Skip to content

[OMNIML-5072] Triton fakequant adapter for N-quantizer per-expert path#1717

Draft
hychiang-git wants to merge 11 commits into
mainfrom
hungyuehc/omniml-5072
Draft

[OMNIML-5072] Triton fakequant adapter for N-quantizer per-expert path#1717
hychiang-git wants to merge 11 commits into
mainfrom
hungyuehc/omniml-5072

Conversation

@hychiang-git

Copy link
Copy Markdown
Contributor

What does this PR do?

Type of change: new feature

Implements OMNIML-5072, built on top of PR #1550 (the N-quantizer per-expert foundation, still WIP). Three additions:

  1. Triton fakequant dispatch for the N-quantizer per-expert path. Wires the single-launch grouped_axis0_fakequant kernel from PR #1671 (the One-Vec-quanitzer path) into _QuantTEGroupedLinear.te_grouped_quantized_linear_fn when _per_expert_weight_quantizer == True. Soft-gated behind _triton_kernels.IS_AVAILABLE and q._if_calib; falls back to the existing cuda_ext per-quantizer loop when the gate is False.
  2. Cached _gather_per_expert_amax helper eliminates per-forward O(N) Python overhead (walks the N submodules from PR WIP Support per expert amax in TEGroupedMLP #1550 and stacks the N scalar _amax buffers; lazily cached, invalidated from modelopt_post_restore).
  3. sharded_state_dict save + EP-aware load on _QuantMegatronTEGroupedLinear's N-quantizer case — gather-once-cache pattern adapted to N scalar _amax buffers across the EP group, so the dist-ckpt round-trip that PR [OMNIML-4998] Per-expert weight quantization on TEGroupedLinear #1671 ships for One-Vec-quanitzer also works on the N-quantizer foundation from PR WIP Support per expert amax in TEGroupedMLP #1550.

This PR is stacked on top of PR #1550 (jennifchen/te_per_expert, still WIP). The diff against main includes PR #1550's commits underneath; the OMNIML-5072-specific work is the top 6 commits (fd77b53d8..51b4c9226). Once PR #1550 lands, rebase onto main to shrink the review surface.

Usage

import os

# Enable the N-quantizer per-expert path on TEGroupedMLP. With this PR loaded,
# the Triton kernel dispatch activates automatically when triton is available.
os.environ["MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER"] = "1"

import modelopt.torch.quantization as mtq
# ... build a Megatron model with TEGroupedMLP ...
mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_fn)
# Each TEGroupedLinear now has N weight_quantizer_i submodules;
# forward uses Triton `grouped_axis0_fakequant` when available, otherwise
# falls back to cuda_ext per-quantizer.

Testing

GPU-validated on aws-cmh (B300, nemo:26.02 / nemo:25.11 containers):

  • Parity test (4 pytest cases, 6.14s): N-quantizer-Triton vs N-quantizer-cuda_ext on Ultra production shape (N=32, [5120, 8192] bf16) — fwd within 1-ULP floor, bwd bit-exact under pass_through_bwd=True. Test at tests/gpu/torch/quantization/plugins/test_te_grouped_triton_parity.py.
  • Microbench (4 cells: Nano / Super / Ultra at EP=4 and EP=8): N-quantizer-Triton ≈ One-Vec-quanitzer + Triton (PR [OMNIML-4998] Per-expert weight quantization on TEGroupedLinear #1671) within ~10%; the N-quantizer vs One-Vec-quanitzer topology has no measurable perf effect once both share the Triton kernel. Full matrix on OMNIML-5064.
  • Dist-ckpt round-trip (2 pytest cases, 92.99s) at TP=2/EP=2 for both FP8_DEFAULT_CFG and NVFP4_DEFAULT_CFG. Test at tests/gpu_megatron/torch/quantization/plugins/test_megatron.py::test_te_grouped_n_modules_sharded_state_dict.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ — Triton path is soft-gated; when _triton_kernels.IS_AVAILABLE is False or during calibration (q._if_calib), the original cuda_ext per-quantizer loop runs unchanged.
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ — Reused the grouped_axis0_fakequant kernel module from PR [OMNIML-4998] Per-expert weight quantization on TEGroupedLinear #1671 (51b4c9226 reuses it via the new _GroupedAxis0FakeQuantFn autograd adapter). No new PIP dependencies.
  • Did you write any new necessary tests?: ✅ — Parity test + N-quantizer sharded_state_dict TP=2/EP=2 test.
  • Did you update Changelog?: ❌ — Stacked on top of WIP PR WIP Support per expert amax in TEGroupedMLP #1550; changelog entry deferred until both are merge-ready (will land with the rebase).
  • Did you get Claude approval on this PR?: ❌ — Pending; will run /claude review after rebase onto main.

Additional Information

Related work:

  • OMNIML-5072 — this ticket.
  • OMNIML-5064 — N-quantizer vs One-Vec-quanitzer comparison study; full microbench matrix here.
  • PR #1550 — N-quantizer foundation (WIP); this PR is stacked on top.
  • PR #1671 — One-Vec-quanitzer + Triton kernel (the kernel this PR reuses).

🤖 Generated with Claude Code

jenchen13 and others added 10 commits May 27, 2026 15:41
…nfra fixes

modelopt/torch/quantization/plugins/transformer_engine.py:
  MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER=1 opts into per-gemm
  weight_quantizer_0..N-1 inside _QuantTEGroupedLinear (deepcopied from
  the shared weight_quantizer). Lets TEGroupedMLP recover per-expert
  amax granularity, matching SequentialMLP's default behavior.

modelopt/torch/distill/plugins/megatron.py:
  LogitsKLLoss.forward prints student/teacher logit stats (mean/std/
  min/max/shape) on rank 0 each call. Diagnostic for the QAD loss-spike
  investigation — confirms which spec produces which logits without
  changing the KL math.

tests/gpu_megatron/torch/quantization/plugins/test_megatron.py:
  New test_te_grouped_vs_sequential_default_amax + ..._default_loss
  cover the structural amax asymmetry between TEGroupedMLP and
  SequentialMLP (TEGrouped per-linear amax = max-over-Sequential-experts
  amax) and a finiteness sanity check on the resulting quant error.

tools/launcher/common/service_utils.sh:
  - Fall back to SLURM_PROCID / SLURM_LOCALID when PMIX_*/OMPI_* are
    unset, so `[[ "$mpi_local_rank" -eq 0 ]]` doesn't silently pass on
    every rank under plain srun.
  - util_install_extra_dep: per-node marker so concurrent ranks wait
    for rank 0 to finish installing (concurrent pip on a shared FS
    leaves a broken state); also installs nvidia-resiliency-ext.

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
- transformer_engine.py: dedup `import copy`/`import os` left over from the
  rebase, sort the four imports alphabetically.
- transformer_engine.py: comment near the per-expert weight_quantizer setup
  explaining that base modelopt_post_restore won't re-calibrate the
  weight_quantizer_{i} modules, so save/restore is only safe when TP/EP is
  unchanged. Per-channel _amax shape depends on the TP-sliced output dim.
- service_utils.sh: drop the duplicated mpi_rank / mpi_local_rank
  re-assignments — main already carries the SLURM fallback, the extra two
  lines were leftover rebase noise.

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Copies the single-launch tensor-of-pointers fake-quant kernel module from
hungyuehc/omniml-4998-umbrella (kernel landed in 0bf4838, libdevice.rint
rounding refined in 1080e68). Kernel file is unchanged.

Wires the new module into modelopt.torch.kernels.quantization.gemm via the
existing IS_AVAILABLE/triton-import block in __init__.py.

The transformer_engine.py adapter that calls this kernel from the N-modules
per-expert path follows in the next commit.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…for N-modules per-expert path

Wires grouped_axis0_fakequant into _QuantTEGroupedLinear's forward when
_per_expert_weight_quantizer is on. Adds:

- _GroupedAxis0FakeQuantFn (torch.autograd.Function): single forward call
  for N expert weights; backward honors pass_through_bwd=True (identity)
  and dispatches to the Triton bwd kernel when False.
- _gather_per_expert_amax: stacks N weight_quantizer_i._amax scalars into
  a [N] fp32 vector matching the kernel's amax-input contract.
- _can_use_triton_per_expert_path: soft-gate on IS_AVAILABLE, all per-expert
  quantizers being TensorQuantizer with _amax set, and not currently
  calibrating (q._if_calib).
- te_grouped_quantized_linear_fn now branches: Triton path when gate passes;
  original per-quantizer cuda_ext loop otherwise.

Replaces N cuda_ext kernel launches with 1 Triton launch on the forward
hot path. No behavior change when the env var opt-in is off.

Untested at runtime yet; AC2 parity test (Ultra production shape, N=32,
pass_through_bwd=True) is the next step.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds tests/gpu/torch/quantization/plugins/test_te_grouped_triton_parity.py
mirroring nmm-sandbox/studies/omniml-5064/microbench/parity_a_vs_btriton.py
as a pytest in the modelopt tests/ surface.

Two checks at each parameterized shape (N=4/8/32):
- Forward parity within 1 ULP. Known rounding-mode mismatch floor between
  Triton's libdevice.rint and cuda_ext's banker's rounding at some bf16
  boundary values.
- Backward parity bit-exact under pass_through_bwd=True. Both paths must
  return grad_out unchanged regardless of forward kernel.

Plus a slow-marked Ultra production shape variant (N=32, [5120, 8192] bf16)
for full-scale validation. Marked slow because the unquantized + quantized
+ gradient copies of 32 expert weights at that shape use ~5 GB of GPU
memory; CI default-suite stays on the smaller parameterized cases.

Test not yet run — requires GPU + container with modelopt installed.
Expected to PASS per the matching standalone parity_a_vs_btriton.py output
on B's path after libdevice.rint refinement (1080e68): forward parity
within 1 ULP, backward bit-exact under pass_through_bwd=True.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…les per-expert path

Adds dist-checkpoint support to A's N-submodule per-expert weight quantization
path on _QuantMegatronTEGroupedLinear. Mirrors B's gather-once-cache pattern
(hungyuehc/omniml-4998-umbrella) adapted for A's storage layout — N separate
weight_quantizer_i._amax scalars across N submodules instead of one [N,1,1]
buffer on a single quantizer.

Methods added:
- _ep_group: returns the EP group if initialized and world_size > 1.
- _gather_global_per_expert_amax_n_modules: stacks N local scalar amaxes from
  the submodules, all-gathers across EP, returns [N_global]. None when the
  layer is not in per-expert mode.
- sharded_state_dict: caches the gathered global vector before delegating to
  super so the EP collective completes BEFORE Megatron's dist-checkpoint save
  fires default-PG ALLGATHER metadata exchanges (interleaving EP + default-PG
  collectives deadlocks NCCL — codified in
  [[feedback-no-custom-collectives-in-dist-ckpt-save]]).

Methods replaced:
- _process_quantizer_amax: emits the cached global [N_global] vector under
  every weight_quantizer_i._amax key in per-expert mode. Suboptimal disk
  usage (N copies of same vector per layer) but mirrors B's pattern and
  avoids surgery into the base-class state-dict iteration.
- _load_from_state_dict: preserves the existing _extra_state{i} filter and
  adds the per-expert narrow — pulls element (ep_rank * N_local + i) out of
  each saved [N_global] vector for the i-th local submodule. Falls through
  unchanged when v.numel() != global_size (legacy / EP=1 save format).

Validated via the AC4 parity test (next commit) at TP=2, EP=2 across
FP8_DEFAULT_CFG and NVFP4_DEFAULT_CFG.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…dict round-trip

Adds test_te_grouped_n_modules_sharded_state_dict parameterized over
FP8_DEFAULT_CFG and NVFP4_DEFAULT_CFG. Builds a TEGroupedMLP model at
TP=2 EP=2 num_moe_experts=4, quantizes with the
MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER=1 env-var path, saves dist-ckpt,
restores into a fresh meta-device model, asserts equivalence via the
existing sharded_state_dict_test_helper.

Mirrors the layout of B's test_te_grouped_per_expert_sharded_state_dict
(hidden_size=256, dist_workers fixture) but triggers A's env-var-gated
per-expert path instead of B's axis=0 quant_cfg knob.

Also cherry-picks the OMNIML-5030 sequence_parallel fix that A's branch
predates: get_mcore_gpt_model in tests/_test_utils/torch/megatron/models.py
gains a sequence_parallel parameter that threads through TransformerConfig,
and the non-hybrid call site in _gpt_model_provider passes
sequence_parallel=(tp_size > 1). Without this, Megatron-Core ValueError-s
during MoE + TP > 1 model construction. The hybrid path on A's branch
already had this fix.

Validated on aws-cmh slurm — 2 passed, 92.99s wall (job 537620).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…d N-loop overhead

The Triton dispatch in te_grouped_quantized_linear_fn calls
_gather_per_expert_amax on every forward. The gather walks N submodules
(O(N) Python overhead) and stacks N scalar _amax buffers into a [N] fp32
vector. The result is invariant across forwards (per-expert _amax does
not change outside calibration, and the gate already blocks the Triton
path when q._if_calib is True on any quantizer).

Caching the gathered tensor lazily on first call eliminates the per-forward
overhead. Invalidation hook _invalidate_per_expert_amax_cache is called from
modelopt_post_restore (where dist-ckpt reload may have changed _amax).

Measured impact on OMNIML-5064 microbench (Nemotron Nano EP=4, N=32):

  fwd_us:   1918 (no-cache)  ->  1244 (cached)   (35% drop)
  step_us:  3444 (no-cache)  ->  2785 (cached)   (19% drop)
  vs Btriton5 (B's path with same Triton kernel):
            ATriton-cached 1244 vs Btriton5 1208  ->  effectively tied
            ATriton-cached step 2785 vs Btriton5 step 2815 -> ATriton edges B

Without the cache the gap to Btriton5 grows with N (1.59x at N=32, 2.18x
at N=128, observed in the no-cache nano_ep4 + super_ep4 runs). With the
cache, the gap closes to within run-to-run noise.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 14, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented Jun 14, 2026

Copy link
Copy Markdown
Contributor

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: c3a4143a-ff8a-466c-b50e-d8bfd23d293b

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch hungyuehc/omniml-5072

Comment @coderabbitai help to get the list of available commands and usage tips.

The doc was a placeholder for B-triton's pre-validation tracking (commit
0bf4838 on PR #1671); validation has since landed and the file is no
longer load-bearing. The kernel module references it from a docstring
line that is stale and will be cleaned up by PR #1671's review pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants