[PyTorch] torch.compile support for permutation functions#2686
[PyTorch] torch.compile support for permutation functions#2686pggPL wants to merge 11 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
41e22ef to
8159d26
Compare
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
|
/te-ci pytorch |
Greptile SummaryThis PR refactors all MoE permutation operations ( Key issues found:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant moe_permute
participant CustomOp as torch.ops.te_moe.*
participant FakeImpl as register_fake
participant RealImpl as Real Forward
participant Autograd as register_autograd
participant BwdOp as Backward Custom Op
User->>moe_permute: call (inp, routing_map, ...)
alt torch.compile tracing
moe_permute->>CustomOp: dispatch
CustomOp->>FakeImpl: shape inference only
FakeImpl-->>CustomOp: fake output tensors
else eager execution
moe_permute->>CustomOp: dispatch
CustomOp->>RealImpl: moe_permute_mask_map_forward / index_map_forward
RealImpl->>Autograd: setup_context (saves row_id_map, etc.)
RealImpl-->>CustomOp: (output, row_id_map, permuted_probs)
end
CustomOp-->>moe_permute: outputs
User->>moe_permute: .backward()
moe_permute->>Autograd: backward wrapper
Autograd->>BwdOp: torch.ops.te_moe.*_bwd
BwdOp-->>Autograd: act_grad, probs_grad
Autograd-->>User: gradients
|
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
| import torch._functorch.config as functorch_config | ||
|
|
||
| functorch_config.donated_buffer = False |
There was a problem hiding this comment.
What does it do and why do we need to do that? Could we add a comment here, especially since we would be using the internal function here (and so it will most probably break at some point).
There was a problem hiding this comment.
This is optimization of torch.compile which is not compatible with retain_graph=True used in tests.
There was a problem hiding this comment.
I added some comment.
| # ===================== _moe_permute_index_map custom ops ===================== | ||
|
|
||
| topK = index.size(1) | ||
| # Workspace state for moe_permute_index_map |
There was a problem hiding this comment.
I don't like it (although I realize this is not really the problem with this PR, but rather the original implementation).
There was a problem hiding this comment.
If we can figure out how to change that however, that would be great. Maybe we could make moe_compute a functor (struct MoECompute with __call__ methods and the workspaces, then moe_compute would just be a object of that class that we would create at the very beginning).
There was a problem hiding this comment.
why? I mean what you don't like about it
There was a problem hiding this comment.
Well, the main thing is the fact that we implicitly rely on the fact that there is only one permutation happening at a time (and that problem would not be solved by my proposal BTW - this would need a change of this to be actual nn.Module but that has its own problems by effectively being an API break, we should still do it for TE 3.0 though). If you run 2 permutations in 2 streams then that has a chance of silent data corruption since both of those kernels would be using the same underlying workspace. This is something that the user has no way of knowing about without consulting the code. And with torch.compile the chance of this happening may be even bigger - we are at the whim of the compiler optimizations at this point.
There was a problem hiding this comment.
Can we change it in TE 3.0 then? I can indeed change it to functor, but as you said this will not solve a problem.
There was a problem hiding this comment.
Reply to offline discussion:
- there is no support for autograd for ops which mutate args,
- torch.compile does not put thing in different streams
Signed-off-by: root <pgadzinski@nvidia.com>
|
/te-ci pytorch |
Description
This PR adds
torch.compile(fullgraph=True)support for MoE permutation operations (moe_permute,moe_unpermute,moe_sort_chunks_by_index) by converting alltorch.autograd.Functionimplementations to PyTorch custom operators usingtorch.library.custom_op.Note that this PR does not add torch.compile support for QuantizedTensor as an input.
Related to #2590
Type of change
Checklist: