Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
NVIDIA Model Optimizer Changelog (Linux)
========================================

0.43 (2026-03-xx)
^^^^^^^^^^^^^^^^^

**New Features**

- User does not need to manually register MOE modules to cover experts calibration coverage in PTQ workflow.
- ``hf_ptq.py`` now saves the quantization summary and moe expert token count table to the export directory.

0.42 (2026-02-xx)
^^^^^^^^^^^^^^^^^

Expand Down
8 changes: 7 additions & 1 deletion examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
export_hf_checkpoint,
export_tensorrt_llm_checkpoint,
get_model_type,
save_expert_token_count_table,
)
from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model
from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration
Expand Down Expand Up @@ -726,7 +727,12 @@ def post_quantize(
"""

if args.verbose:
mtq.print_quant_summary(full_model)
try:
mtq.print_quant_summary(full_model, args.export_path)
save_expert_token_count_table(full_model, args.export_path)
except Exception as e:
print(f"Error saving quant summary: {e}")
Comment on lines +730 to +734
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The exception handler message says "Error saving quant summary" even if the failure came from saving the MoE HTML report. Consider making the message accurate (or logging which step failed) to avoid confusion when debugging export failures.

Suggested change
try:
mtq.print_quant_summary(full_model, args.export_path)
save_expert_token_count_table(full_model, args.export_path)
except Exception as e:
print(f"Error saving quant summary: {e}")
error_occurred = False
try:
mtq.print_quant_summary(full_model, args.export_path)
except Exception as e:
print(f"Error saving quant summary: {e}")
error_occurred = True
try:
save_expert_token_count_table(full_model, args.export_path)
except Exception as e:
print(f"Error saving expert token count table: {e}")
error_occurred = True
if error_occurred:

Copilot uses AI. Check for mistakes.
print("Continuing with generation...")
Comment on lines +730 to +735
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both print_quant_summary(..., args.export_path) and save_expert_token_count_table(..., args.export_path) write fixed filenames into the export directory. Under multi-process runs (e.g., torchrun), every rank will attempt to write the same files, risking races/corruption. Consider guarding these calls to rank-0 only (using the project’s distributed helpers) or including the rank in the filename.

Copilot uses AI. Check for mistakes.

# Run some samples
torch.cuda.empty_cache()
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .model_config import *
from .model_config_export import *
from .model_utils import *
from .moe_utils import *
from .plugins import *
from .transformer_engine import *
from .unified_export_hf import *
Expand Down
77 changes: 77 additions & 0 deletions modelopt/torch/export/moe_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for Mixture-of-Experts (MoE) model export."""

from pathlib import Path

import torch.nn as nn


def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | None = None):
"""Collect expert_token_count from all quantized MoE layers and save as an HTML table.

The table has rows for each MoE layer and columns for each expert, with cell values
showing the number of tokens routed to that expert during calibration.

Args:
model: The model containing quantized MoE layers with ``expert_token_count`` attributes.
output_dir: Directory to save the HTML file. Defaults to current directory.
"""
rows = []
for name, module in model.named_modules():
if hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0:
rows.append((name, module.expert_token_count))

if not rows:
return

num_experts = rows[0][1].shape[0]
assert all(r[1].shape[0] == num_experts for r in rows), (
"All MoE layers must have the same number of experts"
)
Comment on lines +42 to +44
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid using assert for runtime validation here: Python can strip asserts with -O, and this will hard-crash export for models with heterogeneous expert counts. Prefer raising a user-facing exception (e.g., ValueError) or handling differing expert counts by generating a ragged table / padding columns per-layer.

Suggested change
assert all(r[1].shape[0] == num_experts for r in rows), (
"All MoE layers must have the same number of experts"
)
if not all(r[1].shape[0] == num_experts for r in rows):
layer_names = [r[0] for r in rows]
expert_counts = [r[1].shape[0] for r in rows]
raise ValueError(
"All MoE layers must have the same number of experts; "
f"found expert counts {expert_counts} for layers {layer_names}."
)

Copilot uses AI. Check for mistakes.
html_parts = [
"<html><head><style>",
"table { border-collapse: collapse; font-family: monospace; }",
"th, td { border: 1px solid #ccc; padding: 4px 8px; text-align: right; }",
"th { background: #f0f0f0; }",
"</style></head><body>",
"<h2>Expert Token Counts (per MoE layer)</h2>",
"<table><tr><th>Layer/Expert</th>",
]
html_parts.extend(f"<th>{i}</th>" for i in range(num_experts))
html_parts.append("</tr>")

for name, counts in rows:
avg = counts.float().mean().item()
html_parts.append(f"<tr><td>{name}</td>")
for c in counts.tolist():
Comment on lines +57 to +60
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module name is interpolated into HTML without escaping. For models that come from remote/custom code, a module name containing </& could inject markup into the report. Escape name (and any other dynamic text) before writing HTML (e.g., via html.escape).

Copilot uses AI. Check for mistakes.
if avg > 0 and c < avg * 0.05:
style = ' style="background: #ff6666;"'
elif avg > 0 and c < avg * 0.1:
style = ' style="background: #ffcccc;"'
else:
style = ""
html_parts.append(f"<td{style}>{c}</td>")
html_parts.append("</tr>")

html_parts.append("</table></body></html>")
html_content = "\n".join(html_parts)

if output_dir is None:
output_dir = Path(".")
output_path = Path(output_dir) / ".moe.html"
output_path.write_text(html_content, encoding="utf-8")
print(f"\033[1mExpert token count table saved to {output_path}\033[0m")
26 changes: 19 additions & 7 deletions modelopt/torch/quantization/model_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,14 +508,26 @@ def enable_quantizer(model: nn.Module, wildcard_or_filter_func: str | Callable):


@atomic_print
def print_quant_summary(model: nn.Module):
def print_quant_summary(model: nn.Module, output_dir: str | None = None):
"""Print summary of all quantizer modules in the model."""
count = 0
for name, mod in model.named_modules():
if isinstance(mod, TensorQuantizer):
print(f"{name:80} {mod}")
count += 1
print(f"{count} TensorQuantizers found in model")
lines = [
f"{name:80} {mod}"
for name, mod in model.named_modules()
if isinstance(mod, TensorQuantizer)
]
lines.append(f"{len(lines)} TensorQuantizers found in model")

if output_dir:
path = (
output_dir.joinpath(".quant_summary.txt")
if hasattr(output_dir, "joinpath")
else f"{output_dir}/.quant_summary.txt"
)
Comment on lines +511 to +525
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print_quant_summary now accepts Path-like output_dir values (via joinpath), but the type annotation only allows str | None. Update the annotation/docstring to include Path (or os.PathLike) to match the actual supported inputs.

Copilot uses AI. Check for mistakes.
with open(path, "w", encoding="utf-8") as f:
f.write("\n".join(lines) + "\n")
print(f"\033[1mQuant summary saved to {path}\033[0m")
else:
print("\n".join(lines))


