[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912
[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912tdophung wants to merge 28 commits into
Conversation
Greptile SummaryThis PR introduces
Confidence Score: 4/5Safe to merge after addressing the missing ep_axis exclusion guard in the data_parallelism_axes validation loop. The A2A-EP forward correctly addresses the recv_buffer_rows alignment fix. One gap remains: if a caller passes the EP axis name in data_parallelism_axes, the batch PartitionSpec gets a duplicate axis and dp_size is double-counted, producing an undersized ragged_all_to_all receive buffer with no useful error message. transformer_engine/jax/flax/moe.py — specifically the data_parallelism_axes validation block in _forward_a2a_ep. Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant _MoEBlock
participant Router
participant GlobalPermute
participant A2A as ragged_all_to_all (EP)
participant LocalPerm as local_permute_after_a2a
participant ExpertFFN as _expert_ffn (grouped_dense x3)
participant GlobalCombine
Caller->>_MoEBlock: inputs [B, S, H]
_MoEBlock->>Router: "gate_logits -> fused_topk_with_score_function"
Router-->>_MoEBlock: sparse_probs, routing_map
_MoEBlock->>GlobalPermute: _global_permute (pure_jax or triton)
GlobalPermute-->>_MoEBlock: sorted_inputs, group_sizes [E]
alt No-EP path
_MoEBlock->>ExpertFFN: "sorted_inputs, group_sizes, n_groups=E"
ExpertFFN-->>_MoEBlock: expert_outputs
else A2A-EP path via shard_map
_MoEBlock->>A2A: all_gather(group_sizes)
A2A->>A2A: forward ragged_all_to_all over ep axis
A2A->>LocalPerm: reorder recv buffer
LocalPerm-->>A2A: sorted_x, local_group_sizes
A2A->>ExpertFFN: sorted_x, local_group_sizes
ExpertFFN-->>A2A: expert_outputs
A2A->>LocalPerm: local_unpermute_before_a2a
A2A->>A2A: reverse ragged_all_to_all
A2A-->>_MoEBlock: y_back
end
_MoEBlock->>GlobalCombine: _global_combine
GlobalCombine-->>_MoEBlock: output [B, S, H]
_MoEBlock-->>Caller: output [B, S, H], aux_loss
Reviews (6): Last reviewed commit: "change naming and add message for experi..." | Re-trigger Greptile |
Signed-off-by: tdophung <tdophung@nvidia.com>
…ody single GPU vs. multi GPU Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…e and single device initial params in the MoEBlock. Tests should pass now Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: tdophung <tdophung@nvidia.com>
…int in C++ files, make FP8 works. Tested with current scaling Signed-off-by: JAX Toolbox <jax@nvidia.com>
for more information, see https://pre-commit.ci
|
Want your agent to iterate on Greptile's feedback? Try greploops. |
… grad tol to 5e-2, move arch/align_size docs into MoEBlock class Signed-off-by: tdophung <tdophung@nvidia.com>
| batch_divisor = num_ep * dp_size | ||
| if global_batch_size % batch_divisor != 0: | ||
| raise ValueError( | ||
| f"batch={global_batch_size} not divisible by prod(data_parallelism_axes)={dp_size}" | ||
| ) | ||
| recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk |
There was a problem hiding this comment.
Receive buffer undersized when
align_size > 0 + EP are combined
recv_buffer_rows is computed assuming unpadded token counts, but when align_size > 0 the per-expert group_sizes are the aligned counts, so the send_sizes in compute_ragged_all_to_all_params include padding tokens. The worst-case receive per shard is num_ep * ((B/(num_ep*dp_size))*S*K + num_experts_per_shard*(align_size-1)), which exceeds the current recv_buffer_rows = (B/dp_size)*S*K by up to num_experts*(align_size-1) rows. ragged_all_to_all writing beyond the buffer produces incorrect results or a crash. The correct worst-case size is:
recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk + num_experts * (self.align_size - 1 if self.align_size > 0 else 0)
This combination (EP + align_size > 0) is not exercised by the current distributed test (which defaults to align_size=0), so the bug is latent.
phu0ngng
left a comment
There was a problem hiding this comment.
I think we should go with exposing GroupMLP VJP first before the MoE module to enable future possible fusions.
…ing None as group_topk, align_size rename, Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
| for ax in self.data_parallelism_axes: | ||
| if ax not in mesh.shape: | ||
| raise ValueError( | ||
| f"data_parallelism_axes contains {ax!r} but mesh has" | ||
| f" axes {tuple(mesh.shape.keys())}" | ||
| ) |
There was a problem hiding this comment.
The validation loop checks that every axis in
data_parallelism_axes exists in the mesh but does not check that the axis differs from ep_axis. If a caller passes data_parallelism_axes=("ep",) when ep_axis="ep", batch_pspec_axis becomes ("ep", "ep") — a duplicate-axis PartitionSpec that JAX rejects with a cryptic error. Independently, dp_size accumulates mesh.shape["ep"] a second time, so recv_buffer_rows is undersized by a factor of num_ep and batch_divisor becomes num_ep², both causing wrong runtime behaviour before JAX ever sees the bad spec.
| for ax in self.data_parallelism_axes: | |
| if ax not in mesh.shape: | |
| raise ValueError( | |
| f"data_parallelism_axes contains {ax!r} but mesh has" | |
| f" axes {tuple(mesh.shape.keys())}" | |
| ) | |
| for ax in self.data_parallelism_axes: | |
| if ax not in mesh.shape: | |
| raise ValueError( | |
| f"data_parallelism_axes contains {ax!r} but mesh has" | |
| f" axes {tuple(mesh.shape.keys())}" | |
| ) | |
| if ax == ep_axis: | |
| raise ValueError( | |
| f"data_parallelism_axes contains {ax!r}, which is the same as the" | |
| f" EP axis {ep_axis!r}. The EP axis is already included in the batch" | |
| " sharding spec; listing it again produces a duplicate-axis" | |
| " PartitionSpec and an undersized ragged_all_to_all receive buffer." | |
| ) |
…er functions to group permute -> a2a -> local permute to dispatch and combine Signed-off-by: tdophung <tdophung@nvidia.com>
|
Changing back to draft to not spam people's email while I push commits to this branch for the full unrolling of ops in a big VJP. |
for more information, see https://pre-commit.ci
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
The triton backend of test_distributed_moe_vjp hangs in bwd: MainThread parks in _pjit_call_impl_python, one GPU pinned at 100%, no NCCL ops enqueued. Root cause is async-dispatch race between our Triton kernels (which use input_output_aliases on a pre-zeroed output buffer) and the downstream NCCL ragged_all_to_all -- XLA mis-tracks the dependency edge and the collective launches before the kernel finishes writing sorted_inputs; different ranks then read different versions of the per-expert token counts, deadlocking NCCL. Workaround: set CUDA_LAUNCH_BLOCKING=1 in the test runner. Smoke suite now passes in <1 min across 3 consecutive runs. Slowdown on these correctness shapes is negligible. Also flip the faulthandler watchdog to all_threads=True so the next investigator can see worker-thread frames, not just MainThread. Signed-off-by: tdophung <tdophung@nvidia.com>
The 'CRITICAL: -p no:typeguard' section was based on an early incorrect bisection. The actual root cause is an async-dispatch race between our Triton custom_calls (with input_output_aliases on pre-zeroed output buffers) and the downstream NCCL collective in the same shard_map body -- XLA mis-handles the cross-stream sync edge from the aliased custom_call to the NCCL op. Why the old _MoEBlock path didn't hit this: each primitive (token_dispatch, permute, ragged_all_to_all, sort_chunks) sat behind its own custom_vjp boundary, which acted as an implicit sync barrier. The new unified moe() custom_vjp removes those boundaries (so ScaledTensor can survive across them), exposing the bug. Document CUDA_LAUNCH_BLOCKING=1 as the current workaround and flag the proper fix (stream sync in triton_call_lowering, or file an upstream JAX FFI bug) for follow-up. Signed-off-by: tdophung <tdophung@nvidia.com>
faulthandler.dump_traceback_later() takes only (timeout, repeat, file, exit) -- there is no all_threads parameter. It already dumps every Python thread by default. Don't confuse with faulthandler.register() which does take all_threads. The bad kwarg caused pytest collection to crash with TypeError before any test could run. Signed-off-by: tdophung <tdophung@nvidia.com>
Experiment C of the bwd-hang investigation. When NVTE_MOE_OPT_BARRIER=1, insert jax.lax.optimization_barrier on the Triton-kernel output before every immediately-following NCCL ragged_all_to_all in _dispatch, _combine_bwd, _dispatch_bwd. If this fixes the hang WITHOUT CUDA_LAUNCH_BLOCKING=1, the bug is fixable at the lowering layer by forcing materialization between the aliased Triton custom_call and the NCCL collective, which is much cheaper than serializing every CUDA launch. Off by default so the existing CUDA_LAUNCH_BLOCKING workaround continues to work without this opt-in. Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
Experiment C results from dlcluster job 1045311: C.0 (no barrier, no blocking): PASS in 32s (won the race) C.1 (with barrier, single test): TIMEOUT_HANG at 180s C.2 (with barrier, full smoke): TIMEOUT_HANG at 480s The optimization_barrier doesn't fix the hang and in fact makes it MORE reliably reproducible. This rules out the HLO-dataflow fix hypothesis: the bug is not a missing XLA stream-sync edge. C.0 passing demonstrates the hang is RACY not deterministic -- matches Olli Lupton's October 2025 memo on JAX/XLA multi-GPU deadlocks from lazy CUDA module loading interleaved with active NCCL collectives (nvbug/5564750). With multi-GPU-per-process the lazy load of a Triton kernel on GPU0 can take the global driver lock and block on cuiStreamSynchronize for an active NCCL kernel that itself depends on GPU1's progress, which is blocked on the same lock. Triton kernels can't be pre-loaded via FFI 'initialize' (JAX core owns the primitive). The proper fix is multiprocess launch (one JAX process per GPU) -- see follow-up commit. Signed-off-by: tdophung <tdophung@nvidia.com>
Companion to test_distributed_moe_vjp.py that avoids the multi-GPU lazy-load + active-NCCL deadlock entirely by giving each GPU its own Python process / CUDA driver context. With one device per process there is no global module-load lock shared across the threads driving different GPUs, so the failure mode described in past_JAX_XLA_deadlock.txt (nvbug/5564750) cannot occur and no CUDA_LAUNCH_BLOCKING=1 workaround is needed. Pattern mirrors examples/jax/encoder/test_multiprocessing_encoder.py: - pytest --num-process=N --process-id=i CLI options (added to tests/jax/conftest.py, defaults to 0 = single-process so harmless for other tests) - jax.distributed.initialize(... local_device_ids=process_id ...) at module top-level - module-level skip when not launched via the runner so direct pytest collection on tests/jax/ is harmless - run_multiprocess_moe_vjp.sh forks N=nvidia-smi pytest processes and waits for all of them Tests themselves are 1:1 with TestMoeVjpDistributedSmoke from the single-process file (fwd_and_bwd_smoke, aux_loss_smoke, parity). Keeping BOTH files in tree: single-process is simpler for dev-loop iteration (with CUDA_LAUNCH_BLOCKING=1 workaround), multiprocess is what CI should run for guaranteed correctness without workarounds. Signed-off-by: tdophung <tdophung@nvidia.com>
First multiprocess attempt hung at jnp.any(g != 0.0).item() in the post-grad assertion loop. In a single-process test that's a local reduction; in multi-host JAX it implicitly triggers a cross-process collective (all-gather + reduce) under the hood, and any small divergence in graph build order across processes can deadlock. Replace the host-side reductions with: multihost_utils.process_allgather(x, tiled=True) -> np.asarray Then run the finite / non-zero / parity asserts entirely in numpy. Every process gathers in lockstep, no surprise JAX collectives. Also: launcher now respects MOE_VJP_MP_LOG_DIR so per-process logs survive on a host-mounted volume after the container exits. Signed-off-by: tdophung <tdophung@nvidia.com>
process_allgather hung on procs 0+3 while procs 1+2 finished (divergence detected from per-process logs in dlcluster job 1046001). A multi-host collective inside the post-grad assertion loop is too easy to deadlock when even one assertion fires on some procs first. Use the local addressable shard on each process via arr.addressable_data(0) -> np.asarray. Same correctness coverage (if any rank has NaN, that rank's assertion fires) without needing to emit a cross-process collective for the test machinery itself. Signed-off-by: tdophung <tdophung@nvidia.com>
Description
Most of MoE building blocks integration work has been deeply coupled with Maxtext development. Now creating this MoE block to isolate the work from Maxtext and provide more room for experimentation.
MoEBlockis a self-contained Flax-Linen module that wires together TE's fused router, pluggable token-dispatch backends (pure-JAX argsort or Tritonsort_chunks_by_index),grouped_dense-based expert FFN, and ragged-all-to-all (A2Av) expert parallelism viashard_mapThis first iteration will start with ring-of-experts EP, sharding on batch dimention for FSDP, CUBLASLt groupedGEMM and 2 permutation backend: pure JAX or Triton kernels. The block also exposes pluggable knobs for: weight layout (
wi_kernel_axes/wo_kernel_axes), permutation backend, A2A vs no-EP (single GPU) path, data-parallelism axes for true FSDP (batch sharded across(ep, fsdp)simultaneously), top-K with optional grouped/sigmoid scoring (for DSv3 workload), and optional auxiliary load-balancing loss.Fixes #2895
Type of change
Changes
transformer_engine/jax/flax/moe.py--MoEBlockLinen module:gate -> fused topk -> global permute -> A2A EP shard_map (ragged_a2a fwd, local permute, 3x grouped GEMM SwiGLU FFN, local unpermute, ragged_a2a rev) -> global combine.
transformer_engine/jax/permutation.pywith A2A param helpers (compute_ragged_all_to_all_params,compute_reverse_ragged_all_to_all_params,local_permute_after_a2a,local_unpermute_before_a2a) and the pure-JAXunfused_token_dispatch/unfused_token_combinepathswith custom VJPs.
tests/jax/test_moe_block.py-- single-device shape, backward,cross-backend equivalence, aux-loss, group-topk, JIT determinism.
tests/jax/test_distributed_moe_block.py-- EP=2 x FSDP=2 mesh test using the canonical Flax-Linen sharded-init pattern (eval_shape -> get_partition_spec -> logical_to_mesh_sharding -> jit(init, out_shardings=...)) anddata_parallelism_axes=("fsdp",)to exercise true FSDP (batch sharded across both axes).Checklist: