Skip to content

Add Direct Logit Attribution tool for TransformerBridge#1316

Open
TravisHaa wants to merge 1 commit into
TransformerLensOrg:devfrom
TravisHaa:feat/dla-tool-1263
Open

Add Direct Logit Attribution tool for TransformerBridge#1316
TravisHaa wants to merge 1 commit into
TransformerLensOrg:devfrom
TravisHaa:feat/dla-tool-1263

Conversation

@TravisHaa
Copy link
Copy Markdown

@TravisHaa TravisHaa commented May 20, 2026

Description

  • Implemented a Direct Logit Attribution (DLA) tool for the new TransformerBridge system, closes [Proposal] Direct Logit Attribution Tool #1263. Based on the stale PR (Draft) Add DLA function to utils #466 but adapted to utilize 3.0 TransformerBridge.
  • returns per-component (or per-layer) contributions to logit difference ebtween correct and wrong token, decomposing residual stream based off accumulated bool.
  • generated docstring (according to contributing.md) highlighting important warnings of limitations of DLA tool (currently does not support hybrid layers like mamba, requires bridge compatibility mode)

Acknowledged limitations

  • Strict mode only. hybrid architectures (Mamba, SSM, Mixer, LinearAttention) raise
    NotImplementedError. ActivationCache.decompose_resid only knows how to decompose attn_out + mlp_out per layer; supporting hybrid blocks requires extending that method and is out of scope
    for this PR. (will be working on this in next steps)
  • Requires compatibility mode. Raises ValueError if bridge.enable_compatibility_mode()
    hasn't been called. Without folded LayerNorm weights, the projection direction is wrong and
    per-component scores don't reflect actual logit contributions.
  • Excludes unembedding bias. Per-component scores sum to actual_logit_diff − (b_U[correct] − b_U[wrong]). This matches the convention in cache.decompose_resid and PR (Draft) Add DLA function to utils #466 — the bias is a constant offset

Type of change

  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Screenshots

image

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

Copy link
Copy Markdown
Collaborator

@jlarson4 jlarson4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @TravisHaa! I have reviewed your PR and left a few comments. I cannot run the code in its current state, until some of these comments are addressed. Let me know if you have any questions.

Also, the PR is marked as "tests added", but I am not seeing any tests in the diff. Could you add tests/unit/tools/test_direct_logit_attribution.py covering at least:

  • Correct path on a small bridge (e.g., gpt2 with compatibility mode)
  • ValueError when compatibility mode is off
  • NotImplementedError when a Mamba-like adapter is present
  • Both accumulated=True and accumulated=False

Feel free to tag me once these edits are in and I will re-review. Thank you for your work on this, it is coming along nicely!


# Variant submodule names that indicate a hybrid block. Mamba, SSM, Mixer, and LinearAttention layers don't have the usual attn_out / mlp_out decomposition that ActivationCache.decompose_resid expects.
_HYBRID_VARIANT_NAMES = ("mamba", "ssm", "mixer", "linear_attn")
dont
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra dont here


from transformer_lens.tools.analysis.direct_logit_attribution import DLA

__all__ = ["DLA"]c No newline at end of file
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra "c" here

``1`` or ``2``.
"""

assert len(prompts) == answer_tokens.shape[0]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert is stripped under python -O, so user-facing argument validation should use if not ...: raise ValueError(...) instead. The function already raises ValueError for the compatibility-mode check elsewhere, please follow that same pattern here and on the following line

_HYBRID_VARIANT_NAMES = ("mamba", "ssm", "mixer", "linear_attn")
dont

def DLA(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename the function to either dla or direct_logit_attribution. All-caps is reserved for constants.

correct_token_direction - incorrect_token_direction
)

#turns residual stream contributions into logit-difference scores accounting for layerNorm
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix indentation here

f"ActivationCache.decompose_resid — tracked separately."
)

#grab residiual directions from bridge (essentially unembedding matrix transposed)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in resid(i)ual

prompts,
return_type=None,
names_filter=lambda x: x == get_act_name("ln_final.hook_scale")
or x.endswith("embed")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This endswith("embed") can match with hooks we are not trying to target. Consider tightening this to something along the lines of or x in ("hook_embed", "hook_pos_embed")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants