diff --git a/modelopt/torch/quantization/calib/max.py b/modelopt/torch/quantization/calib/max.py index 94cee406e..4373fa69d 100644 --- a/modelopt/torch/quantization/calib/max.py +++ b/modelopt/torch/quantization/calib/max.py @@ -66,15 +66,15 @@ def collect(self, x): if x.device.type == "meta": self._calib_amax = local_amax return + assert not torch.any(torch.isnan(local_amax)), ( + f"detected nan values in amax. nan in original tensor: {torch.any(torch.isnan(x))}" + ) assert torch.all(local_amax >= 0), ( "detected negative values after abs, could be torch or cuda bug" ) assert not torch.any(torch.isinf(local_amax)), ( f"detected inf values in amax. inf in original tensor: {torch.any(torch.isinf(x))}" ) - assert not torch.any(torch.isnan(local_amax)), ( - f"detected nan values in amax. nan in original tensor: {torch.any(torch.isnan(x))}" - ) if self._calib_amax is None: self._calib_amax = local_amax else: diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index e1b48ee60..291acba03 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -156,12 +156,21 @@ "*mlp.gate.*": {"enable": False}, # Skip the MOE router "*mlp.shared_expert_gate.*": {"enable": False}, # Skip the MOE router "*linear_attn.conv1d*": {"enable": False}, - "*mixer.conv1d*": {"enable": False}, + "*mixer.conv1d*": {"enable": False}, # Skip mamba conv1d "*output_layer*": {"enable": False}, "output.*": {"enable": False}, "default": {"enable": False}, } +_mamba_moe_disabled_quantizer_cfg = { + "*fc1_latent_proj*": {"enable": False}, # Skip Latent MOE + "*fc2_latent_proj*": {"enable": False}, # Skip Latent MOE + "*q_proj*": {"enable": False}, # Skip QKV Linear + "*k_proj*": {"enable": False}, # Skip QKV Linear + "*v_proj*": {"enable": False}, # Skip QKV Linear + "*o_proj*": {"enable": False}, # Skip QKV Output Projection +} + INT8_DEFAULT_CFG = { "quant_cfg": { "*weight_quantizer": {"num_bits": 8, "axis": 0}, @@ -198,6 +207,28 @@ "algorithm": "max", } +MAMBA_MOE_FP8_AGGRESSIVE_CFG = { + "quant_cfg": { + "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, + "*input_quantizer": {"num_bits": (4, 3), "axis": None}, + **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + }, + "algorithm": "max", +} + +MAMBA_MOE_FP8_CONSERVATIVE_CFG = { + "quant_cfg": { + "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, + "*input_quantizer": {"num_bits": (4, 3), "axis": None}, + **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + "*mixer.in_proj*": {"enable": False}, # Skip mamba linear + "*mixer.out_proj*": {"enable": False}, # Skip mamba linear + }, + "algorithm": "max", +} + FP8_PER_CHANNEL_PER_TOKEN_CFG = { "quant_cfg": { "*weight_quantizer": {"num_bits": (4, 3), "axis": 0}, @@ -388,6 +419,49 @@ "algorithm": "max", } + +MAMBA_MOE_NVFP4_AGGRESSIVE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + }, + "algorithm": "max", +} +MAMBA_MOE_NVFP4_CONSERVATIVE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + "*mixer.in_proj*": {"enable": False}, # Skip mamba linear + "*mixer.out_proj*": {"enable": False}, # Skip mamba linear + }, + "algorithm": "max", +} + + NVFP4_AWQ_LITE_CFG = { "quant_cfg": { "*weight_quantizer": { @@ -652,6 +726,10 @@ "NVFP4_MLP_WEIGHT_ONLY_CFG", "MXFP4_MLP_WEIGHT_ONLY_CFG", "NVFP4_MLP_ONLY_CFG", + "MAMBA_MOE_NVFP4_CONSERVATIVE_CFG", + "MAMBA_MOE_NVFP4_AGGRESSIVE_CFG", + "MAMBA_MOE_FP8_CONSERVATIVE_CFG", + "MAMBA_MOE_FP8_AGGRESSIVE_CFG", } BiasType = Literal["static", "dynamic"]