Skip to content

[JAX] Collective GEMM with FP8 and MXFP8 support#2740

Open
phu0ngng wants to merge 15 commits intoNVIDIA:mainfrom
phu0ngng:cgemm_fp8
Open

[JAX] Collective GEMM with FP8 and MXFP8 support#2740
phu0ngng wants to merge 15 commits intoNVIDIA:mainfrom
phu0ngng:cgemm_fp8

Conversation

@phu0ngng
Copy link
Collaborator

@phu0ngng phu0ngng commented Mar 5, 2026

Description

This PR extends the JAX Collective GEMM support with DelayedScalingFP8, CurrentScalingFP8, and MXFP8.
Unit tests for those quantization recipes are added. In addition, this PR also cleans up the test infrastructure in the collective gemm tests.

Note that Collective GEMM + MXFP8 requires all dimensions of the GEMM operands to be divisible by 128.
Besides, in the case of CGEMM + MXFP8 + AllGather, the block scales are still all-gathered in the critical path, unlike the quantized data, which is collectively gathered overlapping with the computation.

Type of change

  • Documentation change (change only to the documentation, either a fix or new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@phu0ngng phu0ngng changed the title [JAX] CGEMM + FP8 [JAX] CGEMM + FP8MXFP8 Mar 10, 2026
phu0ngng and others added 14 commits March 10, 2026 15:35
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng phu0ngng changed the title [JAX] CGEMM + FP8MXFP8 [JAX] CGEMM + FP8/MXFP8 Mar 10, 2026
@phu0ngng
Copy link
Collaborator Author

/te-ci JAX L1

@phu0ngng phu0ngng marked this pull request as ready for review March 10, 2026 23:23
@phu0ngng phu0ngng changed the title [JAX] CGEMM + FP8/MXFP8 [JAX] Collective GEMM with FP8 and MXFP8 support Mar 10, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 10, 2026

Greptile Summary

This PR extends JAX Collective GEMM (CGEMM) to support FP8 quantization recipes — DelayedScaling, CurrentScaling, and MXFP8 — as well as laying groundwork for NVFP4 (currently commented out). The implementation adds block-scale reordering helpers for MXFP8 (_reorder_tpsp_leading / _reorder_dp_leading), updates the _parse_operand_output_specs sharding logic to correctly shard scale inverses for collective ops, and introduces two new public helper functions (is_quantize_recipe_supported, get_quantization_recipe) in helper.py. Test infrastructure is also overhauled to run individual test cases rather than whole test files, and the --fp8-recipe argument is renamed to --quantize-recipe.

Key changes:

  • gemm.py: Extracted LHS/output reorder logic into _reorder_tpsp_leading / _reorder_dp_leading helpers; MXFP8 + collective path skips padding and instead asserts alignment-to-128 requirements; scale-inverse sharding is now properly expressed in the partition rule (sequence dim unsharded for LHS scale in AllGather); NVFP4 + collective GEMM is guarded with an assertion.
  • helper.py: is_quantize_recipe_supported() and get_quantization_recipe() provide a string-based API for recipe lookup, used throughout the test files.
  • Test files: All three test files (test_gemm.py, test_dense_grad.py, test_layernorm_mlp_grad.py) gain FP8/MXFP8 test variants with hardware-capability skip guards and correct tolerance selection via get_tolerance_dtype().
  • One minor issue: the NVFP4 guard assertion in _te_gemm still lists only two supported quantization modes in its error message, omitting the newly added MXFP8.

Confidence Score: 4/5

  • This PR is safe to merge; the only flagged issue is a stale error message string, not a functional defect.
  • The core logic additions (reorder helpers, sharding spec updates, scale-padding bypass for MXFP8+collective) are well-structured and match the described design. New tests cover all the advertised recipe+collective-op combinations with hardware-skip guards. The only non-trivial issue found is the assertion error message in _te_gemm that still claims only two modes are supported after MXFP8 support was added, which would mislead NVFP4 users. No functional correctness issues were found.
  • transformer_engine/jax/cpp_extensions/gemm.py — the NVFP4 assertion error message should be updated to list MXFP8 as a supported mode.

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/gemm.py Core GEMM primitive extended with MXFP8 + collective support: new _reorder_tpsp_leading/_reorder_dp_leading helpers, updated sharding specs to include scale inverses, and NVFP4 guard. One issue: the NVFP4 assertion error message still lists only two supported modes, omitting the newly added MXFP8.
transformer_engine/jax/quantize/helper.py Clean additions of is_quantize_recipe_supported() and get_quantization_recipe() with name-to-enum/class maps; exported in __all__ and well-documented.
examples/jax/collective_gemm/common.py Imports reorganized, FP8 tolerances added to dtype_tols(), new get_tolerance_dtype() helper introduced, dead PARAMS_KEY constant and assert_allclose_print_index() removed, --fp8-recipe renamed to --quantize-recipe with expanded choices.
examples/jax/collective_gemm/test_gemm.py Duplicate _get_dp_and_tp_sizes removed, quantizer_set wired through _jitted_cgemm, and new test methods added for DelayedScaling/CurrentScaling/MXFP8 + AllGather/ReduceScatter.
examples/jax/collective_gemm/test_dense_grad.py FP8/MXFP8 test variants added with graceful hardware-skip guards; quantizer_set propagated through gradient helpers; parse_args([]) correctly replaced with parse_args() to enable CLI usage.
examples/jax/collective_gemm/test_layernorm_mlp_grad.py FP8/MXFP8 test variants added; quantizer_sets propagated to MLP. A single quantizer_set is reused for both MLP layers — acceptable for functional correctness tests, since the same object is used identically in both the reference and collective-op paths.
examples/jax/collective_gemm/run_test_cgemm.sh Migrated from file-level to individual test-case granularity; new FP8/MXFP8 cases added. Minor: the rm line at the end of the loop body has 1-space indentation instead of the 2-space style used elsewhere.
transformer_engine/jax/csrc/extensions/gemm.cpp Whitespace-only change (blank lines added for readability around conditionals); no functional change.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[tex.gemm / dense / layernorm_mlp called\nwith quantizer_set + collective_op] --> B{scaling_mode?}
    B -->|NO_SCALING / BF16| C[Pass through — no quantization]
    B -->|DELAYED / CURRENT\nTENSOR_SCALING| D[Tensor scales — no padding,\nno block reorder]
    B -->|MXFP8_1D_SCALING| E{collective_op?}
    B -->|NVFP4_1D_SCALING| F{collective_op?}

    E -->|NONE / outer| G[Apply padding to scale_inv\nSwizzle scale_inv]
    E -->|ALL_GATHER or\nREDUCE_SCATTER| H[Assert dims % 128 == 0\nSkip padding\nSwizzle scale_inv]

    H --> I{collective_op}
    I -->|REDUCE_SCATTER| J[_reorder_tpsp_leading LHS\n_reorder_tpsp_leading lhs_scale_inv]
    I -->|ALL_GATHER| K[_reorder_tpsp_leading lhs_scale_inv]

    J --> L[GemmPrimitive.inner_primitive.bind]
    K --> L
    G --> L

    L --> M{collective_op}
    M -->|ALL_GATHER| N[_reorder_dp_leading output]
    M -->|REDUCE_SCATTER / NONE| O[Return output as-is]
    N --> P[Return output]
    O --> P

    F -->|NONE| Q[NVFP4 GEMM — no collective]
    F -->|not NONE| R[AssertionError:\nNVFP4 + collective not supported]

    D --> S{collective_op}
    S -->|ALL_GATHER or RS| T[Scale sharding: lhs_scale_specs\nwith sequence_dim=None for AG]
    S -->|NONE| U[Scale sharding: none_sharding]
Loading

Last reviewed commit: 3edacb6

Comment on lines +1253 to +1256
assert not scaling_mode.is_nvfp4_scaling, (
f"Collective GEMM is not yet supported with {scaling_mode} quantization. "
"Only DELAYED_TENSOR_SCALING and CURRENT_TENSOR_SCALING are supported."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outdated error message — MXFP8 is now supported

The assertion message was written before MXFP8 support was added. Now that this PR introduces MXFP8 + Collective GEMM, the error message inaccurately tells users that only DELAYED_TENSOR_SCALING and CURRENT_TENSOR_SCALING are valid options — MXFP8 should be included.

Suggested change
assert not scaling_mode.is_nvfp4_scaling, (
f"Collective GEMM is not yet supported with {scaling_mode} quantization. "
"Only DELAYED_TENSOR_SCALING and CURRENT_TENSOR_SCALING are supported."
)
if not collective_op.is_none:
assert not scaling_mode.is_nvfp4_scaling, (
f"Collective GEMM is not yet supported with {scaling_mode} quantization. "
"Only DELAYED_TENSOR_SCALING, CURRENT_TENSOR_SCALING, and MXFP8_1D_SCALING are supported."
)

# Remove the log files after processing them
wait
rm ${TEST_FILE}_gpu_*.log
rm ${TEST_NAME}_gpu_*.log
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent indentation on rm command

This line uses 1-space indentation while the rest of the for loop body consistently uses 2 spaces (e.g., wait, if grep, echo). This is a minor formatting inconsistency introduced in the refactor.

Suggested change
rm ${TEST_NAME}_gpu_*.log
rm ${TEST_NAME}_gpu_*.log

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant