-
Notifications
You must be signed in to change notification settings - Fork 279
Auto detect MOE layers #900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
919be0f
8baeaaf
7da77b9
2e29ee7
4b4ef63
9b9377a
0126ce7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,6 +53,7 @@ | |
| export_hf_checkpoint, | ||
| export_tensorrt_llm_checkpoint, | ||
| get_model_type, | ||
| save_expert_token_count_table, | ||
| ) | ||
| from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model | ||
| from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration | ||
|
|
@@ -726,7 +727,12 @@ def post_quantize( | |
| """ | ||
|
|
||
| if args.verbose: | ||
| mtq.print_quant_summary(full_model) | ||
| try: | ||
| mtq.print_quant_summary(full_model, args.export_path) | ||
| save_expert_token_count_table(full_model, args.export_path) | ||
| except Exception as e: | ||
| print(f"Error saving quant summary: {e}") | ||
| print("Continuing with generation...") | ||
|
Comment on lines
+730
to
+735
|
||
|
|
||
| # Run some samples | ||
| torch.cuda.empty_cache() | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,77 @@ | ||||||||||||||||||||||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||||||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||||||||
| # | ||||||||||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||||||||||||
| # You may obtain a copy of the License at | ||||||||||||||||||||||
| # | ||||||||||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||
| # | ||||||||||||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| """Utilities for Mixture-of-Experts (MoE) model export.""" | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| from pathlib import Path | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| import torch.nn as nn | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | None = None): | ||||||||||||||||||||||
| """Collect expert_token_count from all quantized MoE layers and save as an HTML table. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| The table has rows for each MoE layer and columns for each expert, with cell values | ||||||||||||||||||||||
| showing the number of tokens routed to that expert during calibration. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Args: | ||||||||||||||||||||||
| model: The model containing quantized MoE layers with ``expert_token_count`` attributes. | ||||||||||||||||||||||
| output_dir: Directory to save the HTML file. Defaults to current directory. | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| rows = [] | ||||||||||||||||||||||
| for name, module in model.named_modules(): | ||||||||||||||||||||||
| if hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0: | ||||||||||||||||||||||
| rows.append((name, module.expert_token_count)) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if not rows: | ||||||||||||||||||||||
| return | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| num_experts = rows[0][1].shape[0] | ||||||||||||||||||||||
| assert all(r[1].shape[0] == num_experts for r in rows), ( | ||||||||||||||||||||||
| "All MoE layers must have the same number of experts" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
Comment on lines
+42
to
+44
|
||||||||||||||||||||||
| assert all(r[1].shape[0] == num_experts for r in rows), ( | |
| "All MoE layers must have the same number of experts" | |
| ) | |
| if not all(r[1].shape[0] == num_experts for r in rows): | |
| layer_names = [r[0] for r in rows] | |
| expert_counts = [r[1].shape[0] for r in rows] | |
| raise ValueError( | |
| "All MoE layers must have the same number of experts; " | |
| f"found expert counts {expert_counts} for layers {layer_names}." | |
| ) |
Copilot
AI
Feb 19, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The module name is interpolated into HTML without escaping. For models that come from remote/custom code, a module name containing </& could inject markup into the report. Escape name (and any other dynamic text) before writing HTML (e.g., via html.escape).
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -508,14 +508,26 @@ def enable_quantizer(model: nn.Module, wildcard_or_filter_func: str | Callable): | |
|
|
||
|
|
||
| @atomic_print | ||
| def print_quant_summary(model: nn.Module): | ||
| def print_quant_summary(model: nn.Module, output_dir: str | None = None): | ||
| """Print summary of all quantizer modules in the model.""" | ||
| count = 0 | ||
| for name, mod in model.named_modules(): | ||
| if isinstance(mod, TensorQuantizer): | ||
| print(f"{name:80} {mod}") | ||
| count += 1 | ||
| print(f"{count} TensorQuantizers found in model") | ||
| lines = [ | ||
| f"{name:80} {mod}" | ||
| for name, mod in model.named_modules() | ||
| if isinstance(mod, TensorQuantizer) | ||
| ] | ||
| lines.append(f"{len(lines)} TensorQuantizers found in model") | ||
|
|
||
| if output_dir: | ||
| path = ( | ||
| output_dir.joinpath(".quant_summary.txt") | ||
| if hasattr(output_dir, "joinpath") | ||
| else f"{output_dir}/.quant_summary.txt" | ||
| ) | ||
|
Comment on lines
+511
to
+525
|
||
| with open(path, "w", encoding="utf-8") as f: | ||
| f.write("\n".join(lines) + "\n") | ||
| print(f"\033[1mQuant summary saved to {path}\033[0m") | ||
| else: | ||
| print("\n".join(lines)) | ||
|
|
||
|
|
||
| def fold_weight(model: nn.Module): | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -450,20 +450,56 @@ class _QuantSparseMoe(QuantModule): | |||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def _setup(self): | ||||||||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||||||
| num_experts = 0 | ||||||||||||||||||||||||||||||||||||
| if hasattr(self, "gate") and hasattr(self.gate, "num_experts"): | ||||||||||||||||||||||||||||||||||||
| num_experts = self.gate.num_experts | ||||||||||||||||||||||||||||||||||||
| elif hasattr(self, "num_experts"): | ||||||||||||||||||||||||||||||||||||
| num_experts = self.num_experts | ||||||||||||||||||||||||||||||||||||
| elif hasattr(self, "experts") and hasattr(self.experts, "num_experts"): | ||||||||||||||||||||||||||||||||||||
| num_experts = self.experts.num_experts | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| self.expert_token_count = torch.zeros(num_experts, dtype=torch.long, device="cpu") | ||||||||||||||||||||||||||||||||||||
| self._count_expert_tokens = False | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if num_experts == 0: | ||||||||||||||||||||||||||||||||||||
| warnings.warn( | ||||||||||||||||||||||||||||||||||||
| f"{self.__class__.__name__}: could not resolve num_experts; " | ||||||||||||||||||||||||||||||||||||
| "expert routing will not be tracked for this layer." | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if hasattr(self, "gate"): | ||||||||||||||||||||||||||||||||||||
| self.gate.register_forward_hook(self._gate_forward_hook) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def _gate_forward_hook(self, module, input, output): | ||||||||||||||||||||||||||||||||||||
| if not self._count_expert_tokens: | ||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||
| with torch.no_grad(): | ||||||||||||||||||||||||||||||||||||
| if isinstance(output, tuple) and len(output) >= 3: | ||||||||||||||||||||||||||||||||||||
| # v5.x TopKRouter: returns (logits, scores, indices) | ||||||||||||||||||||||||||||||||||||
| indices = output[2] | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| # v4.x nn.Linear gate: returns logits tensor | ||||||||||||||||||||||||||||||||||||
| logits = output if not isinstance(output, tuple) else output[0] | ||||||||||||||||||||||||||||||||||||
| top_k = self.gate.top_k if hasattr(self.gate, "top_k") else self.top_k | ||||||||||||||||||||||||||||||||||||
| _, indices = torch.topk(logits.float(), top_k, dim=-1) | ||||||||||||||||||||||||||||||||||||
| counts = torch.bincount( | ||||||||||||||||||||||||||||||||||||
| indices.reshape(-1).cpu(), minlength=len(self.expert_token_count) | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| self.expert_token_count += counts | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||
| if any(getattr(m, "_if_calib", False) for m in self.experts.modules()): | ||||||||||||||||||||||||||||||||||||
| is_calib = any(getattr(m, "_if_calib", False) for m in self.experts.modules()) | ||||||||||||||||||||||||||||||||||||
| if is_calib: | ||||||||||||||||||||||||||||||||||||
| # If any of the experts are in calibration mode, we will forward all tokens to all experts | ||||||||||||||||||||||||||||||||||||
| # This is used only for calibration, we need to re-calculate the actual outputs again using | ||||||||||||||||||||||||||||||||||||
| # the original top_k | ||||||||||||||||||||||||||||||||||||
| if TRANSFORMERS_VERSION_GE_5_0: | ||||||||||||||||||||||||||||||||||||
| assert hasattr(self, "gate") | ||||||||||||||||||||||||||||||||||||
| # Path for transformers >= 5.0 | ||||||||||||||||||||||||||||||||||||
| original_top_k = self.gate.topk | ||||||||||||||||||||||||||||||||||||
| self.gate.topk = self.gate.num_experts | ||||||||||||||||||||||||||||||||||||
| assert hasattr(self, "gate") and hasattr(self.gate, "top_k") | ||||||||||||||||||||||||||||||||||||
| original_top_k = self.gate.top_k | ||||||||||||||||||||||||||||||||||||
| self.gate.top_k = self.gate.num_experts | ||||||||||||||||||||||||||||||||||||
| super().forward(hidden_states) | ||||||||||||||||||||||||||||||||||||
| self.gate.topk = original_top_k | ||||||||||||||||||||||||||||||||||||
| self.gate.top_k = original_top_k | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| # Path for transformers < 5.0 | ||||||||||||||||||||||||||||||||||||
| original_top_k = self.top_k | ||||||||||||||||||||||||||||||||||||
|
|
@@ -475,7 +511,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |||||||||||||||||||||||||||||||||||
| raise ValueError(f"Could not find num_experts in module {self}") | ||||||||||||||||||||||||||||||||||||
| super().forward(hidden_states) | ||||||||||||||||||||||||||||||||||||
| self.top_k = original_top_k | ||||||||||||||||||||||||||||||||||||
| return super().forward(hidden_states) | ||||||||||||||||||||||||||||||||||||
| # Enable counting only for the real-routing forward during calibration | ||||||||||||||||||||||||||||||||||||
| self._count_expert_tokens = is_calib | ||||||||||||||||||||||||||||||||||||
| output = super().forward(hidden_states) | ||||||||||||||||||||||||||||||||||||
| self._count_expert_tokens = False | ||||||||||||||||||||||||||||||||||||
| return output | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| class _QuantLlama4TextExperts(QuantModule): | ||||||||||||||||||||||||||||||||||||
|
|
@@ -765,10 +805,7 @@ def unpack_weight(self): | |||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||
| from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if Llama4TextMoe not in QuantModuleRegistry: | ||||||||||||||||||||||||||||||||||||
| QuantModuleRegistry.register({Llama4TextMoe: "hf.Llama4TextMoe"})(_QuantSparseMoe) | ||||||||||||||||||||||||||||||||||||
| from transformers.models.llama4.modeling_llama4 import Llama4TextExperts | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if Llama4TextExperts not in QuantModuleRegistry: | ||||||||||||||||||||||||||||||||||||
| QuantModuleRegistry.register({Llama4TextExperts: "hf.Llama4TextExperts"})( | ||||||||||||||||||||||||||||||||||||
|
|
@@ -791,16 +828,6 @@ def unpack_weight(self): | |||||||||||||||||||||||||||||||||||
| except ImportError: | ||||||||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||
| from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if MixtralSparseMoeBlock not in QuantModuleRegistry: | ||||||||||||||||||||||||||||||||||||
| QuantModuleRegistry.register({MixtralSparseMoeBlock: "hf.MixtralSparseMoeBlock"})( | ||||||||||||||||||||||||||||||||||||
| _QuantSparseMoe | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| except ImportError: | ||||||||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||
| from transformers.models.falcon.modeling_falcon import FalconLinear | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
@@ -809,36 +836,6 @@ def unpack_weight(self): | |||||||||||||||||||||||||||||||||||
| except ImportError: | ||||||||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||
| from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if Qwen3MoeSparseMoeBlock not in QuantModuleRegistry: | ||||||||||||||||||||||||||||||||||||
| QuantModuleRegistry.register({Qwen3MoeSparseMoeBlock: "hf.Qwen3MoeSparseMoeBlock"})( | ||||||||||||||||||||||||||||||||||||
| _QuantSparseMoe | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| except ImportError: | ||||||||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||
| from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if Qwen2MoeSparseMoeBlock not in QuantModuleRegistry: | ||||||||||||||||||||||||||||||||||||
| QuantModuleRegistry.register({Qwen2MoeSparseMoeBlock: "hf.Qwen2MoeSparseMoeBlock"})( | ||||||||||||||||||||||||||||||||||||
| _QuantSparseMoe | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| except ImportError: | ||||||||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||
| from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if Qwen3NextSparseMoeBlock not in QuantModuleRegistry: | ||||||||||||||||||||||||||||||||||||
| QuantModuleRegistry.register({Qwen3NextSparseMoeBlock: "hf.Qwen3NextSparseMoeBlock"})( | ||||||||||||||||||||||||||||||||||||
| _QuantSparseMoe | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| except ImportError: | ||||||||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||
| from compressed_tensors.linear.compressed_linear import CompressedLinear | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
@@ -850,15 +847,7 @@ def unpack_weight(self): | |||||||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||
| from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( | ||||||||||||||||||||||||||||||||||||
| Qwen3VLMoeTextExperts, | ||||||||||||||||||||||||||||||||||||
| Qwen3VLMoeTextSparseMoeBlock, | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if Qwen3VLMoeTextSparseMoeBlock not in QuantModuleRegistry: | ||||||||||||||||||||||||||||||||||||
| QuantModuleRegistry.register( | ||||||||||||||||||||||||||||||||||||
| {Qwen3VLMoeTextSparseMoeBlock: "hf.Qwen3VLMoeTextSparseMoeBlock"} | ||||||||||||||||||||||||||||||||||||
| )(_QuantSparseMoe) | ||||||||||||||||||||||||||||||||||||
| from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if Qwen3VLMoeTextExperts not in QuantModuleRegistry: | ||||||||||||||||||||||||||||||||||||
| QuantModuleRegistry.register({Qwen3VLMoeTextExperts: "hf.Qwen3VLMoeTextExperts"})( | ||||||||||||||||||||||||||||||||||||
|
|
@@ -989,15 +978,55 @@ def register_falcon_linears_on_the_fly(model): | |||||||||||||||||||||||||||||||||||
| QuantModuleRegistry.register({linear_type: linear_type.__name__})(_QuantLinear) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def register_minimax_m2_moe_on_the_fly(model): | ||||||||||||||||||||||||||||||||||||
| """Register MiniMax M2 MoE modules as a QUANT_MODULE. | ||||||||||||||||||||||||||||||||||||
| def _is_sparse_moe_block(module): | ||||||||||||||||||||||||||||||||||||
| """Check if a module is structurally a sparse MoE block compatible with _QuantSparseMoe. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| MiniMax M2 MoE modules are defined in the model card, so we need to register them on the fly. | ||||||||||||||||||||||||||||||||||||
| All HuggingFace MoE blocks (Mixtral, Qwen3Moe, Qwen2Moe, Qwen3Next, Llama4, MiniMax, etc.) | ||||||||||||||||||||||||||||||||||||
| share a common structural pattern: a ``gate`` (TopKRouter) sub-module with routing attributes | ||||||||||||||||||||||||||||||||||||
| (``top_k`` and ``num_experts``), and an ``experts`` sub-module. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| This function detects that pattern instead of relying on class names, making it forward-compatible | ||||||||||||||||||||||||||||||||||||
| with new MoE architectures. Some MoE models (e.g. Glm4MoeMoE) have ``gate`` and ``experts`` but | ||||||||||||||||||||||||||||||||||||
| use a different routing interface (``n_routed_experts`` instead of ``num_experts``, custom | ||||||||||||||||||||||||||||||||||||
| ``route_tokens_to_experts``), so we require ``num_experts`` to be present to avoid false positives. | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| if type(model).__name__ in ["MiniMaxM2ForCausalLM"]: | ||||||||||||||||||||||||||||||||||||
| moe_type = type(model.model.layers[0].block_sparse_moe) | ||||||||||||||||||||||||||||||||||||
| if QuantModuleRegistry.get(moe_type) is None: | ||||||||||||||||||||||||||||||||||||
| QuantModuleRegistry.register({moe_type: moe_type.__name__})(_QuantSparseMoe) | ||||||||||||||||||||||||||||||||||||
| if not hasattr(module, "experts"): | ||||||||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Primary: gate sub-module has topk/top_k + num_experts (standard TopKRouter pattern) | ||||||||||||||||||||||||||||||||||||
| if hasattr(module, "gate"): | ||||||||||||||||||||||||||||||||||||
| gate = module.gate | ||||||||||||||||||||||||||||||||||||
| has_topk = hasattr(gate, "top_k") | ||||||||||||||||||||||||||||||||||||
| has_num_experts = hasattr(gate, "num_experts") | ||||||||||||||||||||||||||||||||||||
| if has_topk and has_num_experts: | ||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Fallback: top_k + num_experts on the block itself (older transformers, e.g. v4.x Qwen3Next) | ||||||||||||||||||||||||||||||||||||
| return hasattr(module, "top_k") and hasattr(module, "num_experts") | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def register_sparse_moe_on_the_fly(model): | ||||||||||||||||||||||||||||||||||||
| """Auto-detect and register MOE modules as _QuantSparseMoe. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Walks the model tree, identifies MoE blocks by their structural attributes | ||||||||||||||||||||||||||||||||||||
| (``gate`` + ``experts``), and registers unregistered ones with ``_QuantSparseMoe``. | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| registered_types = set() | ||||||||||||||||||||||||||||||||||||
| for name, module in model.named_modules(): | ||||||||||||||||||||||||||||||||||||
| mod_type = type(module) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Avoid duplicate registration: skip if we already processed this type | ||||||||||||||||||||||||||||||||||||
| # in this walk, or if it was previously registered in the QuantModuleRegistry. | ||||||||||||||||||||||||||||||||||||
| if mod_type in registered_types or QuantModuleRegistry.get(mod_type) is not None: | ||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1014
to
+1021
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we also avoid duplication of checking the same mod_type?, e.g.,
Suggested change
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if _is_sparse_moe_block(module): | ||||||||||||||||||||||||||||||||||||
| print( | ||||||||||||||||||||||||||||||||||||
| f"\033[1mDetected MOE module '{name}' of type {mod_type.__name__}, " | ||||||||||||||||||||||||||||||||||||
| f"registering with _QuantSparseMoe.\033[0m" | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1024
to
+1027
|
||||||||||||||||||||||||||||||||||||
| QuantModuleRegistry.register({mod_type: f"hf.{mod_type.__name__}"})(_QuantSparseMoe) | ||||||||||||||||||||||||||||||||||||
| registered_types.add(mod_type) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def _is_supported_hf_model(model): | ||||||||||||||||||||||||||||||||||||
|
|
@@ -1065,7 +1094,7 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model): | |||||||||||||||||||||||||||||||||||
| [ | ||||||||||||||||||||||||||||||||||||
| register_falcon_linears_on_the_fly, | ||||||||||||||||||||||||||||||||||||
| register_dbrx_moe_on_the_fly, | ||||||||||||||||||||||||||||||||||||
| register_minimax_m2_moe_on_the_fly, | ||||||||||||||||||||||||||||||||||||
| register_sparse_moe_on_the_fly, | ||||||||||||||||||||||||||||||||||||
| register_hf_attentions_on_the_fly, | ||||||||||||||||||||||||||||||||||||
| convert_hf_parallel_linears_on_the_fly, | ||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The exception handler message says "Error saving quant summary" even if the failure came from saving the MoE HTML report. Consider making the message accurate (or logging which step failed) to avoid confusion when debugging export failures.