diff --git a/modelopt/torch/export/plugins/mcore_custom.py b/modelopt/torch/export/plugins/mcore_custom.py index 0b6ce7a35df..6b7841034f6 100644 --- a/modelopt/torch/export/plugins/mcore_custom.py +++ b/modelopt/torch/export/plugins/mcore_custom.py @@ -91,6 +91,18 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] ) +class MambaConv1dMapping(CustomModuleMapping): + """A custom module mapping for Mamba conv1d weights.""" + + def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}): + """Create a mapping for old Conv1d modules and new direct conv1d parameters.""" + super().__init__( + func_name="mamba_conv1d_remapping", + target_name_or_prefix=target_name_or_prefix, + func_kwargs=func_kwargs, + ) + + class QKVMerging(CustomModuleMapping): """A custom module mapping that merges Q, K, V.""" diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 24bd8144055..5df0af871c9 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -25,6 +25,7 @@ CustomModuleMapping, GroupedMLPMerging, GroupedMLPSlicing, + MambaConv1dMapping, NameRemapping, QKVMerging, QKVSlicing, @@ -123,7 +124,7 @@ "A_log": NameRemapping("backbone.layers.{}.mixer.A_log"), "D": NameRemapping("backbone.layers.{}.mixer.D"), "dt_bias": NameRemapping("backbone.layers.{}.mixer.dt_bias"), - "conv1d": NameRemapping("backbone.layers.{}.mixer.conv1d."), + "conv1d": MambaConv1dMapping("backbone.layers.{}.mixer.conv1d."), "in_proj": NameRemapping("backbone.layers.{}.mixer.in_proj."), "out_proj": NameRemapping("backbone.layers.{}.mixer.out_proj."), "fused_norm": NameRemapping("backbone.layers.{}.norm.weight"), diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index cabb0d77580..6c11ce94b71 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -21,6 +21,7 @@ import json import os import tempfile +import warnings from collections import OrderedDict from pathlib import Path from typing import Any @@ -102,6 +103,21 @@ ] +class _MambaConv1dParamView(torch.nn.Module): + """Module view exposing direct Mamba conv parameters with standard weight names.""" + + def __init__(self, mixer: torch.nn.Module, include_weight_quantizer: bool = True): + super().__init__() + self.weight = mixer.conv1d_weight + bias = getattr(mixer, "conv1d_bias", None) + if bias is not None: + self.bias = bias + + weight_quantizer = getattr(mixer, "conv1d_weight_weight_quantizer", None) + if include_weight_quantizer and weight_quantizer is not None: + self.weight_quantizer = weight_quantizer + + class GPTModelExporter: """Megatron Core GPTModel Exporter. @@ -664,7 +680,7 @@ def _get_mamba_layer_state_dict(self, layer, layer_id): self.rules["D"](layer.mixer.D, layer_id) self.rules["dt_bias"](layer.mixer.dt_bias, layer_id) - self.rules["conv1d"](layer.mixer.conv1d, layer_id) + self.rules["conv1d"](layer.mixer, layer_id) self.rules["in_proj"](layer.mixer.in_proj, layer_id) self.rules["out_proj"](layer.mixer.out_proj, layer_id) @@ -787,6 +803,7 @@ def _custom_mapping_to_lambda(mapping): "grouped_mlp_slicing": self._grouped_mlp_slicing, "pack_name_remapping": self._pack_name_remapping, "pack_name_remapping_gpt_oss": self._pack_name_remapping_gpt_oss, + "mamba_conv1d_remapping": self._mamba_conv1d_remapping, } func = method_map[mapping.func_name] prefix = mapping.target_name_or_prefix @@ -927,6 +944,36 @@ def _record_excluded_module(self, prefix: str): if layer_name not in self.exclude_modules: self.exclude_modules.append(layer_name) + def _mamba_conv1d_remapping(self, module: torch.nn.Module, prefix: str, **kwargs): + """Export Mamba conv1d from either old Conv1d modules or new direct parameters.""" + conv1d = getattr(module, "conv1d", None) + if conv1d is not None: + self._name_remapping(conv1d, prefix, **kwargs) + return + + if not hasattr(module, "conv1d_weight"): + raise AttributeError( + f"{type(module).__name__} has neither conv1d nor conv1d_weight for export" + ) + + conv1d_view = _MambaConv1dParamView(module) + qformat = self._get_quantization_format(conv1d_view) + block_size = get_weight_block_size(conv1d_view) + if ( + qformat not in (None, QUANTIZATION_NONE) + and block_size + and conv1d_view.weight.shape[-1] % block_size != 0 + ): + warnings.warn( + f"Exporting direct Mamba conv1d {prefix} in {self.dtype} because " + f"weight shape {tuple(conv1d_view.weight.shape)} is not divisible by " + f"block size {block_size} for {qformat} packing.", + stacklevel=2, + ) + conv1d_view = _MambaConv1dParamView(module, include_weight_quantizer=False) + + self._name_remapping(conv1d_view, prefix, **kwargs) + def _name_remapping( self, module: torch.nn.Module | torch.Tensor, diff --git a/modelopt/torch/quantization/backends/nvfp4_gemm.py b/modelopt/torch/quantization/backends/nvfp4_gemm.py index fdf6babb695..eef505e96f0 100644 --- a/modelopt/torch/quantization/backends/nvfp4_gemm.py +++ b/modelopt/torch/quantization/backends/nvfp4_gemm.py @@ -21,7 +21,6 @@ import modelopt.torch.quantization as mtq from modelopt.torch.quantization.backends.gemm_registry import gemm_registry from modelopt.torch.quantization.backends.utils import fp4_compatible -from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear from modelopt.torch.quantization.qtensor import NVFP4QTensor, QTensorWrapper from modelopt.torch.quantization.utils import reduce_amax @@ -193,6 +192,10 @@ def apply(cls, *args, **kwargs): def _nvfp4_availability_check(module, input, args, kwargs): """Comprehensive check for FP4 GEMM availability.""" + # Imported lazily to avoid an import cycle: + # quant_linear -> backends -> nvfp4_gemm -> quant_linear. + from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear + # NOTE: Having the import at the top causes mpirun commands inside pytest (vlm_ptq) to fail without any error try: import tensorrt_llm # noqa: F401 diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index e18e3cd064b..2f1778ca4ae 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -15,6 +15,7 @@ """Support quantization for megatron linear layers.""" +import contextlib import types from typing import Any @@ -39,13 +40,19 @@ from modelopt.torch.utils import warn_rank_0 from modelopt.torch.utils.distributed import ParallelState +from .. import tensor_quant from ..conversion import maybe_promote_nvfp4_static_quantizer from ..nn import QuantModule, QuantModuleRegistry, SequentialQuantizer, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear from ..qtensor import QTensorWrapper -from ..utils import sync_moe_expert_amax +from ..utils import is_torch_export_mode, sync_moe_expert_amax from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear +try: + from megatron.core.ssm.mamba_mixer import MambaMixer +except ImportError: + MambaMixer = None + try: from megatron.core.extensions.transformer_engine import ( TEColumnParallelGroupedLinear, @@ -232,6 +239,160 @@ def quant_module_set_extra_state(self, state: Any): self.allow_post_restore = False +if MambaMixer is not None: + + def _save_mamba_extra_state(self, prefix, sharded_offsets, sharded_state_dict): + with contextlib.suppress(RuntimeError): + extra_state = self.get_extra_state() + if extra_state is not None: + sharded_state_dict.update( + **make_sharded_tensors_for_checkpoint( + {"_extra_state": extra_state}, prefix, {}, sharded_offsets + ) + ) + return sharded_state_dict + + def _save_mamba_conv1d_quantizer_amax(self, prefix, sharded_offsets, sharded_state_dict): + quantizer = getattr(self, "conv1d_weight_weight_quantizer", None) + if quantizer is None: + return sharded_state_dict + + # Direct Mamba conv1d stores its quantizer as a ModelOpt temp attribute. + # MCore's native MambaMixer sharded state saves conv1d_weight/bias but + # does not know that this extra quantizer amax must round-trip for export. + quantizer_state_dict = { + k: v + for k, v in quantizer.state_dict( + prefix="conv1d_weight_weight_quantizer.", keep_vars=True + ).items() + if k.endswith(("_amax", "_global_amax")) + } + if "conv1d_weight_weight_quantizer._amax" not in quantizer_state_dict: + calibrator = getattr(quantizer, "_calibrator", None) + if calibrator is not None: + calib_amax = calibrator.compute_amax() + if calib_amax is not None: + quantizer.amax = calib_amax.detach() + quantizer_state_dict["conv1d_weight_weight_quantizer._amax"] = quantizer.amax + if ( + "conv1d_weight_weight_quantizer._amax" not in quantizer_state_dict + and quantizer.is_enabled + ): + # After checkpoint reload, the direct temp quantizer may no longer + # have calibrator statistics. For weight max calibration, the source + # is the conv1d weight itself, so recompute the same scalar before + # saving/exporting rather than dropping the required amax. + quantizer.amax = quantizer._get_amax(self.conv1d_weight.detach().float()).detach() + quantizer_state_dict["conv1d_weight_weight_quantizer._amax"] = quantizer.amax + + for key, value in quantizer_state_dict.items(): + if value.numel() != 1: + raise AssertionError( + f"Only scalar direct Mamba conv1d quantizer amax is supported, got " + f"{key} with shape {tuple(value.shape)}." + ) + + sharded_state_dict.update( + **make_sharded_tensors_for_checkpoint(quantizer_state_dict, prefix, {}, sharded_offsets) + ) + return sharded_state_dict + + @QuantModuleRegistry.register({MambaMixer: "megatron_MambaMixer"}) + class _QuantMambaMixer(QuantModule): + """Quantize new Megatron Mamba direct conv1d parameters.""" + + default_quant_desc_weight = tensor_quant.QUANT_DESC_8BIT_CONV1D_WEIGHT_PER_CHANNEL + + @contextlib.contextmanager + def quantize_conv1d_weight(self): + """Context in which ``self.conv1d_weight`` is quantized.""" + if not hasattr(self, "conv1d_weight_weight_quantizer"): + yield + return + + self._enable_conv1d_weight_quantization = True + try: + yield + finally: + self._enable_conv1d_weight_quantization = False + + @staticmethod + def _get_quantized_conv1d_weight( + module: "_QuantMambaMixer", weight: torch.Tensor + ) -> torch.Tensor: + if module._enable_conv1d_weight_quantization or is_torch_export_mode(): + return module.conv1d_weight_weight_quantizer(weight) + return weight + + def forward(self, *args, **kwargs): + """Quantize the direct conv1d weight before calling MambaMixer.forward.""" + if is_torch_export_mode() or not hasattr(self, "conv1d_weight_weight_quantizer"): + return super().forward(*args, **kwargs) + + with self.quantize_conv1d_weight(): + return super().forward(*args, **kwargs) + + def iter_weights_for_calibration(self): + """Yield direct conv1d weights for weight-only max calibration.""" + seen_quantizers = set() + for weight, weight_quantizer in super().iter_weights_for_calibration(): + seen_quantizers.add(id(weight_quantizer)) + yield weight, weight_quantizer + + weight_quantizer = getattr(self, "conv1d_weight_weight_quantizer", None) + if weight_quantizer is not None and id(weight_quantizer) not in seen_quantizers: + yield self.conv1d_weight, weight_quantizer + + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + sharded_state_dict = _save_mamba_conv1d_quantizer_amax( + self, prefix, sharded_offsets, sharded_state_dict + ) + return _save_mamba_extra_state(self, prefix, sharded_offsets, sharded_state_dict) + + def fold_weight(self, keep_attrs: bool = False): + # NVFP4 export still needs the calibrated scalar amax after the direct + # conv1d weight is folded. + super().fold_weight(keep_attrs=True) + weight_quantizer = getattr(self, "conv1d_weight_weight_quantizer", None) + if ( + weight_quantizer is not None + and weight_quantizer.fake_quant + and weight_quantizer.is_enabled + ): + self.conv1d_weight.data.copy_( + weight_quantizer(self.conv1d_weight.float()).to(self.conv1d_weight.dtype) + ) + weight_quantizer.disable() + + def _setup(self): + if not hasattr(self, "conv1d_weight"): + return + + if not hasattr(self, "parallel_state") or self.parallel_state is None: + data_parallel_group = None + try: + data_parallel_group = get_data_parallel_group(with_context_parallel=True) + except AssertionError: + warn_rank_0( + "Context parallel group is not initialized, using data parallel group" + ) + data_parallel_group = get_data_parallel_group() + self.parallel_state = ParallelState( + data_parallel_group, + mcore_parallel.get_tensor_model_parallel_group(), + ) + + self._register_temp_attribute( + # Generic weight discovery maps parameter ```` to + # ``_weight_quantizer``. + "conv1d_weight_weight_quantizer", + TensorQuantizer(self.default_quant_desc_weight), + ) + self._register_temp_attribute("_enable_conv1d_weight_quantization", False) + self._register_dynamic_attribute("conv1d_weight", self._get_quantized_conv1d_weight) + + def _create_incompatible_method(method_name: str): """Create a method that raises an error for incompatible flash decode methods.""" diff --git a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py index f818cb3594c..612425e5afd 100644 --- a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py +++ b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py @@ -17,6 +17,7 @@ from copy import deepcopy from functools import partial from pathlib import Path +from types import SimpleNamespace import pytest import torch @@ -390,6 +391,80 @@ def test_qkv_slicing_records_hf_excludes_for_unquantized_fused_qkv(): assert "backbone.layers.0.mixer.v_proj" in exporter.exclude_modules +def _make_minimal_exporter() -> GPTModelExporter: + exporter = object.__new__(GPTModelExporter) + exporter.dtype = torch.bfloat16 + exporter.exclude_modules = [] + exporter.layer_config_dict = {} + exporter._state_dict = {} + return exporter + + +def test_mamba_conv1d_remapping_exports_direct_params(): + """New MCore Mamba direct conv params should export to HF conv1d keys.""" + exporter = _make_minimal_exporter() + mixer = torch.nn.Module() + mixer.conv1d_weight = torch.nn.Parameter(torch.arange(12, dtype=torch.float32).reshape(3, 1, 4)) + mixer.conv1d_bias = torch.nn.Parameter(torch.arange(3, dtype=torch.float32)) + + exporter._mamba_conv1d_remapping(mixer, "backbone.layers.0.mixer.conv1d.") + + torch.testing.assert_close( + exporter._state_dict["backbone.layers.0.mixer.conv1d.weight"], + mixer.conv1d_weight.to(torch.bfloat16).cpu(), + ) + torch.testing.assert_close( + exporter._state_dict["backbone.layers.0.mixer.conv1d.bias"], + mixer.conv1d_bias.to(torch.bfloat16).cpu(), + ) + assert exporter.exclude_modules == ["backbone.layers.0.mixer.conv1d"] + + +def test_mamba_conv1d_remapping_preserves_old_conv_module_export(): + """Old MCore Mamba nn.Conv1d modules should still use normal name remapping.""" + exporter = _make_minimal_exporter() + mixer = torch.nn.Module() + mixer.conv1d = torch.nn.Conv1d(3, 3, 4, groups=3) + + exporter._mamba_conv1d_remapping(mixer, "backbone.layers.0.mixer.conv1d.") + + torch.testing.assert_close( + exporter._state_dict["backbone.layers.0.mixer.conv1d.weight"], + mixer.conv1d.weight.to(torch.bfloat16).cpu(), + ) + torch.testing.assert_close( + exporter._state_dict["backbone.layers.0.mixer.conv1d.bias"], + mixer.conv1d.bias.to(torch.bfloat16).cpu(), + ) + + +def test_mamba_layer_state_dict_routes_conv1d_mapping_to_mixer(): + """Mamba layer export should route the full mixer into the conv1d rule.""" + exporter = _make_minimal_exporter() + exporter.rules = exporter._populate_rule_book()["NemotronHForCausalLM"] + + layer = SimpleNamespace( + norm=torch.nn.LayerNorm(3), + mixer=SimpleNamespace( + norm=torch.nn.LayerNorm(3), + A_log=torch.ones(2), + D=torch.ones(2), + dt_bias=torch.ones(2), + conv1d_weight=torch.nn.Parameter( + torch.arange(12, dtype=torch.float32).reshape(3, 1, 4) + ), + conv1d_bias=torch.nn.Parameter(torch.arange(3, dtype=torch.float32)), + in_proj=torch.nn.Linear(3, 3), + out_proj=torch.nn.Linear(3, 3), + ), + ) + + exporter._get_mamba_layer_state_dict(layer, layer_id=0) + + assert "backbone.layers.0.mixer.conv1d.weight" in exporter._state_dict + assert "backbone.layers.0.mixer.conv1d.bias" in exporter._state_dict + + def _make_exporter_for_mtp(model_dir: Path) -> GPTModelExporter: """Create a minimal GPTModelExporter instance for testing _get_mtp_state_dict.""" exporter = object.__new__(GPTModelExporter) diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py index b32ba692771..7883bc5506c 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py @@ -53,8 +53,10 @@ import modelopt import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.model_calib import max_calibrate from modelopt.torch.quantization.nn import QuantModuleRegistry from modelopt.torch.quantization.plugins.megatron import _QuantTEMCoreRowParallelLinear +from modelopt.torch.quantization.utils import export_torch_mode, weight_attr_names try: from megatron.core.extensions.transformer_engine import TERowParallelLinear @@ -496,6 +498,95 @@ def test_homogeneous_sharded_state_dict_hybrid(dist_workers, tmp_path, config): ) +def test_mamba_direct_conv1d_weight_quantizer_calibrates(distributed_setup_size_1): + """New MCore direct Mamba conv1d params should get a calibratable weight quantizer.""" + if not HAS_MAMBA: + pytest.skip("Mamba not installed") + + initialize_for_megatron(tensor_model_parallel_size=1, seed=SEED) + model = _gpt_model_provider( + tp_size=1, + hidden_size=256, + vocab_size=256, + is_hybrid=True, + hybrid_override_pattern="M", + mamba_head_dim=16, + ) + + mtq.replace_quant_module(model) + mixers = [module for module in model.modules() if hasattr(module, "conv1d_weight")] + assert mixers + mixer = mixers[0] + + assert hasattr(mixer, "conv1d_weight_weight_quantizer") + assert list(weight_attr_names(mixer)) == ["conv1d_weight"] + weights_for_calibration = list(mixer.iter_weights_for_calibration()) + assert len(weights_for_calibration) == 1 + weight, weight_quantizer = weights_for_calibration[0] + assert weight is mixer.conv1d_weight + assert weight_quantizer is mixer.conv1d_weight_weight_quantizer + + weight_quantizer.set_from_attribute_config( + { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + } + ) + mtq.disable_quantizer(model, "*") + mixer.conv1d_weight_weight_quantizer.enable() + max_calibrate(model, lambda model: None, distributed_sync=False) + + assert mixer.conv1d_weight_weight_quantizer.amax is not None + assert mixer.conv1d_weight_weight_quantizer.amax.numel() == 1 + sharded_state_dict = mixer.sharded_state_dict(prefix="layers.0.mixer.") + assert "layers.0.mixer._extra_state" in sharded_state_dict + extra_state = sharded_state_dict["layers.0.mixer._extra_state"].data + quantizer_state = extra_state["modelopt_quantizer_state"] + assert "conv1d_weight_weight_quantizer" in quantizer_state + assert ( + "_amax" + in quantizer_state["conv1d_weight_weight_quantizer"]["_pytorch_state_metadata"]["buffers"] + ) + conv1d_amax_key = "layers.0.mixer.conv1d_weight_weight_quantizer._amax" + assert conv1d_amax_key in sharded_state_dict + assert torch.equal( + sharded_state_dict[conv1d_amax_key].data, + mixer.conv1d_weight_weight_quantizer.amax, + ) + calibrated_amax = mixer.conv1d_weight_weight_quantizer.amax.clone() + delattr(mixer.conv1d_weight_weight_quantizer, "_amax") + sharded_state_dict = mixer.sharded_state_dict(prefix="layers.0.mixer.") + assert torch.equal(mixer.conv1d_weight_weight_quantizer.amax, calibrated_amax) + assert conv1d_amax_key in sharded_state_dict + assert torch.equal(sharded_state_dict[conv1d_amax_key].data, calibrated_amax) + + quantizer_calls = [] + original_forward = mixer.conv1d_weight_weight_quantizer.forward + + def counted_forward(inputs): + quantizer_calls.append(inputs) + return original_forward(inputs) + + mixer.conv1d_weight_weight_quantizer.forward = counted_forward + _ = mixer.conv1d_weight + assert len(quantizer_calls) == 0 + + with mixer.quantize_conv1d_weight(): + _ = mixer.conv1d_weight + assert len(quantizer_calls) == 1 + + with export_torch_mode(): + _ = mixer.conv1d_weight + assert len(quantizer_calls) == 2 + + mixer.fold_weight() + assert mixer.conv1d_weight_weight_quantizer.amax is not None + assert torch.equal(mixer.conv1d_weight_weight_quantizer.amax, calibrated_amax) + sharded_state_dict = mixer.sharded_state_dict(prefix="layers.0.mixer.") + assert conv1d_amax_key in sharded_state_dict + assert torch.equal(sharded_state_dict[conv1d_amax_key].data, calibrated_amax) + + @pytest.mark.parametrize( "config", [