diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 97e13f419f9..a888f50319a 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -1471,7 +1471,8 @@ def _is_fused_experts_module(module): Detects the standardized HuggingFace transformers 5.0+ fused expert pattern: ``gate_up_proj`` (3-D parameter), ``down_proj`` (3-D parameter), ``num_experts``, - and ``act_fn``. Matches ``MixtralExperts``, ``Qwen2MoeExperts``, + and ``act_fn`` (or ``_apply_gate`` for clamped-swiglu experts such as + ``MiniMaxM3VLExperts``). Matches ``MixtralExperts``, ``Qwen2MoeExperts``, ``Qwen3MoeExperts``, ``Qwen3_5MoeExperts``, ``DeepseekV3NaiveMoe``, ``JambaExperts``, ``OlmoeExperts``, etc. @@ -1480,7 +1481,9 @@ def _is_fused_experts_module(module): """ if not hasattr(module, "gate_up_proj") or not hasattr(module, "down_proj"): return False - if not hasattr(module, "num_experts") or not hasattr(module, "act_fn"): + if not hasattr(module, "num_experts") or not ( + hasattr(module, "act_fn") or hasattr(module, "_apply_gate") + ): return False gate_up = getattr(module, "gate_up_proj") down = getattr(module, "down_proj") diff --git a/tests/unit/torch/quantization/plugins/test_fused_experts.py b/tests/unit/torch/quantization/plugins/test_fused_experts.py index ce23f7a51d5..ff9fd48fa26 100644 --- a/tests/unit/torch/quantization/plugins/test_fused_experts.py +++ b/tests/unit/torch/quantization/plugins/test_fused_experts.py @@ -152,6 +152,21 @@ def test_module_missing_act_fn_not_detected(self): module.num_experts = 4 assert _is_fused_experts_module(module) is False + def test_module_with_apply_gate_detected(self): + """Clamped-swiglu experts (e.g. MiniMaxM3VLExperts) use _apply_gate instead of act_fn.""" + + class _ApplyGateExperts(nn.Module): + def __init__(self): + super().__init__() + self.gate_up_proj = nn.Parameter(torch.randn(4, 16, 8)) + self.down_proj = nn.Parameter(torch.randn(4, 8, 16)) + self.num_experts = 4 + + def _apply_gate(self, gate, up): + return up * torch.sigmoid(gate) + + assert _is_fused_experts_module(_ApplyGateExperts()) is True + def test_sparse_moe_block_not_detected_as_fused(self): block = _SyntheticSparseMoeBlock() assert _is_fused_experts_module(block) is False