Skip to content

Create a Protocol for the MLP layer of TransformerLayer#3435

Open
nschank wants to merge 2 commits intoNVIDIA:mainfrom
nschank:mlplayer
Open

Create a Protocol for the MLP layer of TransformerLayer#3435
nschank wants to merge 2 commits intoNVIDIA:mainfrom
nschank:mlplayer

Conversation

@nschank
Copy link
Copy Markdown
Contributor

@nschank nschank commented Feb 15, 2026

What does this PR do ?

Defines a Protocol representing the mlp submodule of TransformerLayer, and uses that instead of ModuleSpec to enable typechecking of its configuration.

  • To propagate type-checking in layerspec construction, this required I replace several layers of ModuleSpec with MlpBuilder. Nobody except a single unit test appears to be trying to introspect into the ModuleSpec's contents (other than a single use of metainfo which was easy to locally replace with some manual logic), so this should be fairly simple and safe.
  • TransformerLayer was doing some spicy internal kwarg-management based on the specific type being passed; I moved this logic into factory methods on the relevant classes themselves, and updated callers to prefer to pass that method directly instead, but the ModuleSpec-based special casing was left there for backward compatibility. Note that some of the types that were being special-cased are simply not supported by TransformerLayer at this point (they need to be wrapped in MoeLayer to support the correct forward interface), so I just removed them entirely.

Associated design doc: Typed ModuleSpec.pdf

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@nschank nschank requested review from a team as code owners February 15, 2026 18:58
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Feb 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ko3n1g ko3n1g requested a review from a team February 15, 2026 18:58
@Phlip79 Phlip79 added Expert Review [deprecated] Apply this label to indicate that your PR is ready for expert review. complexity: medium labels Feb 17, 2026
@Phlip79
Copy link
Copy Markdown
Member

Phlip79 commented Feb 17, 2026

/ok to test e0cf1f9

@nschank
Copy link
Copy Markdown
Contributor Author

nschank commented Feb 18, 2026

Note: This one may be a bit blocked on the get_submodules method in #3426, because I think there are a good number of tests introspecting into mlp (see #3425) - will update once available.

@chtruong814 chtruong814 added the needs-follow-up Issue needs follow-up label Feb 20, 2026
@nschank
Copy link
Copy Markdown
Contributor Author

nschank commented Mar 7, 2026

I'm waiting on #3426 to resync, since I want to reuse the get_submodules thing

Comment thread megatron/core/transformer/mlp.py
@chtruong814 chtruong814 removed the needs-follow-up Issue needs follow-up label Mar 17, 2026
ffn_hidden_size: int | None = None,
) -> MLP:
"""Helper function to build an MLP as a TransformerLayer's mlp submodule."""
del is_mtp_layer
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, maybe dumb question... can you explain what is going on here?

Is this function taking in the arguments for a TransformerLayer and "converting" to an MLP or something?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing! Not a dumb question at all, this took some thought lol. This is trying to 'decentralize' the logic that TransformerLayer is currently doing here:

https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/transformer_layer.py#L373-L395

Basically, TransformerLayer is currently trying to introspect into the submodule being constructed, and to change the arguments it passes to build_module in response. But this has several drawbacks:

  1. It's confusing: you need to look in 3 different places to understand how these modules are being constructed (the config, the interface of the submodule, and TransformerLayer's special-cased conversion between them, hidden deep in its initializer).
  2. It's circular: TransformerLayer needs to know about its own dependencies in order to construct them, hence the lazy imports.
  3. It's inflexible: Only these classes get this special treatment, so for instance if a user subclasses one of them they suddenly get a different behavior.
  4. Finally (most relevantly here), it's type-checker incompatible: the parameters are different depending on what the caller provides, so at the very least I'd need to provide an overly flexible interface in order to specify the protocol.

So this PR is instead having TransformerLayer consistently construct its MLP submodule using exactly the same interface (and, in particular, it is the "maximal" interface that satisfies all current callers). It is then the responsibility of whoever constructs TransformerLayerSubmodules to satisfy that interface, regardless of the class they want to construct. If they want to provide a class that does not want to take all of the arguments, then that's not something TransformerLayer should be expected to fix for them - it can instead be handled by providing a callable which simply discards those arguments, and forwards the rest to the class they want to construct!

So these classmethods I added are basically that extra 'translation layer' - these classes sorta "know" they want to be provided to TransformerLayerSubmodules, so they can simply provide an extra method which satisfies the interface TransformerLayerSubmodules requires, and users can then provide MLP.as_mlp_submodule instead of just MLP. External users can easily imitate this pattern as well on their own custom classes (if they don't want to have their initializer accept all the arguments), or anyone can write such a 'translation function' for any alternative class they wish to construct too. So basically you get all the flexibility of the original TransformerLayer conversion thing, with none of the drawbacks.


I hope that helps clarify things - any thoughts on how to make this more self-documenting? Perhaps as_transformer_layer_submodule would convey this better?

Copy link
Copy Markdown
Contributor

@jaredcasper jaredcasper Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blegh, how did we ever let in those lines 373-395... <shaking head>.

The original idea was that all MLPs have the same interface. (and indeed, anything that can be swapped out in a spec should have the same interface as the thing its replacing.) It seems some MLPs have snuck in that have different interfaces and we get that ugly if/else block that shouldn't have passed review.

I don't get what the advantage is to having an "as_mlp_submodule" that takes the extra args then throws them away and creates the class vs just having the init function take the extra args directly and throw them away. Why the extra step? Why not just add is_mtp_layer to this classes __init__? Why should we let them provide a class that does not take all the extra arguments? If something is so different that it needs an entirely different argument list than it shouldn't be swapped in as an "MLP".

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get what the advantage is to having an "as_mlp_submodule" that takes the extra args then throws them away

It's not an advantage, it's a safe refactoring - I don't necessarily condone having the alternative interface for constructing these classes, but I'm trying to provide the minimal clean transition for each class currently being provided. Adding extra unused arguments in order to match a specific interface feels worse to me than providing a clean shim layer which callers can use easily.

More broadly, I think that "everyone needs to have the same interface" is technically a bit more restrictive than necessary - we need the thing passed in to have a particular interface, but I don't think it makes sense to force everyone to solely pass in unadulterated __init__ methods to call. Being able to pass subclasses which desire extra parameters (and then providing them using functools.partial) is a valuable ability, and I basically view this as the dual version of that (i.e. having a class which needs fewer/transformed parameters).

Why should we let them provide a class that does not take all the extra arguments? If something is so different that it needs an entirely different argument list than it shouldn't be swapped in as an "MLP".

Why should a class let someone provide an extra argument that it doesn't want to use? Not every interface needs to be perfectly met in order for something to be useful - IdentityFn isn't exactly the most choosy about how it's used, but it makes a lot of sense.

The parent module has a fixed amount of information it wants to provide, but it's not actually that helpful IMO to say that the 'constructee' should "care" about every single piece of information. If we have 5 classes that are legal MLP layers, and only one cares about is_mtp_layer, why does it make sense to require them all to specifically accept an unused parameter in their __init__? I don't really see that as cleaner than just having a documented shim layer which can let the class do its thing, while documenting the way that some other class uses this class.

This same pattern (having a callable which adapts to the appropriate interface) is something that is useful to demonstrate, so I think there's some value in putting it into the codebase somewhere. It provides flexibility (like not forcing you to update every tp_group recipient to pg_collection at once), and lets you do some useful things (like using a custom Module someone has somewhere, which was already being used by someone, which has a different interface but could be used in Megatron safely).


I would be fine with updating the other class's init's if you wanted, but I don't think I should be expected to convert any classes to switch between tp_group and pg_collection, so at the very least there will definitely be some conversion layer here - at least temporarily. Do you have a different form you'd prefer?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: https://github.com/NVIDIA/Megatron-LM/pull/3435/changes#diff-6745b82c932c5947fd3383c31f326639093c33d34cb5def59dc5a843d4e2ebbcR165 This is perhaps a good example of the additional flexibility that an intermediate callable can provide. Previously, in order to customize a submodule's parameters, it was necessary to actually subclass TransformerLayer and change how it was providing the parameter to the submodule; but now, you can simply intercept the parameters that TransformerLayer is providing and do whatever you want.

@chtruong814 chtruong814 added the needs-follow-up Issue needs follow-up label Mar 19, 2026
@nschank
Copy link
Copy Markdown
Contributor Author

nschank commented Mar 27, 2026

At long last, rebased to capture #3426 and updated tests to use it! Sorry it blows up the file count a bit, but all the test changes are pretty mechanical.

@nschank nschank requested a review from jaredcasper April 1, 2026 13:59
@Phlip79 Phlip79 removed the Expert Review [deprecated] Apply this label to indicate that your PR is ready for expert review. label Apr 3, 2026
Copy link
Copy Markdown
Contributor

@kevalmorabia97 kevalmorabia97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Phlip79 Phlip79 requested a review from a team April 3, 2026 16:46
@Phlip79
Copy link
Copy Markdown
Member

Phlip79 commented Apr 3, 2026

/ok to test e7b8792

@svcnvidia-nemo-ci svcnvidia-nemo-ci added the Final Review PR is in the "final review" stage label Apr 3, 2026
@chtruong814 chtruong814 removed the needs-follow-up Issue needs follow-up label Apr 3, 2026
Copy link
Copy Markdown
Contributor

@jaredcasper jaredcasper left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, for the delay... I've been debating how I feel about this "as_mlp_submodule" business of letting users swap in modules with different interfaces but providing a wrapper. It kind of goes to heart of what the spec should be used for and what it shouldn't. Having the wrapper method is not intuitive to me because it could make it somewhat difficult to track where arguments are coming in from and how they get set. So still not sure I'm a fan of the wrapper method, but we can put it in for now.

@chtruong814 chtruong814 added needs-follow-up Issue needs follow-up and removed needs-follow-up Issue needs follow-up labels Apr 17, 2026
@Phlip79 Phlip79 requested a review from a team April 20, 2026 18:26
@svcnvidia-nemo-ci svcnvidia-nemo-ci added waiting-on-maintainers Waiting on maintainers to respond and removed needs-follow-up Issue needs follow-up labels Apr 21, 2026
@ericharper
Copy link
Copy Markdown
Contributor

@nschank , could you resolve conflicts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request complexity: medium Final Review PR is in the "final review" stage waiting-on-maintainers Waiting on maintainers to respond

Projects

None yet

Development

Successfully merging this pull request may close these issues.