Skip to content

Feat/external architecture registration#1307

Open
huseyincavusbi wants to merge 7 commits into
TransformerLensOrg:devfrom
huseyincavusbi:feat/external-architecture-registration
Open

Feat/external architecture registration#1307
huseyincavusbi wants to merge 7 commits into
TransformerLensOrg:devfrom
huseyincavusbi:feat/external-architecture-registration

Conversation

@huseyincavusbi
Copy link
Copy Markdown
Contributor

Hi @jlarson4,

Description

Three changes to allow users to register custom architecture adapters without forking TransformerLens:

  1. register_adapter() method — Classmethod on ArchitectureAdapterFactory that adds custom architectures to the registry at runtime
  2. Entry-point discoverydiscover_entry_points() scans installed packages for adapters declared via transformer_lens.architectures entry group, auto-called on first select_architecture_adapter()
  3. Documentation — New page under Contributing with examples for runtime registration, entry-point declaration, and example package layout

Closes #1298

Type of change

  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Checklist

  • I have commented my code, particularly in hard-to-understand areas
  • My changes generate no new warnings
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

Copilot AI review requested due to automatic review settings May 16, 2026 11:05
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a public mechanism for registering custom architecture adapters in TransformerLens without forking the project, plus accompanying tests and documentation.

Changes:

  • New ArchitectureAdapterFactory.register_adapter() classmethod for runtime registration, and discover_entry_points() that reads adapters from the transformer_lens.architectures entry-point group on first use.
  • New unit test module covering registration, discovery idempotency, and select error paths.
  • New documentation page under Contributing describing both registration paths and an example external package layout.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
transformer_lens/factories/architecture_adapter_factory.py Adds register_adapter and discover_entry_points classmethods; calls discovery at the start of select_architecture_adapter.
tests/unit/model_bridge/test_architecture_adapter_factory.py New tests for SUPPORTED_ARCHITECTURES contents, runtime registration, error handling, and discovery idempotency.
docs/source/content/contributing.md Adds the new doc page to the toctree.
docs/source/content/adapter_development/external-adapter-registration.md New doc page explaining runtime and entry-point registration, with example package layout.
Comments suppressed due to low confidence (1)

tests/unit/model_bridge/test_architecture_adapter_factory.py:58

  • This test does not actually verify "overwrite" behavior. It registers MockArchitectureAdapter twice for the same key, so the post-condition cls._adapters[key] is first is trivially true regardless of whether the second register_adapter call overwrote anything. To meaningfully test overwrite behavior, register a second, different adapter class and assert that the registry now points at the new class (and not at first).
    def test_register_overwrites_existing(self):
        key = "TestOverwriteForCausalLM"
        ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter)
        first = ArchitectureAdapterFactory._adapters[key]
        ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter)
        assert ArchitectureAdapterFactory._adapters[key] is first

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +176 to +181
try:
eps = entry_points(group="transformer_lens.architectures")
for ep in eps:
cls._adapters[ep.name] = ep.load()
except Exception:
pass
Comment on lines +45 to +97
class TestRegisterAdapter:
"""Verify runtime adapter registration."""

def test_register_adds_to_adapters(self):
key = "TestMockForCausalLM"
ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter)
assert key in ArchitectureAdapterFactory._adapters

def test_register_overwrites_existing(self):
key = "TestOverwriteForCausalLM"
ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter)
first = ArchitectureAdapterFactory._adapters[key]
ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter)
assert ArchitectureAdapterFactory._adapters[key] is first

def test_select_returns_registered_adapter(self):
key = "TestSelectForCausalLM"
ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter)
cfg = _make_cfg(architecture=key)
adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg)
assert isinstance(adapter, MockArchitectureAdapter)


class TestSelectErrors:
"""Verify error handling in select_architecture_adapter."""

def test_unknown_architecture_raises(self):
cfg = _make_cfg(architecture="NonExistentForCausalLM")
with pytest.raises(ValueError, match="Unsupported architecture"):
ArchitectureAdapterFactory.select_architecture_adapter(cfg)

def test_none_architecture_raises(self):
cfg = _make_cfg(architecture=None)
with pytest.raises(ValueError, match="must have architecture set"):
ArchitectureAdapterFactory.select_architecture_adapter(cfg)


class TestDiscoverEntryPoints:
"""Verify entry-point discovery behavior."""

def test_discover_is_idempotent(self):
ArchitectureAdapterFactory._entry_points_discovered = False
ArchitectureAdapterFactory.discover_entry_points()
first_run = ArchitectureAdapterFactory._entry_points_discovered
ArchitectureAdapterFactory.discover_entry_points()
assert ArchitectureAdapterFactory._entry_points_discovered is first_run is True

def test_discover_does_not_remove_existing(self):
key = "TestPreserveForCausalLM"
ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter)
ArchitectureAdapterFactory._entry_points_discovered = False
ArchitectureAdapterFactory.discover_entry_points()
assert key in ArchitectureAdapterFactory._adapters
discovery of adapters from installed packages via entry points.
"""

_adapters = SUPPORTED_ARCHITECTURES
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants