Add Direct Logit Attribution tool for TransformerBridge#1316
Conversation
jlarson4
left a comment
There was a problem hiding this comment.
Hi @TravisHaa! I have reviewed your PR and left a few comments. I cannot run the code in its current state, until some of these comments are addressed. Let me know if you have any questions.
Also, the PR is marked as "tests added", but I am not seeing any tests in the diff. Could you add tests/unit/tools/test_direct_logit_attribution.py covering at least:
- Correct path on a small bridge (e.g., gpt2 with compatibility mode)
ValueErrorwhen compatibility mode is offNotImplementedErrorwhen a Mamba-like adapter is present- Both
accumulated=Trueandaccumulated=False
Feel free to tag me once these edits are in and I will re-review. Thank you for your work on this, it is coming along nicely!
|
|
||
| # Variant submodule names that indicate a hybrid block. Mamba, SSM, Mixer, and LinearAttention layers don't have the usual attn_out / mlp_out decomposition that ActivationCache.decompose_resid expects. | ||
| _HYBRID_VARIANT_NAMES = ("mamba", "ssm", "mixer", "linear_attn") | ||
| dont |
|
|
||
| from transformer_lens.tools.analysis.direct_logit_attribution import DLA | ||
|
|
||
| __all__ = ["DLA"]c No newline at end of file |
| ``1`` or ``2``. | ||
| """ | ||
|
|
||
| assert len(prompts) == answer_tokens.shape[0] |
There was a problem hiding this comment.
assert is stripped under python -O, so user-facing argument validation should use if not ...: raise ValueError(...) instead. The function already raises ValueError for the compatibility-mode check elsewhere, please follow that same pattern here and on the following line
| _HYBRID_VARIANT_NAMES = ("mamba", "ssm", "mixer", "linear_attn") | ||
| dont | ||
|
|
||
| def DLA( |
There was a problem hiding this comment.
Please rename the function to either dla or direct_logit_attribution. All-caps is reserved for constants.
| correct_token_direction - incorrect_token_direction | ||
| ) | ||
|
|
||
| #turns residual stream contributions into logit-difference scores accounting for layerNorm |
| f"ActivationCache.decompose_resid — tracked separately." | ||
| ) | ||
|
|
||
| #grab residiual directions from bridge (essentially unembedding matrix transposed) |
| prompts, | ||
| return_type=None, | ||
| names_filter=lambda x: x == get_act_name("ln_final.hook_scale") | ||
| or x.endswith("embed") |
There was a problem hiding this comment.
This endswith("embed") can match with hooks we are not trying to target. Consider tightening this to something along the lines of or x in ("hook_embed", "hook_pos_embed")
Description
TransformerBridgesystem, closes [Proposal] Direct Logit Attribution Tool #1263. Based on the stale PR (Draft) Add DLA function to utils #466 but adapted to utilize 3.0 TransformerBridge.Acknowledged limitations
NotImplementedError.ActivationCache.decompose_residonly knows how to decomposeattn_out + mlp_outper layer; supporting hybrid blocks requires extending that method and is out of scopefor this PR. (will be working on this in next steps)
ValueErrorifbridge.enable_compatibility_mode()hasn't been called. Without folded LayerNorm weights, the projection direction is wrong and
per-component scores don't reflect actual logit contributions.
actual_logit_diff − (b_U[correct] − b_U[wrong]). This matches the convention incache.decompose_residand PR (Draft) Add DLA function to utils #466 — the bias is a constant offsetType of change
Screenshots
Checklist: