From fc15b2fd1e0fbcd875c959d0c30999218aebe78c Mon Sep 17 00:00:00 2001 From: Meng Xin Date: Sat, 13 Jun 2026 12:13:32 -0700 Subject: [PATCH 1/2] Support Mamba direct conv1d quant export Megatron now stores Mamba conv1d as direct conv1d_weight and conv1d_bias parameters instead of an nn.Conv1d child. ModelOpt export still needs to emit HF conv1d.weight and conv1d.bias keys, and quantization needs the direct weight to participate in the normal weight-only calibration path. The quantizer follows the existing per-parameter naming contract, so conv1d_weight pairs with conv1d_weight_weight_quantizer. The export adapter exposes that internal pair as a standard weight/weight_quantizer view for the existing name remapping and quantized export logic. Constraint: New Megatron Mamba no longer exposes layer.mixer.conv1d. Rejected: Add a Mamba-specific calibration iterator | standard per-weight quantizer naming keeps generic discovery working. Rejected: Change shared weight discovery helpers | unnecessary once the quantizer follows existing naming convention. Confidence: high Scope-risk: moderate Tested: python -m compileall on touched files Tested: git diff --check Tested: uvx ruff@0.12.11 check on touched files Tested: uvx ruff@0.12.11 format --check on touched files Tested: Slurm smoke 215464 passed in container, including export remap, generic weight_attr_names discovery, max_calibrate amax, and quant/export context invocation Not-tested: Full end-to-end HF artifact export/load for a real Nano3 checkpoint Signed-off-by: Meng Xin --- modelopt/torch/export/plugins/mcore_custom.py | 12 +++ .../torch/export/plugins/mcore_nemotron.py | 3 +- .../torch/export/unified_export_megatron.py | 32 +++++++- .../torch/quantization/plugins/megatron.py | 74 +++++++++++++++++- .../export/test_unified_export_megatron.py | 75 +++++++++++++++++++ .../quantization/plugins/test_megatron.py | 61 +++++++++++++++ 6 files changed, 254 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_custom.py b/modelopt/torch/export/plugins/mcore_custom.py index 0b6ce7a35df..6b7841034f6 100644 --- a/modelopt/torch/export/plugins/mcore_custom.py +++ b/modelopt/torch/export/plugins/mcore_custom.py @@ -91,6 +91,18 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] ) +class MambaConv1dMapping(CustomModuleMapping): + """A custom module mapping for Mamba conv1d weights.""" + + def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}): + """Create a mapping for old Conv1d modules and new direct conv1d parameters.""" + super().__init__( + func_name="mamba_conv1d_remapping", + target_name_or_prefix=target_name_or_prefix, + func_kwargs=func_kwargs, + ) + + class QKVMerging(CustomModuleMapping): """A custom module mapping that merges Q, K, V.""" diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 24bd8144055..5df0af871c9 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -25,6 +25,7 @@ CustomModuleMapping, GroupedMLPMerging, GroupedMLPSlicing, + MambaConv1dMapping, NameRemapping, QKVMerging, QKVSlicing, @@ -123,7 +124,7 @@ "A_log": NameRemapping("backbone.layers.{}.mixer.A_log"), "D": NameRemapping("backbone.layers.{}.mixer.D"), "dt_bias": NameRemapping("backbone.layers.{}.mixer.dt_bias"), - "conv1d": NameRemapping("backbone.layers.{}.mixer.conv1d."), + "conv1d": MambaConv1dMapping("backbone.layers.{}.mixer.conv1d."), "in_proj": NameRemapping("backbone.layers.{}.mixer.in_proj."), "out_proj": NameRemapping("backbone.layers.{}.mixer.out_proj."), "fused_norm": NameRemapping("backbone.layers.{}.norm.weight"), diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index cabb0d77580..79b9ac6c50f 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -102,6 +102,21 @@ ] +class _MambaConv1dParamView(torch.nn.Module): + """Module view exposing direct Mamba conv parameters with standard weight names.""" + + def __init__(self, mixer: torch.nn.Module): + super().__init__() + self.weight = mixer.conv1d_weight + bias = getattr(mixer, "conv1d_bias", None) + if bias is not None: + self.bias = bias + + weight_quantizer = getattr(mixer, "conv1d_weight_weight_quantizer", None) + if weight_quantizer is not None: + self.weight_quantizer = weight_quantizer + + class GPTModelExporter: """Megatron Core GPTModel Exporter. @@ -664,7 +679,7 @@ def _get_mamba_layer_state_dict(self, layer, layer_id): self.rules["D"](layer.mixer.D, layer_id) self.rules["dt_bias"](layer.mixer.dt_bias, layer_id) - self.rules["conv1d"](layer.mixer.conv1d, layer_id) + self.rules["conv1d"](layer.mixer, layer_id) self.rules["in_proj"](layer.mixer.in_proj, layer_id) self.rules["out_proj"](layer.mixer.out_proj, layer_id) @@ -787,6 +802,7 @@ def _custom_mapping_to_lambda(mapping): "grouped_mlp_slicing": self._grouped_mlp_slicing, "pack_name_remapping": self._pack_name_remapping, "pack_name_remapping_gpt_oss": self._pack_name_remapping_gpt_oss, + "mamba_conv1d_remapping": self._mamba_conv1d_remapping, } func = method_map[mapping.func_name] prefix = mapping.target_name_or_prefix @@ -927,6 +943,20 @@ def _record_excluded_module(self, prefix: str): if layer_name not in self.exclude_modules: self.exclude_modules.append(layer_name) + def _mamba_conv1d_remapping(self, module: torch.nn.Module, prefix: str, **kwargs): + """Export Mamba conv1d from either old Conv1d modules or new direct parameters.""" + conv1d = getattr(module, "conv1d", None) + if conv1d is not None: + self._name_remapping(conv1d, prefix, **kwargs) + return + + if not hasattr(module, "conv1d_weight"): + raise AttributeError( + f"{type(module).__name__} has neither conv1d nor conv1d_weight for export" + ) + + self._name_remapping(_MambaConv1dParamView(module), prefix, **kwargs) + def _name_remapping( self, module: torch.nn.Module | torch.Tensor, diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index e18e3cd064b..bdb24700804 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -15,6 +15,7 @@ """Support quantization for megatron linear layers.""" +import contextlib import types from typing import Any @@ -39,13 +40,19 @@ from modelopt.torch.utils import warn_rank_0 from modelopt.torch.utils.distributed import ParallelState +from .. import tensor_quant from ..conversion import maybe_promote_nvfp4_static_quantizer from ..nn import QuantModule, QuantModuleRegistry, SequentialQuantizer, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear from ..qtensor import QTensorWrapper -from ..utils import sync_moe_expert_amax +from ..utils import is_torch_export_mode, sync_moe_expert_amax from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear +try: + from megatron.core.ssm.mamba_mixer import MambaMixer +except ImportError: + MambaMixer = None + try: from megatron.core.extensions.transformer_engine import ( TEColumnParallelGroupedLinear, @@ -232,6 +239,71 @@ def quant_module_set_extra_state(self, state: Any): self.allow_post_restore = False +if MambaMixer is not None: + + @QuantModuleRegistry.register({MambaMixer: "megatron_MambaMixer"}) + class _QuantMambaMixer(QuantModule): + """Quantize new Megatron Mamba direct conv1d parameters.""" + + default_quant_desc_weight = tensor_quant.QUANT_DESC_8BIT_CONV1D_WEIGHT_PER_CHANNEL + + @contextlib.contextmanager + def quantize_conv1d_weight(self): + """Context in which ``self.conv1d_weight`` is quantized.""" + if not hasattr(self, "conv1d_weight_weight_quantizer"): + yield + return + + self._enable_conv1d_weight_quantization = True + try: + yield + finally: + self._enable_conv1d_weight_quantization = False + + @staticmethod + def _get_quantized_conv1d_weight( + module: "_QuantMambaMixer", weight: torch.Tensor + ) -> torch.Tensor: + if module._enable_conv1d_weight_quantization or is_torch_export_mode(): + return module.conv1d_weight_weight_quantizer(weight) + return weight + + def forward(self, *args, **kwargs): + """Quantize the direct conv1d weight before calling MambaMixer.forward.""" + if is_torch_export_mode() or not hasattr(self, "conv1d_weight_weight_quantizer"): + return super().forward(*args, **kwargs) + + with self.quantize_conv1d_weight(): + return super().forward(*args, **kwargs) + + def _setup(self): + if not hasattr(self, "conv1d_weight"): + return + + if not hasattr(self, "parallel_state") or self.parallel_state is None: + data_parallel_group = None + try: + data_parallel_group = get_data_parallel_group(with_context_parallel=True) + except AssertionError: + warn_rank_0( + "Context parallel group is not initialized, using data parallel group" + ) + data_parallel_group = get_data_parallel_group() + self.parallel_state = ParallelState( + data_parallel_group, + mcore_parallel.get_tensor_model_parallel_group(), + ) + + self._register_temp_attribute( + # Generic weight discovery maps parameter ```` to + # ``_weight_quantizer``. + "conv1d_weight_weight_quantizer", + TensorQuantizer(self.default_quant_desc_weight), + ) + self._register_temp_attribute("_enable_conv1d_weight_quantization", False) + self._register_dynamic_attribute("conv1d_weight", self._get_quantized_conv1d_weight) + + def _create_incompatible_method(method_name: str): """Create a method that raises an error for incompatible flash decode methods.""" diff --git a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py index f818cb3594c..612425e5afd 100644 --- a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py +++ b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py @@ -17,6 +17,7 @@ from copy import deepcopy from functools import partial from pathlib import Path +from types import SimpleNamespace import pytest import torch @@ -390,6 +391,80 @@ def test_qkv_slicing_records_hf_excludes_for_unquantized_fused_qkv(): assert "backbone.layers.0.mixer.v_proj" in exporter.exclude_modules +def _make_minimal_exporter() -> GPTModelExporter: + exporter = object.__new__(GPTModelExporter) + exporter.dtype = torch.bfloat16 + exporter.exclude_modules = [] + exporter.layer_config_dict = {} + exporter._state_dict = {} + return exporter + + +def test_mamba_conv1d_remapping_exports_direct_params(): + """New MCore Mamba direct conv params should export to HF conv1d keys.""" + exporter = _make_minimal_exporter() + mixer = torch.nn.Module() + mixer.conv1d_weight = torch.nn.Parameter(torch.arange(12, dtype=torch.float32).reshape(3, 1, 4)) + mixer.conv1d_bias = torch.nn.Parameter(torch.arange(3, dtype=torch.float32)) + + exporter._mamba_conv1d_remapping(mixer, "backbone.layers.0.mixer.conv1d.") + + torch.testing.assert_close( + exporter._state_dict["backbone.layers.0.mixer.conv1d.weight"], + mixer.conv1d_weight.to(torch.bfloat16).cpu(), + ) + torch.testing.assert_close( + exporter._state_dict["backbone.layers.0.mixer.conv1d.bias"], + mixer.conv1d_bias.to(torch.bfloat16).cpu(), + ) + assert exporter.exclude_modules == ["backbone.layers.0.mixer.conv1d"] + + +def test_mamba_conv1d_remapping_preserves_old_conv_module_export(): + """Old MCore Mamba nn.Conv1d modules should still use normal name remapping.""" + exporter = _make_minimal_exporter() + mixer = torch.nn.Module() + mixer.conv1d = torch.nn.Conv1d(3, 3, 4, groups=3) + + exporter._mamba_conv1d_remapping(mixer, "backbone.layers.0.mixer.conv1d.") + + torch.testing.assert_close( + exporter._state_dict["backbone.layers.0.mixer.conv1d.weight"], + mixer.conv1d.weight.to(torch.bfloat16).cpu(), + ) + torch.testing.assert_close( + exporter._state_dict["backbone.layers.0.mixer.conv1d.bias"], + mixer.conv1d.bias.to(torch.bfloat16).cpu(), + ) + + +def test_mamba_layer_state_dict_routes_conv1d_mapping_to_mixer(): + """Mamba layer export should route the full mixer into the conv1d rule.""" + exporter = _make_minimal_exporter() + exporter.rules = exporter._populate_rule_book()["NemotronHForCausalLM"] + + layer = SimpleNamespace( + norm=torch.nn.LayerNorm(3), + mixer=SimpleNamespace( + norm=torch.nn.LayerNorm(3), + A_log=torch.ones(2), + D=torch.ones(2), + dt_bias=torch.ones(2), + conv1d_weight=torch.nn.Parameter( + torch.arange(12, dtype=torch.float32).reshape(3, 1, 4) + ), + conv1d_bias=torch.nn.Parameter(torch.arange(3, dtype=torch.float32)), + in_proj=torch.nn.Linear(3, 3), + out_proj=torch.nn.Linear(3, 3), + ), + ) + + exporter._get_mamba_layer_state_dict(layer, layer_id=0) + + assert "backbone.layers.0.mixer.conv1d.weight" in exporter._state_dict + assert "backbone.layers.0.mixer.conv1d.bias" in exporter._state_dict + + def _make_exporter_for_mtp(model_dir: Path) -> GPTModelExporter: """Create a minimal GPTModelExporter instance for testing _get_mtp_state_dict.""" exporter = object.__new__(GPTModelExporter) diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py index b32ba692771..d23d3508833 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py @@ -53,8 +53,10 @@ import modelopt import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.model_calib import max_calibrate from modelopt.torch.quantization.nn import QuantModuleRegistry from modelopt.torch.quantization.plugins.megatron import _QuantTEMCoreRowParallelLinear +from modelopt.torch.quantization.utils import export_torch_mode, weight_attr_names try: from megatron.core.extensions.transformer_engine import TERowParallelLinear @@ -496,6 +498,65 @@ def test_homogeneous_sharded_state_dict_hybrid(dist_workers, tmp_path, config): ) +def test_mamba_direct_conv1d_weight_quantizer_calibrates(distributed_setup_size_1): + """New MCore direct Mamba conv1d params should get a calibratable weight quantizer.""" + if not HAS_MAMBA: + pytest.skip("Mamba not installed") + + initialize_for_megatron(tensor_model_parallel_size=1, seed=SEED) + model = _gpt_model_provider( + tp_size=1, + hidden_size=256, + vocab_size=256, + is_hybrid=True, + hybrid_override_pattern="M", + mamba_head_dim=16, + ) + + mtq.replace_quant_module(model) + mixers = [module for module in model.modules() if hasattr(module, "conv1d_weight")] + assert mixers + mixer = mixers[0] + + assert hasattr(mixer, "conv1d_weight_weight_quantizer") + assert list(weight_attr_names(mixer)) == ["conv1d_weight"] + weights_for_calibration = list(mixer.iter_weights_for_calibration()) + assert len(weights_for_calibration) == 1 + weight, weight_quantizer = weights_for_calibration[0] + assert weight is mixer.conv1d_weight + assert weight_quantizer is mixer.conv1d_weight_weight_quantizer + + mtq.disable_quantizer(model, "*") + mixer.conv1d_weight_weight_quantizer.enable() + max_calibrate(model, lambda model: None, distributed_sync=False) + + assert mixer.conv1d_weight_weight_quantizer.amax is not None + assert mixer.conv1d_weight_weight_quantizer.amax.shape == ( + mixer.conv1d_weight.shape[0], + 1, + 1, + ) + + quantizer_calls = [] + original_forward = mixer.conv1d_weight_weight_quantizer.forward + + def counted_forward(inputs): + quantizer_calls.append(inputs) + return original_forward(inputs) + + mixer.conv1d_weight_weight_quantizer.forward = counted_forward + _ = mixer.conv1d_weight + assert len(quantizer_calls) == 0 + + with mixer.quantize_conv1d_weight(): + _ = mixer.conv1d_weight + assert len(quantizer_calls) == 1 + + with export_torch_mode(): + _ = mixer.conv1d_weight + assert len(quantizer_calls) == 2 + + @pytest.mark.parametrize( "config", [ From d9b871864902599c7c27f1e5aaadcaa9f05707b3 Mon Sep 17 00:00:00 2001 From: Meng Xin Date: Sun, 14 Jun 2026 03:16:15 -0700 Subject: [PATCH 2/2] Preserve Mamba Conv1d quantizer state for export New Megatron stores Mamba conv1d as direct parameters. ModelOpt adds the direct Conv1d weight quantizer as a dynamic-module attribute, so MCore's native Mamba sharded-state path did not reliably preserve the scalar amax that NVFP4 export needs after checkpoint reload. The Mamba quant module now saves the direct Conv1d quantizer amax in distributed checkpoints, recomputing the weight max from conv1d_weight when reload has dropped calibrator statistics. The Megatron exporter also keeps direct Mamba Conv1d in dtype when its kernel dimension cannot be packed by the configured block quantization format. Constraint: Nano3 uses dynamic block-scale NVFP4; export still requires a calibrated scalar amax. Rejected: Let export tolerate missing amax | missing amax means checkpoint save/restore dropped quantizer state. Rejected: Pack direct Conv1d NVFP4 with kernel size 4 and block size 16 | packed block format requires divisibility. Confidence: medium Scope-risk: moderate Tested: py_compile on touched files; ruff check/format on touched files; file-scoped pre-commit hooks; Nano3 PTQ/train/export smoke job 215752. Not-tested: full ModelOpt GPU test suite. Signed-off-by: Meng Xin --- .../torch/export/unified_export_megatron.py | 23 ++++- .../torch/quantization/backends/nvfp4_gemm.py | 5 +- .../torch/quantization/plugins/megatron.py | 89 +++++++++++++++++++ .../quantization/plugins/test_megatron.py | 38 +++++++- 4 files changed, 147 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 79b9ac6c50f..6c11ce94b71 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -21,6 +21,7 @@ import json import os import tempfile +import warnings from collections import OrderedDict from pathlib import Path from typing import Any @@ -105,7 +106,7 @@ class _MambaConv1dParamView(torch.nn.Module): """Module view exposing direct Mamba conv parameters with standard weight names.""" - def __init__(self, mixer: torch.nn.Module): + def __init__(self, mixer: torch.nn.Module, include_weight_quantizer: bool = True): super().__init__() self.weight = mixer.conv1d_weight bias = getattr(mixer, "conv1d_bias", None) @@ -113,7 +114,7 @@ def __init__(self, mixer: torch.nn.Module): self.bias = bias weight_quantizer = getattr(mixer, "conv1d_weight_weight_quantizer", None) - if weight_quantizer is not None: + if include_weight_quantizer and weight_quantizer is not None: self.weight_quantizer = weight_quantizer @@ -955,7 +956,23 @@ def _mamba_conv1d_remapping(self, module: torch.nn.Module, prefix: str, **kwargs f"{type(module).__name__} has neither conv1d nor conv1d_weight for export" ) - self._name_remapping(_MambaConv1dParamView(module), prefix, **kwargs) + conv1d_view = _MambaConv1dParamView(module) + qformat = self._get_quantization_format(conv1d_view) + block_size = get_weight_block_size(conv1d_view) + if ( + qformat not in (None, QUANTIZATION_NONE) + and block_size + and conv1d_view.weight.shape[-1] % block_size != 0 + ): + warnings.warn( + f"Exporting direct Mamba conv1d {prefix} in {self.dtype} because " + f"weight shape {tuple(conv1d_view.weight.shape)} is not divisible by " + f"block size {block_size} for {qformat} packing.", + stacklevel=2, + ) + conv1d_view = _MambaConv1dParamView(module, include_weight_quantizer=False) + + self._name_remapping(conv1d_view, prefix, **kwargs) def _name_remapping( self, diff --git a/modelopt/torch/quantization/backends/nvfp4_gemm.py b/modelopt/torch/quantization/backends/nvfp4_gemm.py index fdf6babb695..eef505e96f0 100644 --- a/modelopt/torch/quantization/backends/nvfp4_gemm.py +++ b/modelopt/torch/quantization/backends/nvfp4_gemm.py @@ -21,7 +21,6 @@ import modelopt.torch.quantization as mtq from modelopt.torch.quantization.backends.gemm_registry import gemm_registry from modelopt.torch.quantization.backends.utils import fp4_compatible -from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear from modelopt.torch.quantization.qtensor import NVFP4QTensor, QTensorWrapper from modelopt.torch.quantization.utils import reduce_amax @@ -193,6 +192,10 @@ def apply(cls, *args, **kwargs): def _nvfp4_availability_check(module, input, args, kwargs): """Comprehensive check for FP4 GEMM availability.""" + # Imported lazily to avoid an import cycle: + # quant_linear -> backends -> nvfp4_gemm -> quant_linear. + from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear + # NOTE: Having the import at the top causes mpirun commands inside pytest (vlm_ptq) to fail without any error try: import tensorrt_llm # noqa: F401 diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index bdb24700804..2f1778ca4ae 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -241,6 +241,62 @@ def quant_module_set_extra_state(self, state: Any): if MambaMixer is not None: + def _save_mamba_extra_state(self, prefix, sharded_offsets, sharded_state_dict): + with contextlib.suppress(RuntimeError): + extra_state = self.get_extra_state() + if extra_state is not None: + sharded_state_dict.update( + **make_sharded_tensors_for_checkpoint( + {"_extra_state": extra_state}, prefix, {}, sharded_offsets + ) + ) + return sharded_state_dict + + def _save_mamba_conv1d_quantizer_amax(self, prefix, sharded_offsets, sharded_state_dict): + quantizer = getattr(self, "conv1d_weight_weight_quantizer", None) + if quantizer is None: + return sharded_state_dict + + # Direct Mamba conv1d stores its quantizer as a ModelOpt temp attribute. + # MCore's native MambaMixer sharded state saves conv1d_weight/bias but + # does not know that this extra quantizer amax must round-trip for export. + quantizer_state_dict = { + k: v + for k, v in quantizer.state_dict( + prefix="conv1d_weight_weight_quantizer.", keep_vars=True + ).items() + if k.endswith(("_amax", "_global_amax")) + } + if "conv1d_weight_weight_quantizer._amax" not in quantizer_state_dict: + calibrator = getattr(quantizer, "_calibrator", None) + if calibrator is not None: + calib_amax = calibrator.compute_amax() + if calib_amax is not None: + quantizer.amax = calib_amax.detach() + quantizer_state_dict["conv1d_weight_weight_quantizer._amax"] = quantizer.amax + if ( + "conv1d_weight_weight_quantizer._amax" not in quantizer_state_dict + and quantizer.is_enabled + ): + # After checkpoint reload, the direct temp quantizer may no longer + # have calibrator statistics. For weight max calibration, the source + # is the conv1d weight itself, so recompute the same scalar before + # saving/exporting rather than dropping the required amax. + quantizer.amax = quantizer._get_amax(self.conv1d_weight.detach().float()).detach() + quantizer_state_dict["conv1d_weight_weight_quantizer._amax"] = quantizer.amax + + for key, value in quantizer_state_dict.items(): + if value.numel() != 1: + raise AssertionError( + f"Only scalar direct Mamba conv1d quantizer amax is supported, got " + f"{key} with shape {tuple(value.shape)}." + ) + + sharded_state_dict.update( + **make_sharded_tensors_for_checkpoint(quantizer_state_dict, prefix, {}, sharded_offsets) + ) + return sharded_state_dict + @QuantModuleRegistry.register({MambaMixer: "megatron_MambaMixer"}) class _QuantMambaMixer(QuantModule): """Quantize new Megatron Mamba direct conv1d parameters.""" @@ -276,6 +332,39 @@ def forward(self, *args, **kwargs): with self.quantize_conv1d_weight(): return super().forward(*args, **kwargs) + def iter_weights_for_calibration(self): + """Yield direct conv1d weights for weight-only max calibration.""" + seen_quantizers = set() + for weight, weight_quantizer in super().iter_weights_for_calibration(): + seen_quantizers.add(id(weight_quantizer)) + yield weight, weight_quantizer + + weight_quantizer = getattr(self, "conv1d_weight_weight_quantizer", None) + if weight_quantizer is not None and id(weight_quantizer) not in seen_quantizers: + yield self.conv1d_weight, weight_quantizer + + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + sharded_state_dict = _save_mamba_conv1d_quantizer_amax( + self, prefix, sharded_offsets, sharded_state_dict + ) + return _save_mamba_extra_state(self, prefix, sharded_offsets, sharded_state_dict) + + def fold_weight(self, keep_attrs: bool = False): + # NVFP4 export still needs the calibrated scalar amax after the direct + # conv1d weight is folded. + super().fold_weight(keep_attrs=True) + weight_quantizer = getattr(self, "conv1d_weight_weight_quantizer", None) + if ( + weight_quantizer is not None + and weight_quantizer.fake_quant + and weight_quantizer.is_enabled + ): + self.conv1d_weight.data.copy_( + weight_quantizer(self.conv1d_weight.float()).to(self.conv1d_weight.dtype) + ) + weight_quantizer.disable() + def _setup(self): if not hasattr(self, "conv1d_weight"): return diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py index d23d3508833..7883bc5506c 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py @@ -526,16 +526,39 @@ def test_mamba_direct_conv1d_weight_quantizer_calibrates(distributed_setup_size_ assert weight is mixer.conv1d_weight assert weight_quantizer is mixer.conv1d_weight_weight_quantizer + weight_quantizer.set_from_attribute_config( + { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + } + ) mtq.disable_quantizer(model, "*") mixer.conv1d_weight_weight_quantizer.enable() max_calibrate(model, lambda model: None, distributed_sync=False) assert mixer.conv1d_weight_weight_quantizer.amax is not None - assert mixer.conv1d_weight_weight_quantizer.amax.shape == ( - mixer.conv1d_weight.shape[0], - 1, - 1, + assert mixer.conv1d_weight_weight_quantizer.amax.numel() == 1 + sharded_state_dict = mixer.sharded_state_dict(prefix="layers.0.mixer.") + assert "layers.0.mixer._extra_state" in sharded_state_dict + extra_state = sharded_state_dict["layers.0.mixer._extra_state"].data + quantizer_state = extra_state["modelopt_quantizer_state"] + assert "conv1d_weight_weight_quantizer" in quantizer_state + assert ( + "_amax" + in quantizer_state["conv1d_weight_weight_quantizer"]["_pytorch_state_metadata"]["buffers"] + ) + conv1d_amax_key = "layers.0.mixer.conv1d_weight_weight_quantizer._amax" + assert conv1d_amax_key in sharded_state_dict + assert torch.equal( + sharded_state_dict[conv1d_amax_key].data, + mixer.conv1d_weight_weight_quantizer.amax, ) + calibrated_amax = mixer.conv1d_weight_weight_quantizer.amax.clone() + delattr(mixer.conv1d_weight_weight_quantizer, "_amax") + sharded_state_dict = mixer.sharded_state_dict(prefix="layers.0.mixer.") + assert torch.equal(mixer.conv1d_weight_weight_quantizer.amax, calibrated_amax) + assert conv1d_amax_key in sharded_state_dict + assert torch.equal(sharded_state_dict[conv1d_amax_key].data, calibrated_amax) quantizer_calls = [] original_forward = mixer.conv1d_weight_weight_quantizer.forward @@ -556,6 +579,13 @@ def counted_forward(inputs): _ = mixer.conv1d_weight assert len(quantizer_calls) == 2 + mixer.fold_weight() + assert mixer.conv1d_weight_weight_quantizer.amax is not None + assert torch.equal(mixer.conv1d_weight_weight_quantizer.amax, calibrated_amax) + sharded_state_dict = mixer.sharded_state_dict(prefix="layers.0.mixer.") + assert conv1d_amax_key in sharded_state_dict + assert torch.equal(sharded_state_dict[conv1d_amax_key].data, calibrated_amax) + @pytest.mark.parametrize( "config",