Skip to content

[PyTorch] torch.compile support for permutation functions#2686

Open
pggPL wants to merge 11 commits intoNVIDIA:mainfrom
pggPL:moe_torch_compile
Open

[PyTorch] torch.compile support for permutation functions#2686
pggPL wants to merge 11 commits intoNVIDIA:mainfrom
pggPL:moe_torch_compile

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Feb 17, 2026

Description

This PR adds torch.compile(fullgraph=True) support for MoE permutation operations (moe_permute, moe_unpermute, moe_sort_chunks_by_index) by converting all torch.autograd.Function implementations to PyTorch custom operators using torch.library.custom_op.

Note that this PR does not add torch.compile support for QuantizedTensor as an input.

Related to #2590

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the moe_torch_compile branch from 41e22ef to 8159d26 Compare February 18, 2026 17:31
pre-commit-ci bot and others added 4 commits February 18, 2026 17:32
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL marked this pull request as ready for review February 19, 2026 15:45
@pggPL
Copy link
Collaborator Author

pggPL commented Feb 19, 2026

/te-ci pytorch

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 19, 2026

Greptile Summary

This PR refactors all MoE permutation operations (moe_permute, moe_unpermute, moe_sort_chunks_by_index) from torch.autograd.Function subclasses to torch.library.custom_op registrations, enabling torch.compile(fullgraph=True) support. The structural approach — separating forward ops, backward ops, fake (shape-inference) implementations, and context setup into distinct functions — is correct for the PyTorch custom-op contract.

Key issues found:

  • Broken fake function for mask-map permute under default usage (permutation.py line 995): _moe_permute_mask_map_forward_fake uses num_out_tokens directly as a tensor dimension without guarding against the sentinel value -1 (the moe_permute default) or None. During torch.compile tracing this results in torch.empty((-1, hidden_size)) which raises an invalid-shape error. The index-map fake already resolves this with output_tokens = num_out_tokens if num_out_tokens > 0 else num_tokens * topK; the same guard is needed in the mask-map fake. The probs-output line on line 1000 has the identical problem.

  • Incomplete test skip conditions for torch.compile (test_permutation.py line 254): test_permutation_mask_map and test_permutation_mask_map_alongside_probs skip only on num_expert/topK combinations, but the None output-token-count parametrized case still runs under compile and will hit the fake-function bug described above.

  • The removal of the assert num_out_tokens is not None guard from the old _moe_permute_mask_map.forward without an equivalent check in the new custom-op path means None can now silently propagate further than before.

Confidence Score: 3/5

  • Not safe to merge as-is: torch.compile tracing will fail for mask-map permute when called with the default sentinel or None output token count.
  • The overall refactor is well-structured and the index-map and chunk-sort paths appear correct. However, the mask-map fake function has a clear bug that aborts torch.compile tracing for any caller using the default token-count sentinel, which is the most common usage. The companion test skip gap means this broken path is not currently caught by the new tests. Both issues need to be resolved before merging.
  • transformer_engine/pytorch/permutation.py (fake function for mask-map permute) and tests/pytorch/test_permutation.py (skip conditions for torch.compile with None token count).

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantized_tensor.py Minimal, clean change: adds a module-level _quantized_tensor_passthrough_ops set and a corresponding short-circuit in __torch_dispatch__ so registered custom ops receive FP8 tensors unwrapped.
transformer_engine/pytorch/permutation.py Large refactor replacing all torch.autograd.Function classes with torch.library.custom_op registrations. The fake (abstract) function for moe_permute_mask_map_forward is missing a guard for the default sentinel value (-1 or None) for the output token count, which will crash torch.compile tracing for the mask-map path when invoked with default arguments.
tests/pytorch/test_permutation.py Good breadth of torch.compile coverage added with per-test skip conditions to limit combinatorial explosion. The skip guard for the mask-map tests is missing a check for the None output-token-count case, allowing a broken compile path to be exercised in test_permutation_mask_map and test_permutation_mask_map_alongside_probs.

Sequence Diagram

sequenceDiagram
    participant User
    participant moe_permute
    participant CustomOp as torch.ops.te_moe.*
    participant FakeImpl as register_fake
    participant RealImpl as Real Forward
    participant Autograd as register_autograd
    participant BwdOp as Backward Custom Op

    User->>moe_permute: call (inp, routing_map, ...)
    alt torch.compile tracing
        moe_permute->>CustomOp: dispatch
        CustomOp->>FakeImpl: shape inference only
        FakeImpl-->>CustomOp: fake output tensors
    else eager execution
        moe_permute->>CustomOp: dispatch
        CustomOp->>RealImpl: moe_permute_mask_map_forward / index_map_forward
        RealImpl->>Autograd: setup_context (saves row_id_map, etc.)
        RealImpl-->>CustomOp: (output, row_id_map, permuted_probs)
    end
    CustomOp-->>moe_permute: outputs

    User->>moe_permute: .backward()
    moe_permute->>Autograd: backward wrapper
    Autograd->>BwdOp: torch.ops.te_moe.*_bwd
    BwdOp-->>Autograd: act_grad, probs_grad
    Autograd-->>User: gradients
Loading

Comments Outside Diff (2)

  1. transformer_engine/pytorch/permutation.py, line 995-1000 (link)

    Fake function breaks when num_out_tokens is not a positive integer

    num_out_tokens is used directly as a tensor dimension without validation. moe_permute defaults num_out_tokens to -1 (the "no dropping" sentinel), and the public API also accepts None. When torch.compile calls this fake during tracing with either value, torch.empty((-1, hidden_size)) or torch.empty((None, hidden_size)) raises an invalid-shape error and aborts compilation.

    The companion fake for the index-map path already handles this correctly:

    output_tokens = num_out_tokens if num_out_tokens > 0 else num_tokens * topK

    The mask-map fake on lines 995 and 1000 both use num_out_tokens as a raw dimension and need the same guard. For example:

    resolved = num_out_tokens if (num_out_tokens is not None and num_out_tokens > 0) else num_tokens * routing_map.shape[1]
    fake_output = torch.empty((resolved, hidden_size), dtype=inp.dtype, device=inp.device)

    Line 1000 (fake_permuted_probs) has the same issue and needs the same fix.

  2. tests/pytorch/test_permutation.py, line 254-255 (link)

    Skip guard misses the None token-count case for torch.compile

    The test is parametrized over token counts of [None, 2039], but the skip guard only filters on num_expert and topK. This means the combination of use_torch_compile=True with a None token count will reach the compiled path. That None is forwarded directly into torch.ops.te_moe.permute_mask_map_fwd, whose signature declares an integer argument. The fake function then tries to allocate a tensor with a None dimension, causing a tracing failure.

    The same gap exists in test_permutation_mask_map_alongside_probs at the corresponding skip line, since te_permute_with_probs forwards the same argument to the same op.

    The skip guard should also exclude the None token-count case to avoid exercising the broken fake-function path until it is hardened to handle sentinel values.

Last reviewed commit: f5186b2

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

pggPL and others added 2 commits February 19, 2026 15:57
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Collaborator Author

pggPL commented Feb 19, 2026

/te-ci pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +225 to +227
import torch._functorch.config as functorch_config

functorch_config.donated_buffer = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it do and why do we need to do that? Could we add a comment here, especially since we would be using the internal function here (and so it will most probably break at some point).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is optimization of torch.compile which is not compatible with retain_graph=True used in tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some comment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where?

# ===================== _moe_permute_index_map custom ops =====================

topK = index.size(1)
# Workspace state for moe_permute_index_map
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like it (although I realize this is not really the problem with this PR, but rather the original implementation).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can figure out how to change that however, that would be great. Maybe we could make moe_compute a functor (struct MoECompute with __call__ methods and the workspaces, then moe_compute would just be a object of that class that we would create at the very beginning).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? I mean what you don't like about it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the main thing is the fact that we implicitly rely on the fact that there is only one permutation happening at a time (and that problem would not be solved by my proposal BTW - this would need a change of this to be actual nn.Module but that has its own problems by effectively being an API break, we should still do it for TE 3.0 though). If you run 2 permutations in 2 streams then that has a chance of silent data corruption since both of those kernels would be using the same underlying workspace. This is something that the user has no way of knowing about without consulting the code. And with torch.compile the chance of this happening may be even bigger - we are at the whim of the compiler optimizations at this point.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change it in TE 3.0 then? I can indeed change it to functor, but as you said this will not solve a problem.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reply to offline discussion:

  1. there is no support for autograd for ops which mutate args,
  2. torch.compile does not put thing in different streams

Signed-off-by: root <pgadzinski@nvidia.com>
Signed-off-by: root <pgadzinski@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@pggPL pggPL closed this Mar 5, 2026
@pggPL pggPL reopened this Mar 5, 2026
@pggPL
Copy link
Collaborator Author

pggPL commented Mar 5, 2026

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants