diff --git a/Changelog.txt b/Changelog.txt index 8d310d6..1f4594d 100644 --- a/Changelog.txt +++ b/Changelog.txt @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.8.0] +Added `analyze_traps(...)` for randomized correlation-trap analysis +Added per-trap vector and order-invariant trap metrics +Added trap diffuseness/risk heuristic (`trap_diffuseness_score`, `trap_risk_score`, `trap_assessment`) +Added trap-analysis unit tests + ## [0.7.7] Multiple OpenAI Codex updates Changed mp_fit plot colors diff --git a/README.md b/README.md index e33155c..35491d2 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ And in the notebooks provided in the [WeightWatcher-examples github repo](https: If you have some models you would like to analyze and get feedback on, check out [WeightWatcher-Pro](https://weightwatcher-ai.com). It's currently in beta and free. -## Installation: Version 0.7.6 +## Installation: Version 0.8.0 ```sh pip install weightwatcher @@ -56,7 +56,7 @@ pip install weightwatcher if this fails try -### Current TestPyPI Version 0.7.5.5 +### Current TestPyPI Version 0.8.0 ```sh python3 -m pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple weightwatcher @@ -113,6 +113,12 @@ watcher.get_ESD() watcher.distances(model_1, model_2) ``` +New in v0.8.0: trap-level randomized diagnostics: + +```python +trap_df = watcher.analyze_traps(layers=[3, 5], plot=True, savefig="trap_images") +``` + ## PEFT / LORA models (experimental) To analyze an PEFT / LORA fine-tuned model, specify the peft option. @@ -274,6 +280,13 @@ This is good way to visualize the correlations in the true ESD, and detect signa details = watcher.analyze(randomize=True, plot=True) ``` +Trap analysis example: + +```python +watcher = ww.WeightWatcher(model=my_model) +trap_df = watcher.analyze_traps(layers=[3, 5], plot=True, savefig="trap_images") +``` + Fig (a) is well trained; Fig (b) may be over-fit. That orange spike on the far right is the tell-tale clue; it's caled a **Correlation Trap**. diff --git a/tests/test_analyze_traps.py b/tests/test_analyze_traps.py new file mode 100644 index 0000000..58d33f1 --- /dev/null +++ b/tests/test_analyze_traps.py @@ -0,0 +1,142 @@ +import unittest +import numpy as np +import pandas as pd +try: + import torch + import torch.nn as nn + TORCH_AVAILABLE = True +except Exception: + TORCH_AVAILABLE = False + +import weightwatcher as ww + + +if TORCH_AVAILABLE: + class TinyTrapNet(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(16, 12, bias=False) + self.fc2 = nn.Linear(12, 10, bias=False) + with torch.no_grad(): + u = torch.linspace(1.0, 2.0, steps=12) + v = torch.linspace(-2.0, 1.0, steps=16) + self.fc1.weight.copy_(35.0 * torch.outer(u, v)) + + u2 = torch.linspace(1.0, 1.5, steps=10) + v2 = torch.linspace(-1.0, 2.0, steps=12) + self.fc2.weight.copy_(20.0 * torch.outer(u2, v2)) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + +@unittest.skipUnless(TORCH_AVAILABLE, "torch is required for analyze_traps tests") +class TestAnalyzeTraps(unittest.TestCase): + + def setUp(self): + self.model = TinyTrapNet() + self.watcher = ww.WeightWatcher(model=self.model) + + def test_analyze_traps_method_exists(self): + self.assertTrue(hasattr(self.watcher, "analyze_traps")) + + def test_analyze_traps_returns_dataframe(self): + np.random.seed(123) + df = self.watcher.analyze_traps(plot=False, savefig=False) + self.assertIsInstance(df, pd.DataFrame) + + def test_analyze_traps_columns(self): + np.random.seed(123) + df = self.watcher.analyze_traps(plot=False, savefig=False) + expected_cols = { + "layer_id", "name", "trap_index", "perm_mode_index", + "sigma_perm", "mp_bulk_max", "left_top_mass", "right_top_mass" + } + self.assertTrue(expected_cols.issubset(set(df.columns))) + + def test_analyze_traps_no_powerlaw_columns_required(self): + np.random.seed(123) + df = self.watcher.analyze_traps(plot=False, savefig=False) + self.assertNotIn("alpha", df.columns) + self.assertNotIn("xmin", df.columns) + self.assertNotIn("xmax", df.columns) + + def test_analyze_traps_reproducible_when_seed_fixed(self): + np.random.seed(999) + df1 = self.watcher.analyze_traps(plot=False, savefig=False) + np.random.seed(999) + df2 = self.watcher.analyze_traps(plot=False, savefig=False) + + self.assertEqual(len(df1), len(df2)) + self.assertListEqual(df1["layer_id"].tolist(), df2["layer_id"].tolist()) + self.assertListEqual(df1["perm_mode_index"].tolist(), df2["perm_mode_index"].tolist()) + + def test_analyze_traps_reproducible_with_seed_arg(self): + df1 = self.watcher.analyze_traps(plot=False, savefig=False, seed=1337) + df2 = self.watcher.analyze_traps(plot=False, savefig=False, seed=1337) + + self.assertEqual(len(df1), len(df2)) + self.assertListEqual(df1["layer_id"].tolist(), df2["layer_id"].tolist()) + self.assertListEqual(df1["perm_mode_index"].tolist(), df2["perm_mode_index"].tolist()) + + def test_analyze_traps_respects_layer_filter(self): + np.random.seed(123) + all_df = self.watcher.analyze_traps(plot=False, savefig=False) + if len(all_df) == 0: + self.skipTest("No traps detected in this environment") + + layer_id = int(all_df["layer_id"].iloc[0]) + np.random.seed(123) + layer_df = self.watcher.analyze_traps(layers=[layer_id], plot=False, savefig=False) + self.assertTrue(set(layer_df["layer_id"].unique()).issubset({layer_id})) + + def test_analyze_traps_skips_ambiguous_multi_Wmat_layers_safely(self): + conv_model = nn.Conv2d(3, 8, kernel_size=3, bias=False) + watcher = ww.WeightWatcher(model=conv_model) + + np.random.seed(123) + df = watcher.analyze_traps(plot=False, savefig=False, pool=True) + self.assertIsInstance(df, pd.DataFrame) + + def test_analyze_traps_contains_vector_metric_columns(self): + np.random.seed(123) + df = self.watcher.analyze_traps(plot=False, savefig=False) + required = { + "u_entropy", "u_discrete_entropy", "u_localization_ratio", "u_participation_ratio", + "v_entropy", "v_discrete_entropy", "v_localization_ratio", "v_participation_ratio" + } + self.assertTrue(required.issubset(set(df.columns))) + + def test_analyze_traps_contains_order_invariant_stat_columns(self): + np.random.seed(123) + df = self.watcher.analyze_traps(plot=False, savefig=False) + required = { + "u_l2_fourth_moment", "u_effective_support", "u_gini_abs", "u_top10_mass", + "u_squared_amp_entropy", "u_stable_rank_surrogate", + "v_l2_fourth_moment", "v_effective_support", "v_gini_abs", "v_top10_mass", + "v_squared_amp_entropy", "v_stable_rank_surrogate", "trap_balance_ratio", + "trap_diffuseness_score", "trap_risk_score", "trap_assessment" + } + self.assertTrue(required.issubset(set(df.columns))) + + def test_order_invariant_stats_are_finite(self): + np.random.seed(123) + df = self.watcher.analyze_traps(plot=False, savefig=False) + if len(df) == 0: + self.skipTest("No traps detected in this environment") + + row = df.iloc[0] + for col in [ + "u_l2_fourth_moment", "u_l2_sixth_moment", "u_effective_support", "u_gini_abs", + "u_top1_mass", "u_top5_mass", "u_top10_mass", "u_squared_amp_entropy", "u_stable_rank_surrogate", + "v_l2_fourth_moment", "v_l2_sixth_moment", "v_effective_support", "v_gini_abs", + "v_top1_mass", "v_top5_mass", "v_top10_mass", "v_squared_amp_entropy", "v_stable_rank_surrogate", + "trap_balance_ratio", + ]: + self.assertTrue(np.isfinite(row[col])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_trap_diffuseness.py b/tests/test_trap_diffuseness.py new file mode 100644 index 0000000..b9b5054 --- /dev/null +++ b/tests/test_trap_diffuseness.py @@ -0,0 +1,76 @@ +import unittest + +import weightwatcher as ww + + +class TestTrapDiffuseness(unittest.TestCase): + + def setUp(self): + self.watcher = ww.WeightWatcher() + + def test_assess_trap_diffuseness_returns_expected_keys(self): + trap = { + "u_length": 20, + "v_length": 20, + "u_effective_support": 10, + "v_effective_support": 11, + "u_squared_amp_entropy": 2.0, + "v_squared_amp_entropy": 2.1, + "u_top1_mass": 0.10, + "v_top1_mass": 0.11, + "u_gini_abs": 0.20, + "v_gini_abs": 0.25, + "left_top_mass": 0.12, + "right_top_mass": 0.13, + "trap_eval_minus_bulk": 0.5, + "mp_bulk_max": 1.0, + } + out = self.watcher.assess_trap_diffuseness(trap) + self.assertIn("trap_diffuseness_score", out) + self.assertIn("trap_risk_score", out) + self.assertIn("trap_assessment", out) + + def test_assess_trap_diffuseness_localized_vs_diffuse(self): + localized = { + "u_length": 20, + "v_length": 20, + "u_effective_support": 1.5, + "v_effective_support": 1.8, + "u_squared_amp_entropy": 0.2, + "v_squared_amp_entropy": 0.3, + "u_top1_mass": 0.85, + "v_top1_mass": 0.80, + "u_gini_abs": 0.92, + "v_gini_abs": 0.90, + "left_top_mass": 0.88, + "right_top_mass": 0.86, + "trap_eval_minus_bulk": 4.0, + "mp_bulk_max": 1.0, + } + + diffuse = { + "u_length": 20, + "v_length": 20, + "u_effective_support": 15.0, + "v_effective_support": 16.0, + "u_squared_amp_entropy": 2.7, + "v_squared_amp_entropy": 2.8, + "u_top1_mass": 0.10, + "v_top1_mass": 0.10, + "u_gini_abs": 0.20, + "v_gini_abs": 0.22, + "left_top_mass": 0.15, + "right_top_mass": 0.14, + "trap_eval_minus_bulk": 0.2, + "mp_bulk_max": 1.0, + } + + loc_out = self.watcher.assess_trap_diffuseness(localized) + dif_out = self.watcher.assess_trap_diffuseness(diffuse) + + self.assertLess(loc_out["trap_diffuseness_score"], dif_out["trap_diffuseness_score"]) + self.assertGreaterEqual(loc_out["trap_risk_score"], dif_out["trap_risk_score"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/weightwatcher/RMT_Util.py b/weightwatcher/RMT_Util.py index 1fc68cc..0c2a7cd 100644 --- a/weightwatcher/RMT_Util.py +++ b/weightwatcher/RMT_Util.py @@ -893,12 +893,15 @@ def plot_loghist(x, bins=100, xmin=None): plt.xscale('log') -def permute_matrix(W): +def permute_matrix(W, rng=None): """permute a matrix in a reversible way""" num_params = np.prod(W.shape) vec = W.reshape(num_params) - p_ids = np.random.permutation(np.arange(num_params)) + if rng is None: + p_ids = np.random.permutation(np.arange(num_params)) + else: + p_ids = rng.permutation(np.arange(num_params)) p_vec = vec[p_ids] p_W = p_vec.reshape(W.shape) @@ -1175,4 +1178,3 @@ def combine_weights_and_biases(W,b): Wb = np.vstack([W.T,b]).T return Wb - diff --git a/weightwatcher/__init__.py b/weightwatcher/__init__.py index f9293ef..d1a1b95 100644 --- a/weightwatcher/__init__.py +++ b/weightwatcher/__init__.py @@ -18,7 +18,7 @@ __name__ = "weightwatcher" -__version__ = "0.7.7" +__version__ = "0.8.0" __license__ = "Apache License, Version 2.0" __description__ = "Diagnostic Tool for Deep Neural Networks" __url__ = "https://calculationconsulting.com/" @@ -30,4 +30,3 @@ "__url__", "__author__", "__email__", "__copyright__"] - diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index c5be1a1..d592d24 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -3035,9 +3035,10 @@ def apply_permute_W(self, ww_layer, params=None): logger.debug("apply permute W on Layer {} {} ".format(layer_id, name)) logger.debug("params {} ".format(params)) + rng = params.get("rng", None) Wmats, permute_ids = [], [] for W in ww_layer.Wmats: - W, p_ids = permute_matrix(W) + W, p_ids = permute_matrix(W, rng=rng) Wmats.append(W) permute_ids.append(p_ids) @@ -3682,6 +3683,527 @@ def analyze(self, model=None, layers=[], def get_details(self): """get the current details, created by analyze""" return self.details + + + def analyze_traps(self, model=None, layers=[], + min_evals=DEFAULT_MIN_EVALS, max_evals=DEFAULT_MAX_EVALS, + min_size=None, max_size=None, max_N=DEFAULT_MAX_N, + glorot_fix=False, + plot=False, savefig=DEF_SAVE_DIR, + conv2d_norm=True, + ww2x=DEFAULT_WW2X, pool=DEFAULT_POOL, + conv2d_fft=False, fft=False, channels=None, + svd_method=FAST_SVD, + start_ids=DEFAULT_START_ID, + base_model=None, + peft=DEFAULT_PEFT, + seed=None): + """Analyze randomized correlation traps and return one row per trap. + + This method follows the randomized/permuted trap workflow: + (1) normalize each selected layer matrix as in ``analyze()``, + (2) compute the original ESD for the true layer matrix, + (3) deterministically permute matrix entries, fit MP bulk on the permuted ESD, + (4) identify trap modes beyond the MP bulk edge, + (5) isolate each trap as a rank-1 matrix in permuted space, + (6) unpermute trap-only matrices back to original space and analyze trap vectors. + + Returns a pandas DataFrame containing one row per detected trap. + This routine does not run any power-law fitting. + + Parameters + ---------- + seed : None or int + Optional seed used for reversible trap permutations. + Passing the same seed makes trap detection reproducible across runs. + """ + + self.set_model_(model, base_model) + + if min_size or max_size: + logger.warning("min_size and max_size options changed to min_evals, max_evals, ignored for now") + + if ww2x: + logger.warning("WW2X option deprecated, reverting too POOL=False") + ww2x = False + pool = False + + params = DEFAULT_PARAMS.copy() + params[MIN_EVALS] = min_evals + params[MAX_EVALS] = max_evals + params[MAX_N] = max_N + + params[PLOT] = plot + params[RANDOMIZE] = True + params[MP_FIT] = True + params[GLOROT_FIT] = glorot_fix + params[CONV2D_NORM] = conv2d_norm + + params[POOL] = pool + params[WW2X] = ww2x + params[CONV2D_FFT] = conv2d_fft + params[FFT] = fft + + params[CHANNELS_STR] = channels + params[LAYERS] = layers + params[STACKED] = False + + params[DETX] = False + params[SVD_METHOD] = svd_method + params[TOLERANCE] = WEAK_RANK_LOSS_TOLERANCE + params[START_IDS] = start_ids + + params[SAVEFIG] = savefig + params[PEFT] = peft + params[INVERSE] = False + if seed is not None and not isinstance(seed, numbers.Integral): + raise ValueError("seed must be None or an integer") + rng = np.random.RandomState(int(seed)) if seed is not None else None + params["rng"] = rng + + logger.debug("params {}".format(params)) + if not WeightWatcher.valid_params(params): + msg = "Error, params not valid: \n {}".format(params) + logger.error(msg) + raise Exception(msg) + params = self.normalize_params(params) + + layer_iterator = self.make_layer_iterator(model=self.model, layers=layers, params=params, base_model=self.base_model) + trap_rows = [] + + for ww_layer in layer_iterator: + if not ww_layer.skipped and ww_layer.has_weights: + self.apply_normalize_Wmats(ww_layer, params) + + if params[FFT]: + self.apply_FFT(ww_layer, params) + + layer_rows = self.apply_analyze_traps(ww_layer, params=params) + if layer_rows: + trap_rows.extend(layer_rows) + + if len(trap_rows) > 0: + details = pd.DataFrame.from_records(trap_rows) + else: + details = pd.DataFrame(columns=self._trap_result_columns()) + + trap_cols = self._trap_result_columns() + details = details.reindex(columns=trap_cols + [c for c in details.columns if c not in trap_cols]) + + if len(details) > 0: + lead_cols = ["layer_id", "name"] + details = details[lead_cols + [c for c in details.columns if c not in lead_cols]] + + self.details = details + return details + + def _trap_result_columns(self): + return [ + "layer_id", "name", "longname", "layer_type", "N", "M", "rf", "Q", + "trap_index", "perm_mode_index", "sigma_perm", "eval_perm", + "mp_bulk_max", "mp_bulk_min", "sigma_mp", "num_spikes", + "rank1_mass_after_unpermute", "sigma_trap_top", + "left_top_mode", "right_top_mode", "left_top_mass", "right_top_mass", + "left_overlap_entropy", "right_overlap_entropy", "left_overlap_ipr", "right_overlap_ipr", + "u_length", "u_entropy", "u_discrete_entropy", "u_localization_ratio", "u_participation_ratio", + "v_length", "v_entropy", "v_discrete_entropy", "v_localization_ratio", "v_participation_ratio", + "u_l2_fourth_moment", "u_l2_sixth_moment", "u_effective_support", "u_gini_abs", + "u_top1_mass", "u_top5_mass", "u_top10_mass", "u_squared_amp_entropy", "u_stable_rank_surrogate", + "v_l2_fourth_moment", "v_l2_sixth_moment", "v_effective_support", "v_gini_abs", + "v_top1_mass", "v_top5_mass", "v_top10_mass", "v_squared_amp_entropy", "v_stable_rank_surrogate", + "trap_balance_ratio", "trap_detected", "trap_eval_minus_bulk", + "trap_diffuseness_score", "trap_risk_score", "trap_assessment", + ] + + + def apply_analyze_traps(self, ww_layer, params=None): + if params is None: params = DEFAULT_PARAMS.copy() + + if len(ww_layer.Wmats) != 1: + logger.warning("Skipping trap analysis for layer %s %s: expected exactly one Wmat, found %d", + ww_layer.layer_id, ww_layer.name, len(ww_layer.Wmats)) + return [] + + self.apply_esd(ww_layer, params) + original_basis_cache = self.compute_original_basis_for_traps(ww_layer, params=params) + + self.apply_permute_W(ww_layer, params) + self.apply_trap_mp_fit(ww_layer, params=params) + trap_mode_indices = self.identify_trap_mode_indices(ww_layer, params=params) + + trap_rows = [] + for trap_index, mode_index in enumerate(trap_mode_indices): + trap_row = self.analyze_single_trap( + ww_layer, + trap_mode_index=mode_index, + original_basis_cache=original_basis_cache, + params=params, + trap_index=trap_index, + ) + trap_rows.append(trap_row) + + self.apply_unpermute_W(ww_layer, params) + return trap_rows + + + def apply_trap_mp_fit(self, ww_layer, params=None): + if params is None: params = DEFAULT_PARAMS.copy() + self.apply_esd(ww_layer, params) + self.apply_mp_fit(ww_layer, random=False, params=params) + return ww_layer + + + def identify_trap_mode_indices(self, ww_layer, params=None): + if params is None: params = DEFAULT_PARAMS.copy() + + evals = np.array(ww_layer.evals) + if evals is None or len(evals) == 0: + return [] + + Q = ww_layer.N / ww_layer.M if ww_layer.M > 0 else 1.0 + Wscale = ww_layer.W_scale if ww_layer.has_column('W_scale') else 1.0 + sigma_mp = ww_layer.sigma_mp if ww_layer.has_column('sigma_mp') else None + + if sigma_mp is None or sigma_mp <= 0: + threshold = ww_layer.bulk_max if ww_layer.has_column('bulk_max') else np.inf + else: + bulk_max_scaled = (sigma_mp * (1 + 1 / np.sqrt(Q))) ** 2 + TW = 1 / np.sqrt(Q) * np.power(bulk_max_scaled, 2 / 3) * np.power(ww_layer.M, -2 / 3) + bulk_max_tw_scaled = bulk_max_scaled + np.sqrt(TW) + threshold = bulk_max_tw_scaled / (Wscale * Wscale) + + evals_desc = evals[::-1] + trap_ids = [i for i, e in enumerate(evals_desc) if e > threshold] + return trap_ids + + + def compute_original_basis_for_traps(self, ww_layer, params=None): + if params is None: params = DEFAULT_PARAMS.copy() + if len(ww_layer.Wmats) != 1: + return None + + W_true = ww_layer.Wmats[0].astype(float) + U0, S0, V0h = svd_full(W_true, method=params[SVD_METHOD]) + return { + "W_true": W_true, + "U0": U0, + "S0": S0, + "V0": V0h.T, + } + + + def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=None, params=None, trap_index=0): + if params is None: params = DEFAULT_PARAMS.copy() + if original_basis_cache is None: + original_basis_cache = self.compute_original_basis_for_traps(ww_layer, params=params) + + W_perm = ww_layer.Wmats[0].astype(float) + p_ids = ww_layer.permute_ids[0] + + U_perm, S_perm, Vh_perm = svd_full(W_perm, method=params[SVD_METHOD]) + V_perm = Vh_perm.T + + sigma_perm = float(S_perm[trap_mode_index]) + u_perm = U_perm[:, trap_mode_index] + v_perm = V_perm[:, trap_mode_index] + + T_perm = sigma_perm * np.outer(u_perm, v_perm) + T_orig = unpermute_matrix(T_perm, p_ids) + + Ut, St, Vht = svd_full(T_orig, method=params[SVD_METHOD]) + u_trap = Ut[:, 0] + v_trap = Vht.T[:, 0] + + U0 = original_basis_cache["U0"] + V0 = original_basis_cache["V0"] + + left_overlaps = np.abs(U0.T @ u_trap) ** 2 + right_overlaps = np.abs(V0.T @ v_trap) ** 2 + + left_top_mode = int(np.argmax(left_overlaps)) + right_top_mode = int(np.argmax(right_overlaps)) + left_top_mass = float(np.max(left_overlaps)) + right_top_mass = float(np.max(right_overlaps)) + + eps = 1e-12 + left_overlap_entropy = float(-np.sum((left_overlaps + eps) * np.log(left_overlaps + eps))) + right_overlap_entropy = float(-np.sum((right_overlaps + eps) * np.log(right_overlaps + eps))) + left_overlap_ipr = float(np.sum(left_overlaps ** 2)) + right_overlap_ipr = float(np.sum(right_overlaps ** 2)) + + st_sq = St * St + rank1_mass_after_unpermute = float(st_sq[0] / (np.sum(st_sq) + eps)) + + u_metrics = self._trap_vector_metrics(u_trap) + v_metrics = self._trap_vector_metrics(v_trap) + u_oi = self._trap_vector_order_invariant_stats(u_trap) + v_oi = self._trap_vector_order_invariant_stats(v_trap) + + eval_perm = sigma_perm ** 2 + trap_result = { + "layer_id": ww_layer.layer_id, + "name": ww_layer.name, + "longname": ww_layer.longname, + "layer_type": str(ww_layer.the_type), + "N": ww_layer.N, + "M": ww_layer.M, + "rf": ww_layer.rf, + "Q": ww_layer.N / ww_layer.M if ww_layer.M > 0 else np.nan, + "trap_index": int(trap_index), + "perm_mode_index": int(trap_mode_index), + "sigma_perm": sigma_perm, + "eval_perm": float(eval_perm), + "mp_bulk_max": float(ww_layer.bulk_max), + "mp_bulk_min": float(ww_layer.bulk_min), + "sigma_mp": float(ww_layer.sigma_mp), + "num_spikes": int(ww_layer.num_spikes), + "rank1_mass_after_unpermute": rank1_mass_after_unpermute, + "sigma_trap_top": float(St[0]), + "left_top_mode": left_top_mode, + "right_top_mode": right_top_mode, + "left_top_mass": left_top_mass, + "right_top_mass": right_top_mass, + "left_overlap_entropy": left_overlap_entropy, + "right_overlap_entropy": right_overlap_entropy, + "left_overlap_ipr": left_overlap_ipr, + "right_overlap_ipr": right_overlap_ipr, + "trap_detected": True, + "trap_eval_minus_bulk": float(eval_perm - ww_layer.bulk_max), + } + + for k, v in u_metrics.items(): + trap_result[f"u_{k}"] = v + for k, v in v_metrics.items(): + trap_result[f"v_{k}"] = v + for k, v in u_oi.items(): + trap_result[f"u_{k}"] = v + for k, v in v_oi.items(): + trap_result[f"v_{k}"] = v + + trap_result["trap_balance_ratio"] = float( + trap_result["u_effective_support"] / (trap_result["v_effective_support"] + 1e-12) + ) + trap_result.update(self.assess_trap_diffuseness(trap_result)) + + trap_result["left_overlaps"] = left_overlaps + trap_result["right_overlaps"] = right_overlaps + trap_result["u_trap"] = u_trap + trap_result["v_trap"] = v_trap + trap_result["T_orig"] = T_orig + trap_result["perm_evals_sorted"] = np.array(ww_layer.evals).copy() + + if params[PLOT]: + self.plot_trap_analysis(ww_layer, trap_result, params=params) + + trap_result.pop("left_overlaps", None) + trap_result.pop("right_overlaps", None) + trap_result.pop("u_trap", None) + trap_result.pop("v_trap", None) + trap_result.pop("T_orig", None) + trap_result.pop("perm_evals_sorted", None) + + return trap_result + + + def assess_trap_diffuseness(self, trap_result): + """Heuristic classifier for trap severity in original weight space. + + Localized traps are treated as higher risk, while diffuse traps are treated as + potentially benign overfitting. This is intentionally a separate function so it + can be unit-tested and adjusted independently. + """ + eps = 1e-12 + + u_len = max(float(trap_result.get("u_length", 0.0)), 1.0) + v_len = max(float(trap_result.get("v_length", 0.0)), 1.0) + + u_eff = float(trap_result.get("u_effective_support", 0.0)) / u_len + v_eff = float(trap_result.get("v_effective_support", 0.0)) / v_len + + u_ent = float(trap_result.get("u_squared_amp_entropy", 0.0)) / np.log(u_len + eps) + v_ent = float(trap_result.get("v_squared_amp_entropy", 0.0)) / np.log(v_len + eps) + + top1_local = 0.5 * (float(trap_result.get("u_top1_mass", 1.0)) + float(trap_result.get("v_top1_mass", 1.0))) + gini_local = 0.5 * (float(trap_result.get("u_gini_abs", 1.0)) + float(trap_result.get("v_gini_abs", 1.0))) + overlap_local = 0.5 * (float(trap_result.get("left_top_mass", 1.0)) + float(trap_result.get("right_top_mass", 1.0))) + + diffuseness_score = float(np.clip( + 0.30 * u_eff + + 0.30 * v_eff + + 0.20 * u_ent + + 0.20 * v_ent - + 0.30 * top1_local - + 0.20 * gini_local - + 0.20 * overlap_local, + 0.0, 1.0 + )) + + trap_strength = float(trap_result.get("trap_eval_minus_bulk", 0.0)) + bulk = abs(float(trap_result.get("mp_bulk_max", 0.0))) + eps + normalized_strength = trap_strength / bulk + risk_score = float(np.clip((1.0 - diffuseness_score) * max(normalized_strength, 0.0), 0.0, 1.0)) + + if diffuseness_score >= 0.55 and risk_score < 0.30: + assessment = "benign_diffuse" + elif diffuseness_score <= 0.35 or risk_score >= 0.50: + assessment = "localized_risky" + else: + assessment = "mixed" + + return { + "trap_diffuseness_score": diffuseness_score, + "trap_risk_score": risk_score, + "trap_assessment": assessment, + } + + + def plot_trap_analysis(self, ww_layer, trap_result, params=None): + if params is None: params = DEFAULT_PARAMS.copy() + + savefig = params[SAVEFIG] + savedir = params[SAVEDIR] + + trap_idx = trap_result["trap_index"] + mode_idx = trap_result["perm_mode_index"] + plot_id = f"{ww_layer.plot_id}.trap{trap_idx}" + + evals = trap_result["perm_evals_sorted"] + x = np.arange(len(evals)) + plt.plot(x, evals, marker='.', linestyle='None', alpha=0.7) + plt.axhline(trap_result["mp_bulk_max"], color='red', linestyle='--', label='bulk_max') + plt.axhline(trap_result["mp_bulk_min"], color='orange', linestyle=':', label='bulk_min') + plt.title(f"Permuted spectrum L{ww_layer.layer_id} {ww_layer.name} trap {mode_idx}") + plt.legend() + if savefig: + save_fig(plt, "trap.mpfit", plot_id, savedir) + plt.show(); plt.clf() + + T_orig = trap_result["T_orig"] + vmax = np.quantile(np.abs(T_orig), 0.995) + plt.imshow(T_orig, cmap='coolwarm', vmin=-vmax, vmax=vmax, aspect='auto') + plt.colorbar() + plt.title(f"Unpermuted trap matrix L{ww_layer.layer_id} {ww_layer.name} trap {mode_idx}") + if savefig: + save_fig(plt, "trap.heatmap", plot_id, savedir) + plt.show(); plt.clf() + + left_overlaps = trap_result["left_overlaps"] + right_overlaps = trap_result["right_overlaps"] + + plt.plot(np.arange(len(left_overlaps)), left_overlaps, marker='.') + plt.title(f"Left overlaps L{ww_layer.layer_id} trap {mode_idx}") + if savefig: + save_fig(plt, "trap.left_overlap", plot_id, savedir) + plt.show(); plt.clf() + + plt.plot(np.arange(len(right_overlaps)), right_overlaps, marker='.') + plt.title(f"Right overlaps L{ww_layer.layer_id} trap {mode_idx}") + if savefig: + save_fig(plt, "trap.right_overlap", plot_id, savedir) + plt.show(); plt.clf() + + u_trap = trap_result["u_trap"] + v_trap = trap_result["v_trap"] + + plt.plot(np.arange(len(u_trap)), u_trap) + plt.title(f"Left trap vector L{ww_layer.layer_id} trap {mode_idx}") + if savefig: + save_fig(plt, "trap.left_vec", plot_id, savedir) + plt.show(); plt.clf() + + plt.hist(np.abs(u_trap), bins=50, alpha=0.8) + plt.title(f"Left trap |coeff| histogram L{ww_layer.layer_id} trap {mode_idx}") + plt.xlabel("|coefficient|") + plt.ylabel("count") + if savefig: + save_fig(plt, "trap.left_vec_hist", plot_id, savedir) + plt.show(); plt.clf() + + plt.plot(np.arange(len(v_trap)), v_trap) + plt.title(f"Right trap vector L{ww_layer.layer_id} trap {mode_idx}") + if savefig: + save_fig(plt, "trap.right_vec", plot_id, savedir) + plt.show(); plt.clf() + + plt.hist(np.abs(v_trap), bins=50, alpha=0.8) + plt.title(f"Right trap |coeff| histogram L{ww_layer.layer_id} trap {mode_idx}") + plt.xlabel("|coefficient|") + plt.ylabel("count") + if savefig: + save_fig(plt, "trap.right_vec_hist", plot_id, savedir) + plt.show(); plt.clf() + + + def _trap_vector_metrics(self, vec): + return { + "length": float(len(vec)), + "entropy": float(vector_entropy(vec)), + "discrete_entropy": float(discrete_entropy(vec)), + "localization_ratio": float(localization_ratio(vec)), + "participation_ratio": float(participation_ratio(vec)), + } + + + def _gini_abs(self, vec): + p = np.abs(np.asarray(vec).flatten()) + if len(p) == 0: + return 0.0 + s = np.sum(p) + if s <= 0: + return 0.0 + p = np.sort(p) + n = len(p) + idx = np.arange(1, n + 1) + return float((2 * np.sum(idx * p)) / (n * s) - (n + 1) / n) + + + def _topk_mass_fractions(self, vec, ks=(1, 5, 10)): + p = np.sort(np.abs(np.asarray(vec).flatten()))[::-1] + s = np.sum(p) + out = {} + for k in ks: + kk = min(k, len(p)) + if s <= 0 or kk == 0: + out[f"top{k}_mass"] = 0.0 + else: + out[f"top{k}_mass"] = float(np.sum(p[:kk]) / s) + return out + + + def _trap_vector_order_invariant_stats(self, vec): + x = np.asarray(vec).flatten() + abs_x = np.abs(x) + l2 = np.linalg.norm(abs_x) + if l2 <= 0: + l2_norm = abs_x + else: + l2_norm = abs_x / l2 + + l2_fourth = float(np.sum(l2_norm ** 4)) + l2_sixth = float(np.sum(l2_norm ** 6)) + + sq = x * x + sq_sum = np.sum(sq) + if sq_sum <= 0: + q = np.zeros_like(x) + else: + q = sq / sq_sum + + effective_support = float(1.0 / (np.sum(q ** 2) + 1e-12)) + squared_amp_entropy = float(-np.sum(q * np.log(q + 1e-12))) + + stable_rank_surrogate = float((np.sum(abs_x ** 2) ** 2) / (np.sum(abs_x ** 4) + 1e-12)) + + out = { + "l2_fourth_moment": l2_fourth, + "l2_sixth_moment": l2_sixth, + "effective_support": effective_support, + "gini_abs": float(self._gini_abs(vec)), + "squared_amp_entropy": squared_amp_entropy, + "stable_rank_surrogate": stable_rank_surrogate, + } + out.update(self._topk_mass_fractions(vec, ks=(1, 5, 10))) + return out def get_summary(self, details=None): """Return metric averages, as dict, if available """