From eb31f621efa6c47fa6b802a8a48181e1c60952f1 Mon Sep 17 00:00:00 2001 From: Sunny Joshi Date: Fri, 15 May 2026 17:55:14 +0100 Subject: [PATCH 1/6] add Gpt2 MOdel Bridge tests --- .../test_gpt2_adapter.py | 281 ++++++++++++++++++ 1 file changed, 281 insertions(+) create mode 100644 tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py diff --git a/tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py b/tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py new file mode 100644 index 000000000..6abb7e15f --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py @@ -0,0 +1,281 @@ +"""Unit tests for GPT2ArchitectureAdapter. + +Tests cover: +- Config attribute validation (all required attributes are set correctly) +- Component mapping structure (correct bridge types and HF module names) +- Weight conversion keys and count +- QKVSplitRearrangeConversion numerical correctness +- Factory registration (GPT2LMHeadModel maps to the right adapter) +""" + +import pytest +import torch + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + JointQKVAttentionBridge, + LinearBridge, + MLPBridge, + NormalizationBridge, + PosEmbedBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.supported_architectures.gpt2 import ( + GPT2ArchitectureAdapter, + QKVSplitRearrangeConversion, +) + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + +def _make_cfg( + n_heads: int = 4, + d_model: int = 64, + n_layers: int = 2, + d_mlp: int = 256, + d_vocab: int = 1000, + n_ctx: int = 512, +) -> TransformerBridgeConfig: + """Return a minimal TransformerBridgeConfig for GPT2 adapter tests.""" + return TransformerBridgeConfig( + d_model=d_model, + d_head=d_model // n_heads, + n_layers=n_layers, + n_ctx=n_ctx, + n_heads=n_heads, + d_vocab=d_vocab, + d_mlp=d_mlp, + default_prepend_bos=True, + architecture="GPT2LMHeadModel", + ) + +@pytest.fixture +def cfg() -> TransformerBridgeConfig: + return _make_cfg() + +@pytest.fixture +def adapter(cfg: TransformerBridgeConfig) -> GPT2ArchitectureAdapter: + return GPT2ArchitectureAdapter(cfg) + +# --------------------------------------------------------------------------- +# Config attribute tests +# --------------------------------------------------------------------------- + +class TestGPT2AdapterConfig: + """Tests that the adapter sets required config attributes correctly.""" + + def test_normalization_type_is_ln(self, adapter: GPT2ArchitectureAdapter) -> None: + assert adapter.cfg.normalization_type == "LN" + + def test_positional_embedding_type_is_standard(self, adapter: GPT2ArchitectureAdapter) -> None: + assert adapter.cfg.positional_embedding_type == "standard" + + def test_final_rms_is_false(self, adapter: GPT2ArchitectureAdapter) -> None: + assert adapter.cfg.final_rms is False + + def test_gated_mlp_is_false(self, adapter: GPT2ArchitectureAdapter) -> None: + assert adapter.cfg.gated_mlp is False + + def test_attn_only_is_false(self, adapter: GPT2ArchitectureAdapter) -> None: + assert adapter.cfg.attn_only is False + + def test_split_attention_weights_is_true(self, adapter: GPT2ArchitectureAdapter) -> None: + assert adapter.cfg.split_attention_weights is True + + def test_uses_combined_qkv_is_true(self, adapter: GPT2ArchitectureAdapter) -> None: + """GPT-2 stores Q, K, V in a single combined c_attn matrix.""" + assert adapter.uses_combined_qkv is True + +# --------------------------------------------------------------------------- +# Component mapping structure tests +# --------------------------------------------------------------------------- + +class TestGPT2AdapterComponentMapping: + """Tests that component_mapping has the correct bridge types and HF module names.""" + + # -- Top-level keys -- + + def test_embed_is_embedding_bridge(self, adapter: GPT2ArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) + + def test_embed_name(self, adapter: GPT2ArchitectureAdapter) -> None: + assert adapter.component_mapping["embed"].name == "transformer.wte" + + def test_pos_embed_is_pos_embed_bridge(self, adapter: GPT2ArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["pos_embed"], PosEmbedBridge) + + def test_pos_embed_name(self, adapter: GPT2ArchitectureAdapter) -> None: + assert adapter.component_mapping["pos_embed"].name == "transformer.wpe" + + def test_blocks_is_block_bridge(self, adapter: GPT2ArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["blocks"], BlockBridge) + + def test_blocks_name(self, adapter: GPT2ArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].name == "transformer.h" + + def test_ln_final_is_normalization_bridge(self, adapter: GPT2ArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["ln_final"], NormalizationBridge) + + def test_ln_final_name(self, adapter: GPT2ArchitectureAdapter) -> None: + assert adapter.component_mapping["ln_final"].name == "transformer.ln_f" + + def test_unembed_is_unembedding_bridge(self, adapter: GPT2ArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge) + + def test_unembed_name(self, adapter: GPT2ArchitectureAdapter) -> None: + assert adapter.component_mapping["unembed"].name == "lm_head" + + # -- Block submodules -- + + def test_blocks_ln1_is_normalization_bridge(self, adapter: GPT2ArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["blocks"].submodules["ln1"], NormalizationBridge) + + def test_blocks_ln1_name(self, adapter: GPT2ArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].submodules["ln1"].name == "ln_1" + + def test_blocks_ln2_is_normalization_bridge(self, adapter: GPT2ArchitectureAdapter) -> None: + """GPT-2 has a second layer norm before the MLP (no parallel attn/MLP).""" + assert isinstance(adapter.component_mapping["blocks"].submodules["ln2"], NormalizationBridge) + + def test_blocks_ln2_name(self, adapter: GPT2ArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].submodules["ln2"].name == "ln_2" + + def test_attn_is_joint_qkv_attention_bridge(self, adapter: GPT2ArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["attn"], JointQKVAttentionBridge) + + def test_attn_name(self, adapter: GPT2ArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["attn"].name == "attn" + + def test_attn_qkv_is_linear_bridge(self, adapter: GPT2ArchitectureAdapter) -> None: + """The combined QKV projection is a single LinearBridge wrapping c_attn.""" + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert isinstance(attn.submodules["qkv"], LinearBridge) + + def test_attn_qkv_name(self, adapter: GPT2ArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert attn.submodules["qkv"].name == "c_attn" + + def test_attn_o_is_linear_bridge(self, adapter: GPT2ArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert isinstance(attn.submodules["o"], LinearBridge) + + def test_attn_o_name(self, adapter: GPT2ArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert attn.submodules["o"].name == "c_proj" + + def test_mlp_is_mlp_bridge(self, adapter: GPT2ArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["mlp"], MLPBridge) + + def test_mlp_in_name(self, adapter: GPT2ArchitectureAdapter) -> None: + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert mlp.submodules["in"].name == "c_fc" + + def test_mlp_out_name(self, adapter: GPT2ArchitectureAdapter) -> None: + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert mlp.submodules["out"].name == "c_proj" + +# --------------------------------------------------------------------------- +# Weight processing conversion tests +# --------------------------------------------------------------------------- + +class TestGPT2AdapterWeightConversions: + """Tests that weight_processing_conversions has exactly the expected keys.""" + + @pytest.mark.parametrize( + "key", + [ + "blocks.{i}.attn.q.weight", + "blocks.{i}.attn.k.weight", + "blocks.{i}.attn.v.weight", + "blocks.{i}.attn.q.bias", + "blocks.{i}.attn.k.bias", + "blocks.{i}.attn.v.bias", + "blocks.{i}.attn.o.weight", + "unembed.weight", + ], + ) + def test_conversion_key_present(self, adapter: GPT2ArchitectureAdapter, key: str) -> None: + assert key in adapter.weight_processing_conversions + + def test_exactly_eight_conversion_keys(self, adapter: GPT2ArchitectureAdapter) -> None: + assert len(adapter.weight_processing_conversions) == 8 + +# --------------------------------------------------------------------------- +# QKVSplitRearrangeConversion — numerical correctness tests +# --------------------------------------------------------------------------- + + +class TestQKVSplitRearrangeConversion: + """Numerical correctness of GPT-2's combined-QKV (c_attn) split.""" + + N_HEADS, D_HEAD, D_MODEL = 4, 16, 64 # D_MODEL = N_HEADS * D_HEAD + + def _make_conv( + self, qkv_index: int, n_heads: int = 4, d_head: int = 16 + ) -> QKVSplitRearrangeConversion: + """Helper: build a QKVSplitRearrangeConversion for weight tensors.""" + return QKVSplitRearrangeConversion( + qkv_index=qkv_index, + rearrange_pattern="d_model (n h) -> n d_model h", + n=n_heads, + ) + + @pytest.mark.parametrize( + "shape, expected", + [((64, 192), True), ((192, 64), True), ((64, 64), False), ((64, 128), False)], + ) + def test_combined_detection(self, shape, expected) -> None: + assert self._make_conv(0)._is_combined_qkv(torch.zeros(*shape)) is expected + + def test_q_k_v_extracted_from_correct_thirds(self) -> None: + """Q/K/V split from the first/second/third third of the combined weight.""" + blocks = [torch.full((self.D_MODEL, self.D_MODEL), float(v)) for v in (1, 2, 3)] + combined = torch.cat(blocks, dim=1) + for idx, const in enumerate((1.0, 2.0, 3.0)): + out = self._make_conv(idx).handle_conversion(combined) + assert out.shape == (self.N_HEADS, self.D_MODEL, self.D_HEAD) + assert torch.all(out == const) + + def test_already_split_weight_roundtrips(self) -> None: + """handle_conversion -> revert recovers an already-split nn.Linear weight.""" + torch.manual_seed(2) + conv = self._make_conv(0) + original = torch.randn(self.N_HEADS * self.D_HEAD, self.D_MODEL) + recovered = conv.revert(conv.handle_conversion(original)) + assert recovered.shape == original.shape + assert torch.allclose(original, recovered) + +# --------------------------------------------------------------------------- +# Factory registration tests +# --------------------------------------------------------------------------- + +class TestGPT2FactoryRegistration: + """Tests that the factory maps GPT2LMHeadModel to the correct adapter.""" + + def test_factory_returns_gpt2_adapter(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, + ) + + cfg = _make_cfg() + adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) + assert isinstance(adapter, GPT2ArchitectureAdapter), ( + f"Expected GPT2ArchitectureAdapter, got {type(adapter).__name__}" + ) + + def test_factory_key_is_registered(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ) + + assert "GPT2LMHeadModel" in SUPPORTED_ARCHITECTURES, ( + "GPT2LMHeadModel must be registered in SUPPORTED_ARCHITECTURES" + ) + + From b0779d64534efc25f7d5bd9145c9c1dd02e35dc9 Mon Sep 17 00:00:00 2001 From: Sunny Joshi Date: Fri, 15 May 2026 18:16:54 +0100 Subject: [PATCH 2/6] removing unused params --- .../model_bridge/supported_architectures/test_gpt2_adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py b/tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py index 6abb7e15f..c71e37587 100644 --- a/tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py @@ -217,7 +217,7 @@ class TestQKVSplitRearrangeConversion: N_HEADS, D_HEAD, D_MODEL = 4, 16, 64 # D_MODEL = N_HEADS * D_HEAD def _make_conv( - self, qkv_index: int, n_heads: int = 4, d_head: int = 16 + self, qkv_index: int, n_heads: int = 4 ) -> QKVSplitRearrangeConversion: """Helper: build a QKVSplitRearrangeConversion for weight tensors.""" return QKVSplitRearrangeConversion( From 579967c7e4dc3790ae07d7a0ae6f189dc7e68cd1 Mon Sep 17 00:00:00 2001 From: Sunny Joshi Date: Fri, 15 May 2026 18:24:36 +0100 Subject: [PATCH 3/6] adding missing bos and attn checks --- .../supported_architectures/test_gpt2_adapter.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py b/tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py index c71e37587..6eccd695a 100644 --- a/tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py @@ -89,6 +89,14 @@ def test_uses_combined_qkv_is_true(self, adapter: GPT2ArchitectureAdapter) -> No """GPT-2 stores Q, K, V in a single combined c_attn matrix.""" assert adapter.uses_combined_qkv is True + def test_default_prepend_bos_is_true(self, adapter: GPT2ArchitectureAdapter) -> None: + """GPT-2 prepends a BOS token by default (adapter inherits this).""" + assert adapter.cfg.default_prepend_bos is True + + def test_default_cfg_uses_split_attention(self, adapter: GPT2ArchitectureAdapter) -> None: + """default_cfg flags that GPT-2's combined QKV must be split.""" + assert adapter.default_cfg["uses_split_attention"] is True + # --------------------------------------------------------------------------- # Component mapping structure tests # --------------------------------------------------------------------------- @@ -151,6 +159,11 @@ def test_attn_name(self, adapter: GPT2ArchitectureAdapter) -> None: blocks = adapter.component_mapping["blocks"] assert blocks.submodules["attn"].name == "attn" + def test_attn_does_not_require_attention_mask(self, adapter: GPT2ArchitectureAdapter) -> None: + """GPT-2 attention applies a causal mask internally, so no external mask is needed.""" + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert attn.requires_attention_mask is False + def test_attn_qkv_is_linear_bridge(self, adapter: GPT2ArchitectureAdapter) -> None: """The combined QKV projection is a single LinearBridge wrapping c_attn.""" attn = adapter.component_mapping["blocks"].submodules["attn"] From 8eeaac839bc058315766ff3c7c16e1dcf92c584e Mon Sep 17 00:00:00 2001 From: Sunny Joshi Date: Fri, 15 May 2026 18:29:51 +0100 Subject: [PATCH 4/6] updating FatoryRegistration test --- .../supported_architectures/test_gpt2_adapter.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py b/tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py index 6eccd695a..97df7795f 100644 --- a/tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py @@ -286,9 +286,4 @@ def test_factory_key_is_registered(self) -> None: from transformer_lens.factories.architecture_adapter_factory import ( SUPPORTED_ARCHITECTURES, ) - - assert "GPT2LMHeadModel" in SUPPORTED_ARCHITECTURES, ( - "GPT2LMHeadModel must be registered in SUPPORTED_ARCHITECTURES" - ) - - + assert SUPPORTED_ARCHITECTURES["GPT2LMHeadModel"] is GPT2ArchitectureAdapter From 786e59b3b2b7d5e2c0170047b8bf0718738b16a9 Mon Sep 17 00:00:00 2001 From: Sunny Joshi Date: Fri, 15 May 2026 18:31:21 +0100 Subject: [PATCH 5/6] formatting via black --- .../test_gpt2_adapter.py | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py b/tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py index 97df7795f..8579f7f34 100644 --- a/tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_gpt2_adapter.py @@ -31,6 +31,7 @@ # Helpers / fixtures # --------------------------------------------------------------------------- + def _make_cfg( n_heads: int = 4, d_model: int = 64, @@ -52,18 +53,22 @@ def _make_cfg( architecture="GPT2LMHeadModel", ) + @pytest.fixture def cfg() -> TransformerBridgeConfig: return _make_cfg() + @pytest.fixture def adapter(cfg: TransformerBridgeConfig) -> GPT2ArchitectureAdapter: return GPT2ArchitectureAdapter(cfg) + # --------------------------------------------------------------------------- # Config attribute tests # --------------------------------------------------------------------------- + class TestGPT2AdapterConfig: """Tests that the adapter sets required config attributes correctly.""" @@ -97,10 +102,12 @@ def test_default_cfg_uses_split_attention(self, adapter: GPT2ArchitectureAdapter """default_cfg flags that GPT-2's combined QKV must be split.""" assert adapter.default_cfg["uses_split_attention"] is True + # --------------------------------------------------------------------------- # Component mapping structure tests # --------------------------------------------------------------------------- + class TestGPT2AdapterComponentMapping: """Tests that component_mapping has the correct bridge types and HF module names.""" @@ -139,14 +146,18 @@ def test_unembed_name(self, adapter: GPT2ArchitectureAdapter) -> None: # -- Block submodules -- def test_blocks_ln1_is_normalization_bridge(self, adapter: GPT2ArchitectureAdapter) -> None: - assert isinstance(adapter.component_mapping["blocks"].submodules["ln1"], NormalizationBridge) + assert isinstance( + adapter.component_mapping["blocks"].submodules["ln1"], NormalizationBridge + ) def test_blocks_ln1_name(self, adapter: GPT2ArchitectureAdapter) -> None: assert adapter.component_mapping["blocks"].submodules["ln1"].name == "ln_1" def test_blocks_ln2_is_normalization_bridge(self, adapter: GPT2ArchitectureAdapter) -> None: """GPT-2 has a second layer norm before the MLP (no parallel attn/MLP).""" - assert isinstance(adapter.component_mapping["blocks"].submodules["ln2"], NormalizationBridge) + assert isinstance( + adapter.component_mapping["blocks"].submodules["ln2"], NormalizationBridge + ) def test_blocks_ln2_name(self, adapter: GPT2ArchitectureAdapter) -> None: assert adapter.component_mapping["blocks"].submodules["ln2"].name == "ln_2" @@ -193,10 +204,12 @@ def test_mlp_out_name(self, adapter: GPT2ArchitectureAdapter) -> None: mlp = adapter.component_mapping["blocks"].submodules["mlp"] assert mlp.submodules["out"].name == "c_proj" + # --------------------------------------------------------------------------- # Weight processing conversion tests # --------------------------------------------------------------------------- + class TestGPT2AdapterWeightConversions: """Tests that weight_processing_conversions has exactly the expected keys.""" @@ -219,6 +232,7 @@ def test_conversion_key_present(self, adapter: GPT2ArchitectureAdapter, key: str def test_exactly_eight_conversion_keys(self, adapter: GPT2ArchitectureAdapter) -> None: assert len(adapter.weight_processing_conversions) == 8 + # --------------------------------------------------------------------------- # QKVSplitRearrangeConversion — numerical correctness tests # --------------------------------------------------------------------------- @@ -229,16 +243,14 @@ class TestQKVSplitRearrangeConversion: N_HEADS, D_HEAD, D_MODEL = 4, 16, 64 # D_MODEL = N_HEADS * D_HEAD - def _make_conv( - self, qkv_index: int, n_heads: int = 4 - ) -> QKVSplitRearrangeConversion: + def _make_conv(self, qkv_index: int, n_heads: int = 4) -> QKVSplitRearrangeConversion: """Helper: build a QKVSplitRearrangeConversion for weight tensors.""" return QKVSplitRearrangeConversion( qkv_index=qkv_index, rearrange_pattern="d_model (n h) -> n d_model h", n=n_heads, ) - + @pytest.mark.parametrize( "shape, expected", [((64, 192), True), ((192, 64), True), ((64, 64), False), ((64, 128), False)], @@ -264,10 +276,12 @@ def test_already_split_weight_roundtrips(self) -> None: assert recovered.shape == original.shape assert torch.allclose(original, recovered) + # --------------------------------------------------------------------------- # Factory registration tests # --------------------------------------------------------------------------- + class TestGPT2FactoryRegistration: """Tests that the factory maps GPT2LMHeadModel to the correct adapter.""" @@ -278,12 +292,13 @@ def test_factory_returns_gpt2_adapter(self) -> None: cfg = _make_cfg() adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) - assert isinstance(adapter, GPT2ArchitectureAdapter), ( - f"Expected GPT2ArchitectureAdapter, got {type(adapter).__name__}" - ) + assert isinstance( + adapter, GPT2ArchitectureAdapter + ), f"Expected GPT2ArchitectureAdapter, got {type(adapter).__name__}" def test_factory_key_is_registered(self) -> None: from transformer_lens.factories.architecture_adapter_factory import ( SUPPORTED_ARCHITECTURES, ) + assert SUPPORTED_ARCHITECTURES["GPT2LMHeadModel"] is GPT2ArchitectureAdapter From 3eee04bfac4904ddc34fc4994e23ce247c3b24fb Mon Sep 17 00:00:00 2001 From: Sunny Joshi Date: Sat, 16 May 2026 14:03:36 +0100 Subject: [PATCH 6/6] adding custom head GPT2 tests --- .../test_gpt2_lm_head_custom_adapter.py | 199 ++++++++++++++++++ 1 file changed, 199 insertions(+) create mode 100644 tests/unit/model_bridge/supported_architectures/test_gpt2_lm_head_custom_adapter.py diff --git a/tests/unit/model_bridge/supported_architectures/test_gpt2_lm_head_custom_adapter.py b/tests/unit/model_bridge/supported_architectures/test_gpt2_lm_head_custom_adapter.py new file mode 100644 index 000000000..f60d84b90 --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_gpt2_lm_head_custom_adapter.py @@ -0,0 +1,199 @@ +"""Unit tests for Gpt2LmHeadCustomArchitectureAdapter. + +Tests cover: +- Component mapping structure (correct bridge types and HF module names) +- Weight conversion keys and count +- Factory registration (GPT2LMHeadCustomModel maps to the right adapter) + +Note: unlike GPT2ArchitectureAdapter, this adapter sets no cfg.* attributes +and uses a plain AttentionBridge (no combined-QKV split), so it has no +config-attribute tests. +""" + +import pytest + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ArchitectureAdapterFactory, +) +from transformer_lens.model_bridge.generalized_components import ( + AttentionBridge, + BlockBridge, + EmbeddingBridge, + MLPBridge, + NormalizationBridge, + PosEmbedBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.supported_architectures.gpt2_lm_head_custom import ( + Gpt2LmHeadCustomArchitectureAdapter, +) + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + + +def _make_cfg( + n_heads: int = 4, + d_model: int = 64, + n_layers: int = 2, + d_mlp: int = 256, + d_vocab: int = 1000, + n_ctx: int = 512, +) -> TransformerBridgeConfig: + """Return a minimal TransformerBridgeConfig for custom GPT-2 adapter tests.""" + return TransformerBridgeConfig( + d_model=d_model, + d_head=d_model // n_heads, + n_layers=n_layers, + n_ctx=n_ctx, + n_heads=n_heads, + d_vocab=d_vocab, + d_mlp=d_mlp, + default_prepend_bos=True, + architecture="GPT2LMHeadCustomModel", + ) + + +@pytest.fixture +def adapter() -> Gpt2LmHeadCustomArchitectureAdapter: + return Gpt2LmHeadCustomArchitectureAdapter(_make_cfg()) + + +# --------------------------------------------------------------------------- +# Component mapping structure tests +# --------------------------------------------------------------------------- + + +class TestCustomAdapterComponentMapping: + """Component mapping must have the correct bridge types and HF module names.""" + + def test_embed_is_embedding_bridge(self, adapter: Gpt2LmHeadCustomArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) + + def test_embed_name(self, adapter: Gpt2LmHeadCustomArchitectureAdapter) -> None: + assert adapter.component_mapping["embed"].name == "transformer.wte" + + def test_pos_embed_is_pos_embed_bridge( + self, adapter: Gpt2LmHeadCustomArchitectureAdapter + ) -> None: + assert isinstance(adapter.component_mapping["pos_embed"], PosEmbedBridge) + + def test_pos_embed_name(self, adapter: Gpt2LmHeadCustomArchitectureAdapter) -> None: + assert adapter.component_mapping["pos_embed"].name == "transformer.wpe" + + def test_blocks_is_block_bridge(self, adapter: Gpt2LmHeadCustomArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["blocks"], BlockBridge) + + def test_blocks_name(self, adapter: Gpt2LmHeadCustomArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].name == "transformer.h" + + def test_ln_final_is_normalization_bridge( + self, adapter: Gpt2LmHeadCustomArchitectureAdapter + ) -> None: + assert isinstance(adapter.component_mapping["ln_final"], NormalizationBridge) + + def test_ln_final_name(self, adapter: Gpt2LmHeadCustomArchitectureAdapter) -> None: + assert adapter.component_mapping["ln_final"].name == "transformer.ln_f" + + def test_unembed_is_unembedding_bridge( + self, adapter: Gpt2LmHeadCustomArchitectureAdapter + ) -> None: + assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge) + + def test_unembed_name(self, adapter: Gpt2LmHeadCustomArchitectureAdapter) -> None: + assert adapter.component_mapping["unembed"].name == "lm_head" + + # -- Block submodules -- + + def test_ln1_is_normalization_bridge( + self, adapter: Gpt2LmHeadCustomArchitectureAdapter + ) -> None: + assert isinstance( + adapter.component_mapping["blocks"].submodules["ln1"], NormalizationBridge + ) + + def test_ln1_name(self, adapter: Gpt2LmHeadCustomArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].submodules["ln1"].name == "ln_1" + + def test_attn_is_attention_bridge(self, adapter: Gpt2LmHeadCustomArchitectureAdapter) -> None: + """The custom adapter uses a plain AttentionBridge, not a JointQKVAttentionBridge.""" + assert isinstance(adapter.component_mapping["blocks"].submodules["attn"], AttentionBridge) + + def test_attn_name(self, adapter: Gpt2LmHeadCustomArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].submodules["attn"].name == "attn" + + def test_ln2_is_normalization_bridge( + self, adapter: Gpt2LmHeadCustomArchitectureAdapter + ) -> None: + assert isinstance( + adapter.component_mapping["blocks"].submodules["ln2"], NormalizationBridge + ) + + def test_ln2_name(self, adapter: Gpt2LmHeadCustomArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].submodules["ln2"].name == "ln_2" + + def test_mlp_is_mlp_bridge(self, adapter: Gpt2LmHeadCustomArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["blocks"].submodules["mlp"], MLPBridge) + + def test_mlp_name(self, adapter: Gpt2LmHeadCustomArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].submodules["mlp"].name == "mlp" + + +# --------------------------------------------------------------------------- +# Weight processing conversion tests +# --------------------------------------------------------------------------- + + +class TestCustomAdapterWeightConversions: + """Adapter must define exactly the expected QKVO weight conversion keys.""" + + @pytest.mark.parametrize( + "key", + [ + "blocks.{i}.attn.q", + "blocks.{i}.attn.k", + "blocks.{i}.attn.v", + "blocks.{i}.attn.b_Q", + "blocks.{i}.attn.b_K", + "blocks.{i}.attn.b_V", + "blocks.{i}.attn.o", + ], + ) + def test_conversion_key_present( + self, adapter: Gpt2LmHeadCustomArchitectureAdapter, key: str + ) -> None: + assert key in adapter.weight_processing_conversions + + def test_exactly_seven_conversion_keys( + self, adapter: Gpt2LmHeadCustomArchitectureAdapter + ) -> None: + assert len(adapter.weight_processing_conversions) == 7 + + def test_qkv_conversions_source_from_combined_c_attn( + self, adapter: Gpt2LmHeadCustomArchitectureAdapter + ) -> None: + """Q/K/V weights are all fetched from the single combined c_attn.weight.""" + for key in ("blocks.{i}.attn.q", "blocks.{i}.attn.k", "blocks.{i}.attn.v"): + conversion = adapter.weight_processing_conversions[key] + assert conversion.source_key == "transformer.h.{i}.attn.c_attn.weight" + + +# --------------------------------------------------------------------------- +# Factory registration tests +# --------------------------------------------------------------------------- + + +class TestCustomFactoryRegistration: + """Factory must resolve GPT2LMHeadCustomModel -> Gpt2LmHeadCustomArchitectureAdapter.""" + + def test_factory_returns_custom_adapter(self) -> None: + adapter = ArchitectureAdapterFactory.select_architecture_adapter(_make_cfg()) + assert isinstance(adapter, Gpt2LmHeadCustomArchitectureAdapter) + + def test_custom_model_in_supported_architectures(self) -> None: + assert ( + SUPPORTED_ARCHITECTURES["GPT2LMHeadCustomModel"] is Gpt2LmHeadCustomArchitectureAdapter + )