[OMNIML-5072] Triton fakequant adapter for N-quantizer per-expert path#1717
Draft
hychiang-git wants to merge 11 commits into
Draft
[OMNIML-5072] Triton fakequant adapter for N-quantizer per-expert path#1717hychiang-git wants to merge 11 commits into
hychiang-git wants to merge 11 commits into
Conversation
…nfra fixes
modelopt/torch/quantization/plugins/transformer_engine.py:
MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER=1 opts into per-gemm
weight_quantizer_0..N-1 inside _QuantTEGroupedLinear (deepcopied from
the shared weight_quantizer). Lets TEGroupedMLP recover per-expert
amax granularity, matching SequentialMLP's default behavior.
modelopt/torch/distill/plugins/megatron.py:
LogitsKLLoss.forward prints student/teacher logit stats (mean/std/
min/max/shape) on rank 0 each call. Diagnostic for the QAD loss-spike
investigation — confirms which spec produces which logits without
changing the KL math.
tests/gpu_megatron/torch/quantization/plugins/test_megatron.py:
New test_te_grouped_vs_sequential_default_amax + ..._default_loss
cover the structural amax asymmetry between TEGroupedMLP and
SequentialMLP (TEGrouped per-linear amax = max-over-Sequential-experts
amax) and a finiteness sanity check on the resulting quant error.
tools/launcher/common/service_utils.sh:
- Fall back to SLURM_PROCID / SLURM_LOCALID when PMIX_*/OMPI_* are
unset, so `[[ "$mpi_local_rank" -eq 0 ]]` doesn't silently pass on
every rank under plain srun.
- util_install_extra_dep: per-node marker so concurrent ranks wait
for rank 0 to finish installing (concurrent pip on a shared FS
leaves a broken state); also installs nvidia-resiliency-ext.
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
- transformer_engine.py: dedup `import copy`/`import os` left over from the
rebase, sort the four imports alphabetically.
- transformer_engine.py: comment near the per-expert weight_quantizer setup
explaining that base modelopt_post_restore won't re-calibrate the
weight_quantizer_{i} modules, so save/restore is only safe when TP/EP is
unchanged. Per-channel _amax shape depends on the TP-sliced output dim.
- service_utils.sh: drop the duplicated mpi_rank / mpi_local_rank
re-assignments — main already carries the SLURM fallback, the extra two
lines were leftover rebase noise.
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Copies the single-launch tensor-of-pointers fake-quant kernel module from hungyuehc/omniml-4998-umbrella (kernel landed in 0bf4838, libdevice.rint rounding refined in 1080e68). Kernel file is unchanged. Wires the new module into modelopt.torch.kernels.quantization.gemm via the existing IS_AVAILABLE/triton-import block in __init__.py. The transformer_engine.py adapter that calls this kernel from the N-modules per-expert path follows in the next commit. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…for N-modules per-expert path Wires grouped_axis0_fakequant into _QuantTEGroupedLinear's forward when _per_expert_weight_quantizer is on. Adds: - _GroupedAxis0FakeQuantFn (torch.autograd.Function): single forward call for N expert weights; backward honors pass_through_bwd=True (identity) and dispatches to the Triton bwd kernel when False. - _gather_per_expert_amax: stacks N weight_quantizer_i._amax scalars into a [N] fp32 vector matching the kernel's amax-input contract. - _can_use_triton_per_expert_path: soft-gate on IS_AVAILABLE, all per-expert quantizers being TensorQuantizer with _amax set, and not currently calibrating (q._if_calib). - te_grouped_quantized_linear_fn now branches: Triton path when gate passes; original per-quantizer cuda_ext loop otherwise. Replaces N cuda_ext kernel launches with 1 Triton launch on the forward hot path. No behavior change when the env var opt-in is off. Untested at runtime yet; AC2 parity test (Ultra production shape, N=32, pass_through_bwd=True) is the next step. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds tests/gpu/torch/quantization/plugins/test_te_grouped_triton_parity.py mirroring nmm-sandbox/studies/omniml-5064/microbench/parity_a_vs_btriton.py as a pytest in the modelopt tests/ surface. Two checks at each parameterized shape (N=4/8/32): - Forward parity within 1 ULP. Known rounding-mode mismatch floor between Triton's libdevice.rint and cuda_ext's banker's rounding at some bf16 boundary values. - Backward parity bit-exact under pass_through_bwd=True. Both paths must return grad_out unchanged regardless of forward kernel. Plus a slow-marked Ultra production shape variant (N=32, [5120, 8192] bf16) for full-scale validation. Marked slow because the unquantized + quantized + gradient copies of 32 expert weights at that shape use ~5 GB of GPU memory; CI default-suite stays on the smaller parameterized cases. Test not yet run — requires GPU + container with modelopt installed. Expected to PASS per the matching standalone parity_a_vs_btriton.py output on B's path after libdevice.rint refinement (1080e68): forward parity within 1 ULP, backward bit-exact under pass_through_bwd=True. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…les per-expert path
Adds dist-checkpoint support to A's N-submodule per-expert weight quantization
path on _QuantMegatronTEGroupedLinear. Mirrors B's gather-once-cache pattern
(hungyuehc/omniml-4998-umbrella) adapted for A's storage layout — N separate
weight_quantizer_i._amax scalars across N submodules instead of one [N,1,1]
buffer on a single quantizer.
Methods added:
- _ep_group: returns the EP group if initialized and world_size > 1.
- _gather_global_per_expert_amax_n_modules: stacks N local scalar amaxes from
the submodules, all-gathers across EP, returns [N_global]. None when the
layer is not in per-expert mode.
- sharded_state_dict: caches the gathered global vector before delegating to
super so the EP collective completes BEFORE Megatron's dist-checkpoint save
fires default-PG ALLGATHER metadata exchanges (interleaving EP + default-PG
collectives deadlocks NCCL — codified in
[[feedback-no-custom-collectives-in-dist-ckpt-save]]).
Methods replaced:
- _process_quantizer_amax: emits the cached global [N_global] vector under
every weight_quantizer_i._amax key in per-expert mode. Suboptimal disk
usage (N copies of same vector per layer) but mirrors B's pattern and
avoids surgery into the base-class state-dict iteration.
- _load_from_state_dict: preserves the existing _extra_state{i} filter and
adds the per-expert narrow — pulls element (ep_rank * N_local + i) out of
each saved [N_global] vector for the i-th local submodule. Falls through
unchanged when v.numel() != global_size (legacy / EP=1 save format).
Validated via the AC4 parity test (next commit) at TP=2, EP=2 across
FP8_DEFAULT_CFG and NVFP4_DEFAULT_CFG.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…dict round-trip Adds test_te_grouped_n_modules_sharded_state_dict parameterized over FP8_DEFAULT_CFG and NVFP4_DEFAULT_CFG. Builds a TEGroupedMLP model at TP=2 EP=2 num_moe_experts=4, quantizes with the MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER=1 env-var path, saves dist-ckpt, restores into a fresh meta-device model, asserts equivalence via the existing sharded_state_dict_test_helper. Mirrors the layout of B's test_te_grouped_per_expert_sharded_state_dict (hidden_size=256, dist_workers fixture) but triggers A's env-var-gated per-expert path instead of B's axis=0 quant_cfg knob. Also cherry-picks the OMNIML-5030 sequence_parallel fix that A's branch predates: get_mcore_gpt_model in tests/_test_utils/torch/megatron/models.py gains a sequence_parallel parameter that threads through TransformerConfig, and the non-hybrid call site in _gpt_model_provider passes sequence_parallel=(tp_size > 1). Without this, Megatron-Core ValueError-s during MoE + TP > 1 model construction. The hybrid path on A's branch already had this fix. Validated on aws-cmh slurm — 2 passed, 92.99s wall (job 537620). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…d N-loop overhead
The Triton dispatch in te_grouped_quantized_linear_fn calls
_gather_per_expert_amax on every forward. The gather walks N submodules
(O(N) Python overhead) and stacks N scalar _amax buffers into a [N] fp32
vector. The result is invariant across forwards (per-expert _amax does
not change outside calibration, and the gate already blocks the Triton
path when q._if_calib is True on any quantizer).
Caching the gathered tensor lazily on first call eliminates the per-forward
overhead. Invalidation hook _invalidate_per_expert_amax_cache is called from
modelopt_post_restore (where dist-ckpt reload may have changed _amax).
Measured impact on OMNIML-5064 microbench (Nemotron Nano EP=4, N=32):
fwd_us: 1918 (no-cache) -> 1244 (cached) (35% drop)
step_us: 3444 (no-cache) -> 2785 (cached) (19% drop)
vs Btriton5 (B's path with same Triton kernel):
ATriton-cached 1244 vs Btriton5 1208 -> effectively tied
ATriton-cached step 2785 vs Btriton5 step 2815 -> ATriton edges B
Without the cache the gap to Btriton5 grows with N (1.59x at N=32, 2.18x
at N=128, observed in the no-cache nano_ep4 + super_ep4 runs). With the
cache, the gap closes to within run-to-run noise.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Contributor
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
The doc was a placeholder for B-triton's pre-validation tracking (commit 0bf4838 on PR #1671); validation has since landed and the file is no longer load-bearing. The kernel module references it from a docstring line that is stale and will be cleaned up by PR #1671's review pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Type of change: new feature
Implements OMNIML-5072, built on top of PR #1550 (the N-quantizer per-expert foundation, still WIP). Three additions:
grouped_axis0_fakequantkernel from PR #1671 (the One-Vec-quanitzer path) into_QuantTEGroupedLinear.te_grouped_quantized_linear_fnwhen_per_expert_weight_quantizer == True. Soft-gated behind_triton_kernels.IS_AVAILABLEandq._if_calib; falls back to the existingcuda_extper-quantizer loop when the gate is False._gather_per_expert_amaxhelper eliminates per-forward O(N) Python overhead (walks the N submodules from PR WIP Support per expert amax in TEGroupedMLP #1550 and stacks the N scalar_amaxbuffers; lazily cached, invalidated frommodelopt_post_restore).sharded_state_dictsave + EP-aware load on_QuantMegatronTEGroupedLinear's N-quantizer case — gather-once-cache pattern adapted to N scalar_amaxbuffers across the EP group, so the dist-ckpt round-trip that PR [OMNIML-4998] Per-expert weight quantization on TEGroupedLinear #1671 ships for One-Vec-quanitzer also works on the N-quantizer foundation from PR WIP Support per expert amax in TEGroupedMLP #1550.This PR is stacked on top of PR #1550 (
jennifchen/te_per_expert, still WIP). The diff againstmainincludes PR #1550's commits underneath; the OMNIML-5072-specific work is the top 6 commits (fd77b53d8..51b4c9226). Once PR #1550 lands, rebase ontomainto shrink the review surface.Usage
Testing
GPU-validated on aws-cmh (B300, nemo:26.02 / nemo:25.11 containers):
pass_through_bwd=True. Test attests/gpu/torch/quantization/plugins/test_te_grouped_triton_parity.py.FP8_DEFAULT_CFGandNVFP4_DEFAULT_CFG. Test attests/gpu_megatron/torch/quantization/plugins/test_megatron.py::test_te_grouped_n_modules_sharded_state_dict.Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.)._triton_kernels.IS_AVAILABLEis False or during calibration (q._if_calib), the originalcuda_extper-quantizer loop runs unchanged.CONTRIBUTING.md: ✅ — Reused thegrouped_axis0_fakequantkernel module from PR [OMNIML-4998] Per-expert weight quantization on TEGroupedLinear #1671 (51b4c9226reuses it via the new_GroupedAxis0FakeQuantFnautograd adapter). No new PIP dependencies.sharded_state_dictTP=2/EP=2 test./claude reviewafter rebase ontomain.Additional Information
Related work:
🤖 Generated with Claude Code