Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,7 +1471,8 @@ def _is_fused_experts_module(module):

Detects the standardized HuggingFace transformers 5.0+ fused expert pattern:
``gate_up_proj`` (3-D parameter), ``down_proj`` (3-D parameter), ``num_experts``,
and ``act_fn``. Matches ``MixtralExperts``, ``Qwen2MoeExperts``,
and ``act_fn`` (or ``_apply_gate`` for clamped-swiglu experts such as
``MiniMaxM3VLExperts``). Matches ``MixtralExperts``, ``Qwen2MoeExperts``,
``Qwen3MoeExperts``, ``Qwen3_5MoeExperts``, ``DeepseekV3NaiveMoe``,
``JambaExperts``, ``OlmoeExperts``, etc.

Expand All @@ -1480,7 +1481,9 @@ def _is_fused_experts_module(module):
"""
if not hasattr(module, "gate_up_proj") or not hasattr(module, "down_proj"):
return False
if not hasattr(module, "num_experts") or not hasattr(module, "act_fn"):
if not hasattr(module, "num_experts") or not (
hasattr(module, "act_fn") or hasattr(module, "_apply_gate")
):
return False
gate_up = getattr(module, "gate_up_proj")
down = getattr(module, "down_proj")
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/torch/quantization/plugins/test_fused_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,21 @@ def test_module_missing_act_fn_not_detected(self):
module.num_experts = 4
assert _is_fused_experts_module(module) is False

def test_module_with_apply_gate_detected(self):
"""Clamped-swiglu experts (e.g. MiniMaxM3VLExperts) use _apply_gate instead of act_fn."""

class _ApplyGateExperts(nn.Module):
def __init__(self):
super().__init__()
self.gate_up_proj = nn.Parameter(torch.randn(4, 16, 8))
self.down_proj = nn.Parameter(torch.randn(4, 8, 16))
self.num_experts = 4

def _apply_gate(self, gate, up):
return up * torch.sigmoid(gate)

assert _is_fused_experts_module(_ApplyGateExperts()) is True

def test_sparse_moe_block_not_detected_as_fused(self):
block = _SyntheticSparseMoeBlock()
assert _is_fused_experts_module(block) is False
Expand Down