Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
496e3ab
initial impl
tdophung Apr 21, 2026
f453137
clean up any link to Maxtext. Permutation backends. clean up foward b…
tdophung Apr 22, 2026
0044bf2
add distributed test.
tdophung Apr 23, 2026
d78bc01
refactor to a2a from roe
tdophung Apr 30, 2026
6f87629
fix test_distributed issues with unpopulated LogicallyPartition pytre…
tdophung Apr 30, 2026
6aeb491
add option to choose weight fsdp sharding axis
tdophung May 5, 2026
25e1eb8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 5, 2026
d7fef5a
address greptile comments
tdophung May 6, 2026
3a51708
address jeremys comments + relax the sum(group_size) <= dim_m constra…
nvjax May 7, 2026
dafaad4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2026
27c18fe
revert C++ changes and will put in a new branch, tighten distributed…
tdophung May 12, 2026
abbb2c6
address more comments: ep_resource look up, perm backend enum, accept…
tdophung May 12, 2026
b375db7
tests/jax/test_distributed_moe_block.py
tdophung May 12, 2026
37c871c
change naming and add message for experimental feature
tdophung May 12, 2026
3206244
refactor moeBlock into a giant VJP, unrolling most ops, but have help…
tdophung May 15, 2026
84a7c00
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2026
f6c6e43
some test scripts to add some to delete
tdophung May 15, 2026
43fcbdd
WIP: iteration on moe vjp distributed hang
tdophung May 19, 2026
fb11714
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 19, 2026
dfc25bf
test(jax): unblock distributed MoE smoke suite with CUDA_LAUNCH_BLOCKING
tdophung May 19, 2026
f31bbf6
docs(jax): rewrite test_distributed_moe_vjp docstring
tdophung May 19, 2026
317de4f
fix(jax): drop invalid all_threads kwarg from dump_traceback_later
tdophung May 19, 2026
afc7406
test(jax): add NVTE_MOE_OPT_BARRIER flag for experiment C
tdophung May 19, 2026
04436aa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 19, 2026
b639f81
revert(jax): drop NVTE_MOE_OPT_BARRIER flag (experiment C disproven)
tdophung May 19, 2026
dbe9407
test(jax): multiprocess MoE VJP test (one GPU per process)
tdophung May 19, 2026
37bf0b1
fix(jax): use process_allgather + numpy assertions in MP test
tdophung May 20, 2026
230b711
fix(jax): check local shards instead of process_allgather in MP test
tdophung May 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/envvars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,12 @@ JAX Triton Extensions
:Default: ``0``
:Description: Raise a ``RuntimeError`` when the installed JAX is too old to safely run ``TritonAutotunedKernelCall`` (`jax-ml/jax#35218 <https://github.com/jax-ml/jax/pull/35218>`_) instead of silently falling back to non-autotuned dispatch. Useful for CI or debugging to ensure Triton autotuning is active. When set to ``0`` (default), old JAX versions silently fall back to single-config (non-autotuned) kernel dispatch for compatibility.

.. envvar:: NVTE_TRITON_PERMUTATION_BLOCK_SIZES

:Type: comma-separated list of ``int`` (e.g. ``"128"`` or ``"64,128,256"``)
:Default: ``"64,128,256,512,1024,2048,4096"`` (the full sweep)
:Description: Override the ``BLOCK_SIZE`` configs evaluated by ``triton.autotune`` for the MoE permutation kernels in ``transformer_engine/common/triton/permutation.py`` (``_permute_kernel``, ``_unpermute_kernel``, ``_unpermute_bwd_with_merging_probs_kernel``, ``_sort_chunks_by_map_kernel``). The default 7-config sweep yields the best runtime on production shapes but costs ~1-5 s of MLIR→LLVM→PTX→cubin compile per config-per-kernel on a cold start (≈2-5 min total per backend, serialized on a single GPU). Set to a single value (e.g. ``"128"``) to skip autotuning entirely for tests / CI where correctness -- not throughput -- is the goal. Must be a comma-separated list of positive ints; malformed values raise ``ValueError`` at kernel-registration time. **Do NOT set this in production runs** -- you will lose autotuned performance.

Examples
--------

Expand Down
27 changes: 27 additions & 0 deletions qa/L0_jax_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,33 @@ wait
TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/collective_gemm/run_test_cgemm.sh || test_fail "run_test_cgemm.sh"
wait

# MoE custom_vjp distributed (Level 2 smoke + Level 3 perf). Single-host
# multi-GPU; requires >=4 visible GPUs.
#
# Flags required for this file (mirrored in tests/jax/run_distributed_moe_vjp.sh):
#
# * ``-p no:typeguard`` — jaxtyping's pytest plugin auto-loads typeguard,
# whose @typechecked import hook materialises JAX tracers via isinstance()
# checks during shard_map tracing. We disable it only here (other jax tests
# need it for type-hint validation).
# * ``XLA_PYTHON_CLIENT_PREALLOCATE=false`` + ``MEM_FRACTION=0.5`` —
# prevents NCCL OOM during EP all-to-all communicator setup (default 90%
# preallocation leaves no room).
# * ``CUDA_LAUNCH_BLOCKING=1`` — workaround for an async-dispatch hang
# between Triton custom_calls with ``input_output_aliases`` and the
# downstream NCCL ragged_all_to_all in this test's bwd path. Without it,
# MainThread parks in _pjit_call_impl_python and one GPU pins at 100%
# forever. With it, the smoke suite passes in <1 min. See
# ``tests/jax/test_distributed_moe_vjp.py`` module docstring for the
# bisection record and TODO for the proper fix.
XLA_PYTHON_CLIENT_PREALLOCATE=false XLA_PYTHON_CLIENT_MEM_FRACTION=0.5 \
CUDA_LAUNCH_BLOCKING=1 \
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v -s \
-p no:typeguard \
--junitxml=$XML_LOG_DIR/pytest_test_distributed_moe_vjp.xml \
$TE_PATH/tests/jax/test_distributed_moe_vjp.py || test_fail "test_distributed_moe_vjp.py"
wait

if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
exit 1
Expand Down
14 changes: 14 additions & 0 deletions tests/jax/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,20 @@ def pytest_sessionfinish(self, session, exitstatus):
print("=" * 80)


def pytest_addoption(parser):
"""CLI options for multiprocess JAX tests.

Mirrors examples/jax/encoder/conftest.py so multiprocess tests in
tests/jax/ can be launched one-process-per-GPU via a sibling shell
script. Required by tests/jax/test_multiprocess_moe_vjp.py to work
around the JAX/XLA + lazy Triton kernel load + active NCCL deadlock
documented in past_JAX_XLA_deadlock.txt and nvbug/5564750. Harmless
for other tests; defaults to 0 (= "not a multiprocess launch").
"""
parser.addoption("--num-process", action="store", default=0)
parser.addoption("--process-id", action="store", default=0)


def pytest_configure(config):
config.addinivalue_line(
"markers",
Expand Down
Loading
Loading