[JAX] Collective GEMM with FP8 and MXFP8 support#2740
[JAX] Collective GEMM with FP8 and MXFP8 support#2740phu0ngng wants to merge 15 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
|
/te-ci JAX L1 |
for more information, see https://pre-commit.ci
Greptile SummaryThis PR extends JAX Collective GEMM (CGEMM) to support FP8 quantization recipes — DelayedScaling, CurrentScaling, and MXFP8 — as well as laying groundwork for NVFP4 (currently commented out). The implementation adds block-scale reordering helpers for MXFP8 ( Key changes:
Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[tex.gemm / dense / layernorm_mlp called\nwith quantizer_set + collective_op] --> B{scaling_mode?}
B -->|NO_SCALING / BF16| C[Pass through — no quantization]
B -->|DELAYED / CURRENT\nTENSOR_SCALING| D[Tensor scales — no padding,\nno block reorder]
B -->|MXFP8_1D_SCALING| E{collective_op?}
B -->|NVFP4_1D_SCALING| F{collective_op?}
E -->|NONE / outer| G[Apply padding to scale_inv\nSwizzle scale_inv]
E -->|ALL_GATHER or\nREDUCE_SCATTER| H[Assert dims % 128 == 0\nSkip padding\nSwizzle scale_inv]
H --> I{collective_op}
I -->|REDUCE_SCATTER| J[_reorder_tpsp_leading LHS\n_reorder_tpsp_leading lhs_scale_inv]
I -->|ALL_GATHER| K[_reorder_tpsp_leading lhs_scale_inv]
J --> L[GemmPrimitive.inner_primitive.bind]
K --> L
G --> L
L --> M{collective_op}
M -->|ALL_GATHER| N[_reorder_dp_leading output]
M -->|REDUCE_SCATTER / NONE| O[Return output as-is]
N --> P[Return output]
O --> P
F -->|NONE| Q[NVFP4 GEMM — no collective]
F -->|not NONE| R[AssertionError:\nNVFP4 + collective not supported]
D --> S{collective_op}
S -->|ALL_GATHER or RS| T[Scale sharding: lhs_scale_specs\nwith sequence_dim=None for AG]
S -->|NONE| U[Scale sharding: none_sharding]
Last reviewed commit: 3edacb6 |
| assert not scaling_mode.is_nvfp4_scaling, ( | ||
| f"Collective GEMM is not yet supported with {scaling_mode} quantization. " | ||
| "Only DELAYED_TENSOR_SCALING and CURRENT_TENSOR_SCALING are supported." | ||
| ) |
There was a problem hiding this comment.
Outdated error message — MXFP8 is now supported
The assertion message was written before MXFP8 support was added. Now that this PR introduces MXFP8 + Collective GEMM, the error message inaccurately tells users that only DELAYED_TENSOR_SCALING and CURRENT_TENSOR_SCALING are valid options — MXFP8 should be included.
| assert not scaling_mode.is_nvfp4_scaling, ( | |
| f"Collective GEMM is not yet supported with {scaling_mode} quantization. " | |
| "Only DELAYED_TENSOR_SCALING and CURRENT_TENSOR_SCALING are supported." | |
| ) | |
| if not collective_op.is_none: | |
| assert not scaling_mode.is_nvfp4_scaling, ( | |
| f"Collective GEMM is not yet supported with {scaling_mode} quantization. " | |
| "Only DELAYED_TENSOR_SCALING, CURRENT_TENSOR_SCALING, and MXFP8_1D_SCALING are supported." | |
| ) |
| # Remove the log files after processing them | ||
| wait | ||
| rm ${TEST_FILE}_gpu_*.log | ||
| rm ${TEST_NAME}_gpu_*.log |
There was a problem hiding this comment.
Inconsistent indentation on rm command
This line uses 1-space indentation while the rest of the for loop body consistently uses 2 spaces (e.g., wait, if grep, echo). This is a minor formatting inconsistency introduced in the refactor.
| rm ${TEST_NAME}_gpu_*.log | |
| rm ${TEST_NAME}_gpu_*.log |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Description
This PR extends the JAX Collective GEMM support with DelayedScalingFP8, CurrentScalingFP8, and MXFP8.
Unit tests for those quantization recipes are added. In addition, this PR also cleans up the test infrastructure in the collective gemm tests.
Note that Collective GEMM + MXFP8 requires all dimensions of the GEMM operands to be divisible by 128.
Besides, in the case of CGEMM + MXFP8 + AllGather, the block scales are still all-gathered in the critical path, unlike the quantized data, which is collectively gathered overlapping with the computation.
Type of change
Checklist: