Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions modelopt/torch/export/plugins/mcore_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,18 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any]
)


class MambaConv1dMapping(CustomModuleMapping):
"""A custom module mapping for Mamba conv1d weights."""

def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}):
"""Create a mapping for old Conv1d modules and new direct conv1d parameters."""
super().__init__(
func_name="mamba_conv1d_remapping",
target_name_or_prefix=target_name_or_prefix,
func_kwargs=func_kwargs,
)


class QKVMerging(CustomModuleMapping):
"""A custom module mapping that merges Q, K, V."""

Expand Down
3 changes: 2 additions & 1 deletion modelopt/torch/export/plugins/mcore_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
CustomModuleMapping,
GroupedMLPMerging,
GroupedMLPSlicing,
MambaConv1dMapping,
NameRemapping,
QKVMerging,
QKVSlicing,
Expand Down Expand Up @@ -123,7 +124,7 @@
"A_log": NameRemapping("backbone.layers.{}.mixer.A_log"),
"D": NameRemapping("backbone.layers.{}.mixer.D"),
"dt_bias": NameRemapping("backbone.layers.{}.mixer.dt_bias"),
"conv1d": NameRemapping("backbone.layers.{}.mixer.conv1d."),
"conv1d": MambaConv1dMapping("backbone.layers.{}.mixer.conv1d."),
"in_proj": NameRemapping("backbone.layers.{}.mixer.in_proj."),
"out_proj": NameRemapping("backbone.layers.{}.mixer.out_proj."),
"fused_norm": NameRemapping("backbone.layers.{}.norm.weight"),
Expand Down
49 changes: 48 additions & 1 deletion modelopt/torch/export/unified_export_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import json
import os
import tempfile
import warnings
from collections import OrderedDict
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -102,6 +103,21 @@
]


class _MambaConv1dParamView(torch.nn.Module):
"""Module view exposing direct Mamba conv parameters with standard weight names."""

def __init__(self, mixer: torch.nn.Module, include_weight_quantizer: bool = True):
super().__init__()
self.weight = mixer.conv1d_weight
bias = getattr(mixer, "conv1d_bias", None)
if bias is not None:
self.bias = bias

weight_quantizer = getattr(mixer, "conv1d_weight_weight_quantizer", None)
if include_weight_quantizer and weight_quantizer is not None:
self.weight_quantizer = weight_quantizer


class GPTModelExporter:
"""Megatron Core GPTModel Exporter.

Expand Down Expand Up @@ -664,7 +680,7 @@ def _get_mamba_layer_state_dict(self, layer, layer_id):
self.rules["D"](layer.mixer.D, layer_id)
self.rules["dt_bias"](layer.mixer.dt_bias, layer_id)

self.rules["conv1d"](layer.mixer.conv1d, layer_id)
self.rules["conv1d"](layer.mixer, layer_id)
self.rules["in_proj"](layer.mixer.in_proj, layer_id)
self.rules["out_proj"](layer.mixer.out_proj, layer_id)

Expand Down Expand Up @@ -787,6 +803,7 @@ def _custom_mapping_to_lambda(mapping):
"grouped_mlp_slicing": self._grouped_mlp_slicing,
"pack_name_remapping": self._pack_name_remapping,
"pack_name_remapping_gpt_oss": self._pack_name_remapping_gpt_oss,
"mamba_conv1d_remapping": self._mamba_conv1d_remapping,
}
func = method_map[mapping.func_name]
prefix = mapping.target_name_or_prefix
Expand Down Expand Up @@ -927,6 +944,36 @@ def _record_excluded_module(self, prefix: str):
if layer_name not in self.exclude_modules:
self.exclude_modules.append(layer_name)

def _mamba_conv1d_remapping(self, module: torch.nn.Module, prefix: str, **kwargs):
"""Export Mamba conv1d from either old Conv1d modules or new direct parameters."""
conv1d = getattr(module, "conv1d", None)
if conv1d is not None:
self._name_remapping(conv1d, prefix, **kwargs)
return

if not hasattr(module, "conv1d_weight"):
raise AttributeError(
f"{type(module).__name__} has neither conv1d nor conv1d_weight for export"
)

conv1d_view = _MambaConv1dParamView(module)
qformat = self._get_quantization_format(conv1d_view)
block_size = get_weight_block_size(conv1d_view)
if (
qformat not in (None, QUANTIZATION_NONE)
and block_size
and conv1d_view.weight.shape[-1] % block_size != 0
):
warnings.warn(
f"Exporting direct Mamba conv1d {prefix} in {self.dtype} because "
f"weight shape {tuple(conv1d_view.weight.shape)} is not divisible by "
f"block size {block_size} for {qformat} packing.",
stacklevel=2,
)
conv1d_view = _MambaConv1dParamView(module, include_weight_quantizer=False)

self._name_remapping(conv1d_view, prefix, **kwargs)

def _name_remapping(
self,
module: torch.nn.Module | torch.Tensor,
Expand Down
5 changes: 4 additions & 1 deletion modelopt/torch/quantization/backends/nvfp4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.backends.gemm_registry import gemm_registry
from modelopt.torch.quantization.backends.utils import fp4_compatible
from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear
from modelopt.torch.quantization.qtensor import NVFP4QTensor, QTensorWrapper
from modelopt.torch.quantization.utils import reduce_amax

Expand Down Expand Up @@ -193,6 +192,10 @@ def apply(cls, *args, **kwargs):

def _nvfp4_availability_check(module, input, args, kwargs):
"""Comprehensive check for FP4 GEMM availability."""
# Imported lazily to avoid an import cycle:
# quant_linear -> backends -> nvfp4_gemm -> quant_linear.
from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear

# NOTE: Having the import at the top causes mpirun commands inside pytest (vlm_ptq) to fail without any error
try:
import tensorrt_llm # noqa: F401
Expand Down
163 changes: 162 additions & 1 deletion modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Support quantization for megatron linear layers."""

import contextlib
import types
from typing import Any

Expand All @@ -39,13 +40,19 @@
from modelopt.torch.utils import warn_rank_0
from modelopt.torch.utils.distributed import ParallelState

from .. import tensor_quant
from ..conversion import maybe_promote_nvfp4_static_quantizer
from ..nn import QuantModule, QuantModuleRegistry, SequentialQuantizer, TensorQuantizer
from ..nn.modules.quant_linear import RealQuantLinear
from ..qtensor import QTensorWrapper
from ..utils import sync_moe_expert_amax
from ..utils import is_torch_export_mode, sync_moe_expert_amax
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear

try:
from megatron.core.ssm.mamba_mixer import MambaMixer
except ImportError:
MambaMixer = None

try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelGroupedLinear,
Expand Down Expand Up @@ -232,6 +239,160 @@ def quant_module_set_extra_state(self, state: Any):
self.allow_post_restore = False


if MambaMixer is not None:

def _save_mamba_extra_state(self, prefix, sharded_offsets, sharded_state_dict):
with contextlib.suppress(RuntimeError):
extra_state = self.get_extra_state()
if extra_state is not None:
sharded_state_dict.update(
**make_sharded_tensors_for_checkpoint(
{"_extra_state": extra_state}, prefix, {}, sharded_offsets
)
)
return sharded_state_dict

def _save_mamba_conv1d_quantizer_amax(self, prefix, sharded_offsets, sharded_state_dict):
quantizer = getattr(self, "conv1d_weight_weight_quantizer", None)
if quantizer is None:
return sharded_state_dict

# Direct Mamba conv1d stores its quantizer as a ModelOpt temp attribute.
# MCore's native MambaMixer sharded state saves conv1d_weight/bias but
# does not know that this extra quantizer amax must round-trip for export.
quantizer_state_dict = {
k: v
for k, v in quantizer.state_dict(
prefix="conv1d_weight_weight_quantizer.", keep_vars=True
).items()
if k.endswith(("_amax", "_global_amax"))
}
if "conv1d_weight_weight_quantizer._amax" not in quantizer_state_dict:
calibrator = getattr(quantizer, "_calibrator", None)
if calibrator is not None:
calib_amax = calibrator.compute_amax()
if calib_amax is not None:
quantizer.amax = calib_amax.detach()
quantizer_state_dict["conv1d_weight_weight_quantizer._amax"] = quantizer.amax
if (
"conv1d_weight_weight_quantizer._amax" not in quantizer_state_dict
and quantizer.is_enabled
):
# After checkpoint reload, the direct temp quantizer may no longer
# have calibrator statistics. For weight max calibration, the source
# is the conv1d weight itself, so recompute the same scalar before
# saving/exporting rather than dropping the required amax.
quantizer.amax = quantizer._get_amax(self.conv1d_weight.detach().float()).detach()
quantizer_state_dict["conv1d_weight_weight_quantizer._amax"] = quantizer.amax