def fold_weight(model: nn.Module):
Expand Down
167 changes: 98 additions & 69 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,20 +450,56 @@ class _QuantSparseMoe(QuantModule):
"""

def _setup(self):
pass
num_experts = 0
if hasattr(self, "gate") and hasattr(self.gate, "num_experts"):
num_experts = self.gate.num_experts
elif hasattr(self, "num_experts"):
num_experts = self.num_experts
elif hasattr(self, "experts") and hasattr(self.experts, "num_experts"):
num_experts = self.experts.num_experts

self.expert_token_count = torch.zeros(num_experts, dtype=torch.long, device="cpu")
self._count_expert_tokens = False

if num_experts == 0:
warnings.warn(
f"{self.__class__.__name__}: could not resolve num_experts; "
"expert routing will not be tracked for this layer."
)
return

if hasattr(self, "gate"):
self.gate.register_forward_hook(self._gate_forward_hook)

def _gate_forward_hook(self, module, input, output):
if not self._count_expert_tokens:
return
with torch.no_grad():
if isinstance(output, tuple) and len(output) >= 3:
# v5.x TopKRouter: returns (logits, scores, indices)
indices = output[2]
else:
# v4.x nn.Linear gate: returns logits tensor
logits = output if not isinstance(output, tuple) else output[0]
top_k = self.gate.top_k if hasattr(self.gate, "top_k") else self.top_k
_, indices = torch.topk(logits.float(), top_k, dim=-1)
counts = torch.bincount(
indices.reshape(-1).cpu(), minlength=len(self.expert_token_count)
)
self.expert_token_count += counts

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if any(getattr(m, "_if_calib", False) for m in self.experts.modules()):
is_calib = any(getattr(m, "_if_calib", False) for m in self.experts.modules())
if is_calib:
# If any of the experts are in calibration mode, we will forward all tokens to all experts
# This is used only for calibration, we need to re-calculate the actual outputs again using
# the original top_k
if TRANSFORMERS_VERSION_GE_5_0:
assert hasattr(self, "gate")
# Path for transformers >= 5.0
original_top_k = self.gate.topk
self.gate.topk = self.gate.num_experts
assert hasattr(self, "gate") and hasattr(self.gate, "top_k")
original_top_k = self.gate.top_k
self.gate.top_k = self.gate.num_experts
super().forward(hidden_states)
self.gate.topk = original_top_k
self.gate.top_k = original_top_k
else:
# Path for transformers < 5.0
original_top_k = self.top_k
Expand All @@ -475,7 +511,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise ValueError(f"Could not find num_experts in module {self}")
super().forward(hidden_states)
self.top_k = original_top_k
return super().forward(hidden_states)
# Enable counting only for the real-routing forward during calibration
self._count_expert_tokens = is_calib
output = super().forward(hidden_states)
self._count_expert_tokens = False
return output


class _QuantLlama4TextExperts(QuantModule):
Expand Down Expand Up @@ -765,10 +805,7 @@ def unpack_weight(self):


try:
from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe

if Llama4TextMoe not in QuantModuleRegistry:
QuantModuleRegistry.register({Llama4TextMoe: "hf.Llama4TextMoe"})(_QuantSparseMoe)
from transformers.models.llama4.modeling_llama4 import Llama4TextExperts

if Llama4TextExperts not in QuantModuleRegistry:
QuantModuleRegistry.register({Llama4TextExperts: "hf.Llama4TextExperts"})(
Expand All @@ -791,16 +828,6 @@ def unpack_weight(self):
except ImportError:
pass

try:
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

if MixtralSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({MixtralSparseMoeBlock: "hf.MixtralSparseMoeBlock"})(
_QuantSparseMoe
)
except ImportError:
pass

try:
from transformers.models.falcon.modeling_falcon import FalconLinear

Expand All @@ -809,36 +836,6 @@ def unpack_weight(self):
except ImportError:
pass

try:
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock

if Qwen3MoeSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen3MoeSparseMoeBlock: "hf.Qwen3MoeSparseMoeBlock"})(
_QuantSparseMoe
)
except ImportError:
pass

try:
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock

if Qwen2MoeSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen2MoeSparseMoeBlock: "hf.Qwen2MoeSparseMoeBlock"})(
_QuantSparseMoe
)
except ImportError:
pass

try:
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock

if Qwen3NextSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen3NextSparseMoeBlock: "hf.Qwen3NextSparseMoeBlock"})(
_QuantSparseMoe
)
except ImportError:
pass

try:
from compressed_tensors.linear.compressed_linear import CompressedLinear

Expand All @@ -850,15 +847,7 @@ def unpack_weight(self):
pass

try:
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeTextExperts,
Qwen3VLMoeTextSparseMoeBlock,
)

if Qwen3VLMoeTextSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register(
{Qwen3VLMoeTextSparseMoeBlock: "hf.Qwen3VLMoeTextSparseMoeBlock"}
)(_QuantSparseMoe)
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts

if Qwen3VLMoeTextExperts not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen3VLMoeTextExperts: "hf.Qwen3VLMoeTextExperts"})(
Expand Down Expand Up @@ -989,15 +978,55 @@ def register_falcon_linears_on_the_fly(model):
QuantModuleRegistry.register({linear_type: linear_type.__name__})(_QuantLinear)


def register_minimax_m2_moe_on_the_fly(model):
"""Register MiniMax M2 MoE modules as a QUANT_MODULE.
def _is_sparse_moe_block(module):
"""Check if a module is structurally a sparse MoE block compatible with _QuantSparseMoe.

MiniMax M2 MoE modules are defined in the model card, so we need to register them on the fly.
All HuggingFace MoE blocks (Mixtral, Qwen3Moe, Qwen2Moe, Qwen3Next, Llama4, MiniMax, etc.)
share a common structural pattern: a ``gate`` (TopKRouter) sub-module with routing attributes
(``top_k`` and ``num_experts``), and an ``experts`` sub-module.

This function detects that pattern instead of relying on class names, making it forward-compatible
with new MoE architectures. Some MoE models (e.g. Glm4MoeMoE) have ``gate`` and ``experts`` but
use a different routing interface (``n_routed_experts`` instead of ``num_experts``, custom
``route_tokens_to_experts``), so we require ``num_experts`` to be present to avoid false positives.
"""
if type(model).__name__ in ["MiniMaxM2ForCausalLM"]:
moe_type = type(model.model.layers[0].block_sparse_moe)
if QuantModuleRegistry.get(moe_type) is None:
QuantModuleRegistry.register({moe_type: moe_type.__name__})(_QuantSparseMoe)
if not hasattr(module, "experts"):
return False

# Primary: gate sub-module has topk/top_k + num_experts (standard TopKRouter pattern)
if hasattr(module, "gate"):
gate = module.gate
has_topk = hasattr(gate, "top_k")
has_num_experts = hasattr(gate, "num_experts")
if has_topk and has_num_experts:
return True

# Fallback: top_k + num_experts on the block itself (older transformers, e.g. v4.x Qwen3Next)
return hasattr(module, "top_k") and hasattr(module, "num_experts")


def register_sparse_moe_on_the_fly(model):
"""Auto-detect and register MOE modules as _QuantSparseMoe.

Walks the model tree, identifies MoE blocks by their structural attributes
(``gate`` + ``experts``), and registers unregistered ones with ``_QuantSparseMoe``.
"""
registered_types = set()
for name, module in model.named_modules():
mod_type = type(module)

# Avoid duplicate registration: skip if we already processed this type
# in this walk, or if it was previously registered in the QuantModuleRegistry.
if mod_type in registered_types or QuantModuleRegistry.get(mod_type) is not None:
continue
Comment on lines +1014 to +1021
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also avoid duplication of checking the same mod_type?, e.g.,

Suggested change
registered_types = set()
for name, module in model.named_modules():
mod_type = type(module)
# Avoid duplicate registration: skip if we already processed this type
# in this walk, or if it was previously registered in the QuantModuleRegistry.
if mod_type in registered_types or QuantModuleRegistry.get(mod_type) is not None:
continue
checked_types = set()
for name, module in model.named_modules():
mod_type = type(module)
# Avoid duplicate registration: skip if we already processed this type
# in this walk, or if it was previously registered in the QuantModuleRegistry.
if mod_type in checked_types or QuantModuleRegistry.get(mod_type) is not None:
continue
checked_types.add(mod_type)


if _is_sparse_moe_block(module):
print(
f"\033[1mDetected MOE module '{name}' of type {mod_type.__name__}, "
f"registering with _QuantSparseMoe.\033[0m"
)
Comment on lines +1024 to +1027
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Library code should avoid printing directly (especially with ANSI escape codes) because it pollutes stdout in normal use and will print once per rank under distributed execution. Prefer using the project logging helpers (e.g., print_rank_0/warn_rank_0) or a logging logger, and consider gating this behind a verbosity/debug flag.

Copilot uses AI. Check for mistakes.
QuantModuleRegistry.register({mod_type: f"hf.{mod_type.__name__}"})(_QuantSparseMoe)
registered_types.add(mod_type)


def _is_supported_hf_model(model):
Expand Down Expand Up @@ -1065,7 +1094,7 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model):
[
register_falcon_linears_on_the_fly,
register_dbrx_moe_on_the_fly,
register_minimax_m2_moe_on_the_fly,
register_sparse_moe_on_the_fly,
register_hf_attentions_on_the_fly,
convert_hf_parallel_linears_on_the_fly,
]
Expand Down
Loading