for key, value in quantizer_state_dict.items():
if value.numel() != 1:
raise AssertionError(
f"Only scalar direct Mamba conv1d quantizer amax is supported, got "
f"{key} with shape {tuple(value.shape)}."
)

sharded_state_dict.update(
**make_sharded_tensors_for_checkpoint(quantizer_state_dict, prefix, {}, sharded_offsets)
)
return sharded_state_dict

@QuantModuleRegistry.register({MambaMixer: "megatron_MambaMixer"})
class _QuantMambaMixer(QuantModule):
"""Quantize new Megatron Mamba direct conv1d parameters."""

default_quant_desc_weight = tensor_quant.QUANT_DESC_8BIT_CONV1D_WEIGHT_PER_CHANNEL

@contextlib.contextmanager
def quantize_conv1d_weight(self):
"""Context in which ``self.conv1d_weight`` is quantized."""
if not hasattr(self, "conv1d_weight_weight_quantizer"):
yield
return

self._enable_conv1d_weight_quantization = True
try:
yield
finally:
self._enable_conv1d_weight_quantization = False

@staticmethod
def _get_quantized_conv1d_weight(
module: "_QuantMambaMixer", weight: torch.Tensor
) -> torch.Tensor:
if module._enable_conv1d_weight_quantization or is_torch_export_mode():
return module.conv1d_weight_weight_quantizer(weight)
return weight

def forward(self, *args, **kwargs):
"""Quantize the direct conv1d weight before calling MambaMixer.forward."""
if is_torch_export_mode() or not hasattr(self, "conv1d_weight_weight_quantizer"):
return super().forward(*args, **kwargs)

with self.quantize_conv1d_weight():
return super().forward(*args, **kwargs)

def iter_weights_for_calibration(self):
"""Yield direct conv1d weights for weight-only max calibration."""
seen_quantizers = set()
for weight, weight_quantizer in super().iter_weights_for_calibration():
seen_quantizers.add(id(weight_quantizer))
yield weight, weight_quantizer

weight_quantizer = getattr(self, "conv1d_weight_weight_quantizer", None)
if weight_quantizer is not None and id(weight_quantizer) not in seen_quantizers:
yield self.conv1d_weight, weight_quantizer

def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
sharded_state_dict = _save_mamba_conv1d_quantizer_amax(
self, prefix, sharded_offsets, sharded_state_dict
)
return _save_mamba_extra_state(self, prefix, sharded_offsets, sharded_state_dict)

def fold_weight(self, keep_attrs: bool = False):
# NVFP4 export still needs the calibrated scalar amax after the direct
# conv1d weight is folded.
super().fold_weight(keep_attrs=True)
weight_quantizer = getattr(self, "conv1d_weight_weight_quantizer", None)
if (
weight_quantizer is not None
and weight_quantizer.fake_quant
and weight_quantizer.is_enabled
):
self.conv1d_weight.data.copy_(
weight_quantizer(self.conv1d_weight.float()).to(self.conv1d_weight.dtype)
)
weight_quantizer.disable()

def _setup(self):
if not hasattr(self, "conv1d_weight"):
return

if not hasattr(self, "parallel_state") or self.parallel_state is None:
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
warn_rank_0(
"Context parallel group is not initialized, using data parallel group"
)
data_parallel_group = get_data_parallel_group()
self.parallel_state = ParallelState(
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
)

self._register_temp_attribute(
# Generic weight discovery maps parameter ``<name>`` to
# ``<name>_weight_quantizer``.
"conv1d_weight_weight_quantizer",
TensorQuantizer(self.default_quant_desc_weight),
)
self._register_temp_attribute("_enable_conv1d_weight_quantization", False)
self._register_dynamic_attribute("conv1d_weight", self._get_quantized_conv1d_weight)


def _create_incompatible_method(method_name: str):
"""Create a method that raises an error for incompatible flash decode methods."""

Expand Down
Loading
Loading