From 9f49dcd9ef27ecaf14d6d558248b1af6f7d57dee Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Fri, 12 Jun 2026 07:59:00 -0700 Subject: [PATCH 1/4] Add NVFP4 Four-Over-Six weight quantization support Add the NVFP4_FOUR_OVER_SIX_CFG preset (scoped to max calibration) and implement 4/6 scale selection in NVFP4QTensor, normalizing the selected per-block scale with F8_E4M3_MAX_46. Wire the supporting changes through the fp4 kernels, nvfp4_gemm backend, tensor_quant, tensor_quantizer, config, and layer_utils. Add recipe/preset YAMLs and the Megatron PTQ launcher example, plus unit tests covering 4/6 quantization and config registration. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Jennifer Chen --- modelopt/torch/export/layer_utils.py | 3 +- .../kernels/quantization/gemm/fp4_kernel.py | 17 +- .../quantization/gemm/fp4_kernel_hopper.py | 5 +- .../torch/quantization/backends/nvfp4_gemm.py | 13 +- modelopt/torch/quantization/config.py | 6 +- .../nn/modules/tensor_quantizer.py | 13 +- .../quantization/qtensor/nvfp4_tensor.py | 182 ++++++++++++++++-- modelopt/torch/quantization/tensor_quant.py | 4 +- .../configs/numerics/nvfp4_four_over_six.yaml | 24 +++ .../presets/model/nvfp4_four_over_six.yaml | 28 +++ .../units/w4a4_nvfp4_nvfp4_four_over_six.yaml | 29 +++ .../ptq/nvfp4-46-max.yaml | 133 +++++++++++++ .../quantization/test_nvfp4_four_over_six.py | 174 +++++++++++++++++ .../megatron_lm_ptq.yaml | 17 +- 14 files changed, 615 insertions(+), 33 deletions(-) create mode 100644 modelopt_recipes/configs/numerics/nvfp4_four_over_six.yaml create mode 100644 modelopt_recipes/configs/ptq/presets/model/nvfp4_four_over_six.yaml create mode 100644 modelopt_recipes/configs/ptq/units/w4a4_nvfp4_nvfp4_four_over_six.yaml create mode 100644 modelopt_recipes/huggingface/models/nvidia/Nemotron-3-Ultra-550B-A55B/ptq/nvfp4-46-max.yaml create mode 100644 tests/unit/torch/quantization/test_nvfp4_four_over_six.py diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py index e0c78f42def..5258c243d20 100755 --- a/modelopt/torch/export/layer_utils.py +++ b/modelopt/torch/export/layer_utils.py @@ -1198,7 +1198,8 @@ def sync_moe_gate_up_amax(model: nn.Module) -> int: """Take element-wise max of gate and up weight quantizer amaxes per expert. Serving engines fuse gate_proj and up_proj into a single gate_up_proj and - require a single weight_scale_2. Since weight_scale_2 = amax / (6 * 448), + require a single weight_scale_2. Since weight_scale_2 = amax / (6 * m_fp8) + (m_fp8=448 normally, 256 for NVFP4 4/6 mode), syncing amaxes before quantization ensures the per-block weight_scale values are computed against a consistent global scale. diff --git a/modelopt/torch/kernels/quantization/gemm/fp4_kernel.py b/modelopt/torch/kernels/quantization/gemm/fp4_kernel.py index 0e6874ab575..d43ad8f36cc 100644 --- a/modelopt/torch/kernels/quantization/gemm/fp4_kernel.py +++ b/modelopt/torch/kernels/quantization/gemm/fp4_kernel.py @@ -37,6 +37,8 @@ torch.bfloat16: tl.bfloat16, } +FP8_E4M3_MAX = 448.0 + def _torch_dtype_to_tl(dtype: torch.dtype): if dtype not in _TORCH_TO_TL_DTYPE: @@ -211,6 +213,7 @@ def compute_fp4_scales( amax: torch.Tensor, global_amax: torch.Tensor | None = None, quantize_block_scales: bool = True, + fp8_max_for_normalization: float = FP8_E4M3_MAX, ) -> torch.Tensor: """Compute per-block FP4 scales from amax values. @@ -220,6 +223,8 @@ def compute_fp4_scales( amax: Per-block amax values (any shape). global_amax: Global amax for FP8 two-level scaling. Computed from *amax* if None. quantize_block_scales: If True, quantize scales to FP8 E4M3. + fp8_max_for_normalization: FP8 max value used to normalize per-block scales + before FP8 quantization (default 448.0; use 256.0 for NVFP4 4/6 mode). Returns: Per-block scales (same shape as *amax*), float32. @@ -235,7 +240,7 @@ def compute_fp4_scales( global_amax = reduce_amax(amax, axis=None, keepdims=False, squeeze_scalar=True) global_amax = global_amax.float() - scale_fp8_quant_amax = global_amax / 6.0 + scale_fp8_quant_amax = global_amax * (FP8_E4M3_MAX / fp8_max_for_normalization) / 6.0 scale = scaled_e4m3_impl(scale, scale_fp8_quant_amax) return scale @@ -246,6 +251,7 @@ def static_blockwise_fp4_fake_quant( amax: torch.Tensor, global_amax: torch.Tensor | None = None, quantize_block_scales: bool = True, + fp8_max_for_normalization: float = FP8_E4M3_MAX, out_dtype: torch.dtype | None = None, ): """Static blockwise FP4 fake quantization using Triton kernel. @@ -261,6 +267,8 @@ def static_blockwise_fp4_fake_quant( consumes it as a flat 1-D buffer of length ``NUM_FP4_BLOCKS``. global_amax: FP32 scalar global amax. If provided, used to compute scale_fp8_quant_amax. quantize_block_scales: If True, quantize block scales to FP8. + fp8_max_for_normalization: FP8 max value used to normalize per-block scales + before FP8 quantization (default 448.0; use 256.0 for NVFP4 4/6 mode). out_dtype: Output dtype. Defaults to x.dtype if None. """ original_shape = x.shape @@ -275,7 +283,12 @@ def static_blockwise_fp4_fake_quant( if out_dtype is None: out_dtype = x.dtype - scale = compute_fp4_scales(amax, global_amax, quantize_block_scales) + scale = compute_fp4_scales( + amax, + global_amax, + quantize_block_scales, + fp8_max_for_normalization=fp8_max_for_normalization, + ) x_flat = x.contiguous().view(-1) y_flat = torch.empty_like(x_flat, dtype=out_dtype) diff --git a/modelopt/torch/kernels/quantization/gemm/fp4_kernel_hopper.py b/modelopt/torch/kernels/quantization/gemm/fp4_kernel_hopper.py index 624e723b957..9b9adaeb33c 100644 --- a/modelopt/torch/kernels/quantization/gemm/fp4_kernel_hopper.py +++ b/modelopt/torch/kernels/quantization/gemm/fp4_kernel_hopper.py @@ -103,6 +103,7 @@ def fp4_fake_quant_block( x: torch.Tensor, global_amax: torch.Tensor, block_size: int = 16, + fp8_max_for_normalization: float = 448.0, tile_rows: int = 16, tile_cols: int = 64, num_warps: int | None = None, @@ -114,6 +115,8 @@ def fp4_fake_quant_block( x (torch.Tensor): Input tensor of shape ``(M, N)`` or higher. global_amax (torch.Tensor): Global maximum value tensor for scaling. block_size (int): Number of elements per FP4 block. + fp8_max_for_normalization (float): FP8 max value used to normalize per-block + scales before FP8 quantization (default 448.0; use 256.0 for NVFP4 4/6 mode). tile_rows (int, optional): Row tile size. Defaults to 16. tile_cols (int, optional): Column tile size. Defaults to 64. Rounded up to the nearest multiple of ``block_size`` internally. @@ -137,7 +140,7 @@ def fp4_fake_quant_block( tile_cols_aligned = ((tile_cols + block_size - 1) // block_size) * block_size num_fp4_blocks = tile_cols_aligned // block_size - global_scale = (global_amax.float() / (6.0 * 448.0)).to(x.device) + global_scale = (global_amax.float() / (6.0 * fp8_max_for_normalization)).to(x.device) grid = lambda *_: (triton.cdiv(M, tile_rows), triton.cdiv(N, tile_cols_aligned)) diff --git a/modelopt/torch/quantization/backends/nvfp4_gemm.py b/modelopt/torch/quantization/backends/nvfp4_gemm.py index fdf6babb695..3ace55b1e75 100644 --- a/modelopt/torch/quantization/backends/nvfp4_gemm.py +++ b/modelopt/torch/quantization/backends/nvfp4_gemm.py @@ -79,6 +79,8 @@ def _fp4_linear( quant_module._input_global_scale = 448.0 * 6.0 / input_amax.float() weight = quant_module.weight + is_four_over_six = bool(quant_module.weight_quantizer.block_sizes.get("four_over_six", False)) + weight_fp8_max = 256.0 if is_four_over_six else 448.0 cached_weight_global_scale = hasattr(quant_module, "_weight_global_scale") if isinstance(weight, QTensorWrapper): # weight is already compressed. @@ -102,7 +104,7 @@ def _fp4_linear( if not cached_weight_global_scale: weight_amax = quant_module.weight_quantizer.amax or reduce_amax(weight) assert weight_amax != 0 - quant_module._weight_global_scale = 448.0 * 6.0 / weight_amax.float() + quant_module._weight_global_scale = weight_fp8_max * 6.0 / weight_amax.float() quant_module._weight_amax = weight_amax weight_fp4, weight_scale = torch.ops.trtllm.fp4_quantize( @@ -211,6 +213,15 @@ def _nvfp4_availability_check(module, input, args, kwargs): if not hasattr(module, "input_quantizer") or not hasattr(module, "weight_quantizer"): return False + # 4/6 relies on adaptive per-block M=4 vs M=6 selection. TRT on-the-fly quantization + # from FP16/BF16 weights only has a global scale and cannot reproduce that selection. + # Keep backend available for pre-quantized weights (QTensorWrapper) where per-block + # scales are already materialized. + if module.weight_quantizer.block_sizes.get("four_over_six", False) and not isinstance( + module.weight, QTensorWrapper + ): + return False + quant_cfg_list: list = mtq.NVFP4_DEFAULT_CFG["quant_cfg"] # Quantizer configs input_cfg = mtq.config.find_quant_cfg_entry_by_path(quant_cfg_list, "*input_quantizer").get( diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 09d3437e595..ae842d77850 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -523,7 +523,7 @@ def validate_block_sizes(cls, v, info: ValidationInfo): ) for _k, _v in v.items(): if isinstance(_k, str): - assert _k in ["type", "scale_bits", "scale_block_sizes"] + assert _k in ["type", "scale_bits", "scale_block_sizes", "four_over_six"] else: assert isinstance(_k, int) and (_v is None or isinstance(_v, int)) return v @@ -1328,6 +1328,9 @@ def _load_quantizer_cfg_dict_list(config_path: str) -> list[dict[str, Any]]: FP8_AFFINE_KV_CFG: dict[str, Any] = _load_quantize_config_dict("configs/ptq/presets/kv/fp8_affine") NVFP4_DEFAULT_CFG: dict[str, Any] = _load_quantize_config_dict("configs/ptq/presets/model/nvfp4") +NVFP4_FOUR_OVER_SIX_CFG: dict[str, Any] = _load_quantize_config_dict( + "configs/ptq/presets/model/nvfp4_four_over_six" +) NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG: dict[str, Any] = _load_quantize_config_dict( "configs/ptq/presets/model/nvfp4_w4a4_weight_mse_fp8_sweep" ) @@ -1406,6 +1409,7 @@ def _load_quantizer_cfg_dict_list(config_path: str) -> list[dict[str, Any]]: "NVFP4_AWQ_FULL_CFG", "NVFP4_AWQ_LITE_CFG", "NVFP4_DEFAULT_CFG", + "NVFP4_FOUR_OVER_SIX_CFG", "NVFP4_FP8_MHA_CONFIG", "NVFP4_KV_CFG", "NVFP4_KV_ROTATE_CFG", diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index b96f146b237..7fef3c09221 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -787,10 +787,13 @@ def _real_quantize(self, inputs): outputs, _weights_scaling_factor, _weights_scaling_factor_2 = NVFP4QTensor.quantize( inputs, self._block_sizes[-1], - weights_scaling_factor_2=self.amax.float() / (448.0 * 6.0) - if self.amax is not None - else None, + weights_scaling_factor_2=( + NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(self) + if self.amax is not None + else None + ), try_tensorrt=True, + four_over_six=bool(self._block_sizes.get("four_over_six", False)), ) buffer_to_register["_scale"] = _weights_scaling_factor buffer_to_register["_double_scale"] = _weights_scaling_factor_2 @@ -1449,11 +1452,15 @@ def _apply(self, fn, recurse=True): def _fake_quantize(self, inputs): """Fake quantization using two-level scaling with _amax and _global_amax.""" if self.amax is not None: + fp8_max_for_normalization = ( + 256.0 if self.block_sizes.get("four_over_six", False) else 448.0 + ) return static_blockwise_fp4_fake_quant( inputs, self.amax, self.global_amax, # Can be None, will be computed internally True, # quantize_block_scales + fp8_max_for_normalization, inputs.dtype, self._pass_through_bwd, ) diff --git a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py index 0f84cdea3c1..648b946b6e9 100644 --- a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py +++ b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py @@ -25,22 +25,30 @@ e2m1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5]) e2m1_values = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6]) +# Four Over Six (4/6) adaptive block scaling — paper arXiv:2512.02010v5. +# scales to either 4 or 6 per block, therefore the FP8 block scales are either 448 or 256. +FP4_E2M1_MAX = 6.0 +FP4_E2M1_MAX_M4 = 4.0 +F8_E4M3_MAX = 448.0 +F8_E4M3_MAX_46 = 256.0 + __all__ = ["NVFP4QTensor"] def _cast_per_block_scale_to_fp8( per_block_scale: torch.Tensor, per_block_scale_max: torch.Tensor | None = None, + fp8_max_for_normalization: float = F8_E4M3_MAX, ) -> torch.Tensor: """Clamp to FP8 E4M3FN range [2**-9, 448] and cast — avoids underflow→0 / overflow→NaN. When ``per_block_scale_max`` is provided, first rescales as - ``per_block_scale.float() * 448 / per_block_scale_max`` — the static-export + ``per_block_scale.float() * fp8_max_for_normalization / per_block_scale_max`` — the static-export path needs this because the ``[==0]=1.0`` safety net combined with a small ``global_amax`` can drive the rescaled value above 448 (see PR #1397). """ if per_block_scale_max is not None: - per_block_scale = per_block_scale.float() * 448.0 / per_block_scale_max + per_block_scale = per_block_scale.float() * fp8_max_for_normalization / per_block_scale_max return per_block_scale.clamp(min=2**-9, max=448.0).to(torch.float8_e4m3fn) @@ -82,6 +90,12 @@ def _get_static_global_amax(cls, weight_quantizer): global_amax = getattr(weight_quantizer, "_global_amax", None) return global_amax + @classmethod + def _is_four_over_six(cls, weight_quantizer) -> bool: + """Return True if 4/6 adaptive block scaling is enabled on this quantizer.""" + bs = getattr(weight_quantizer, "block_sizes", None) or {} + return bool(bs.get("four_over_six", False)) + @classmethod def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer): """Returns per tensor weight scaling factor from the weight_quantizer. @@ -95,14 +109,15 @@ def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer): Returns: The global scaling factor as a float tensor. """ + m_fp8 = F8_E4M3_MAX_46 if cls._is_four_over_six(weight_quantizer) else F8_E4M3_MAX global_amax = cls._get_static_global_amax(weight_quantizer) if global_amax is not None: - return global_amax.float() / (6.0 * 448.0) + return global_amax.float() / (FP4_E2M1_MAX * m_fp8) else: assert hasattr(weight_quantizer, "_amax"), ( "Weight quantizer does not have attribute amax" ) - return weight_quantizer._amax.float() / (6.0 * 448.0) + return weight_quantizer._amax.float() / (FP4_E2M1_MAX * m_fp8) @classmethod def get_weights_scaling_factor_from_quantizer( @@ -133,14 +148,17 @@ def get_weights_scaling_factor_from_quantizer( weight_quantizer ) + is_four_over_six = cls._is_four_over_six(weight_quantizer) + fp8_max_for_normalization = F8_E4M3_MAX_46 if is_four_over_six else F8_E4M3_MAX + if cls._is_static_quantizer(weight_quantizer): # Static path: use pre-computed per-block amax values from quantizer global_amax = cls._get_static_global_amax(weight_quantizer).float() per_block_amax = weight_quantizer._amax.float() # Compute scales in float - per_block_scale_max = global_amax / 6.0 - per_block_scale = per_block_amax / 6.0 + per_block_scale_max = global_amax / FP4_E2M1_MAX + per_block_scale = per_block_amax / FP4_E2M1_MAX per_block_scale[per_block_scale == 0] = 1.0 # Reshape per_block_scale to match weight's block structure @@ -148,13 +166,31 @@ def get_weights_scaling_factor_from_quantizer( expected_shape = (*weight.shape[:-1], num_blocks_per_row) per_block_scale = per_block_scale.view(expected_shape) + if is_four_over_six: + per_block_scale = cls._select_four_over_six_scale( + weight, + per_block_scale, + weights_scaling_factor_2, + block_size, + per_block_scale_max, + fp8_max_for_normalization=fp8_max_for_normalization, + ) + if not keep_high_precision: - per_block_scale = _cast_per_block_scale_to_fp8(per_block_scale, per_block_scale_max) + per_block_scale = _cast_per_block_scale_to_fp8( + per_block_scale, + per_block_scale_max, + fp8_max_for_normalization=fp8_max_for_normalization, + ) return per_block_scale, weights_scaling_factor_2 else: # Dynamic path: compute from weight tensor return cls.get_weights_scaling_factor( - weight, block_size, weights_scaling_factor_2, keep_high_precision + weight, + block_size, + weights_scaling_factor_2, + keep_high_precision, + four_over_six=is_four_over_six, ) @classmethod @@ -164,6 +200,7 @@ def get_weights_scaling_factor( block_size: int, weights_scaling_factor_2: torch.Tensor | None = None, keep_high_precision: bool = False, + four_over_six: bool = False, ): """Returns quantized per block weight scaling factor from weight tensor. @@ -171,7 +208,9 @@ def get_weights_scaling_factor( For quantizers with pre-computed amax, use get_weights_scaling_factor_from_quantizer. """ if weights_scaling_factor_2 is None: - weights_scaling_factor_2 = cls.get_weights_scaling_factor_2(input) + weights_scaling_factor_2 = cls.get_weights_scaling_factor_2( + input, four_over_six=four_over_six + ) # Get per_block amax assert block_size != 0, "Block size is zero. Cannot return per_block amax for given input." @@ -182,20 +221,125 @@ def get_weights_scaling_factor( # Get per block amax per_block_amax = reduce_block_amax(input, block_sizes={-1: block_size}).float() - # Get per-block-scale + # Get per-block-scale (default M=6) per_block_scale = per_block_amax / ( - 6.0 * weights_scaling_factor_2.to(per_block_amax.device) + FP4_E2M1_MAX * weights_scaling_factor_2.to(per_block_amax.device) ) # Set all zero values in scale to 1.0 per_block_scale[per_block_scale == 0] = 1.0 + + if four_over_six: + per_block_scale = cls._select_four_over_six_scale( + input, per_block_scale, weights_scaling_factor_2, block_size + ) + if not keep_high_precision: per_block_scale = _cast_per_block_scale_to_fp8(per_block_scale) return per_block_scale, weights_scaling_factor_2 @classmethod - def get_weights_scaling_factor_2(cls, input: torch.Tensor): + def get_weights_scaling_factor_2(cls, input: torch.Tensor, four_over_six: bool = False): """Returns per tensor weight scaling factor.""" - return reduce_amax(input).float() / (6.0 * 448.0) + m_fp8 = F8_E4M3_MAX_46 if four_over_six else F8_E4M3_MAX + return reduce_amax(input).float() / (FP4_E2M1_MAX * m_fp8) + + @classmethod + def _select_four_over_six_scale( + cls, + weight: torch.Tensor, + per_block_scale_m6: torch.Tensor, + weights_scaling_factor_2: torch.Tensor, + block_size: int, + per_block_scale_max: torch.Tensor | None = None, + fp8_max_for_normalization: float = F8_E4M3_MAX, + ) -> torch.Tensor: + """Pick M=4 or M=6 per block by per-block MSE (paper §3.1, arXiv:2512.02010v5). + + Both candidates share the per-block amax: the M=4 scale equals the M=6 scale times 6/4. + We round both candidates onto the E2M1 grid (after F8 quantization of the block scale) + and pick whichever yields lower per-block MSE against the BF16/F32 weight values. + + Inputs: + weight: original weight tensor [..., features], features divisible by block_size. + per_block_scale_m6: F32 per-block scale under the default M=6 rule. + Shape [..., num_blocks]. + weights_scaling_factor_2: per-tensor F32 alpha. Must already use the 4/6-adjusted + denominator (FP4_E2M1_MAX * F8_E4M3_MAX_46), set by get_weights_scaling_factor_2*. + block_size: block length (16 for NVFP4). + per_block_scale_max: optional max scale value for the static-export F8 rescale + (see _cast_per_block_scale_to_fp8). Pass-through only. + + Returns the per-block scale in F32, with M=4 blocks scaled by 6/4 vs M=6 blocks. + Same shape as per_block_scale_m6. The caller is responsible for the subsequent + F8_E4M3 cast. + """ + ratio = FP4_E2M1_MAX / FP4_E2M1_MAX_M4 # 1.5 + per_block_scale_m4 = per_block_scale_m6 * ratio + + # Round candidate per-block scales to F8_E4M3 precision before scoring — the saved scales + # are F8 quantized, so MSE under F8-rounded scales is what eventually gets deployed. + scale_m6_f8 = _cast_per_block_scale_to_fp8( + per_block_scale_m6, + per_block_scale_max, + fp8_max_for_normalization=fp8_max_for_normalization, + ).to(torch.float32) + scale_m4_f8 = _cast_per_block_scale_to_fp8( + per_block_scale_m4, + per_block_scale_max, + fp8_max_for_normalization=fp8_max_for_normalization, + ).to(torch.float32) + + # Quantize-then-dequantize both candidates on the actual weight, compare per-block MSE. + alpha = weights_scaling_factor_2.to(weight.device).to(torch.float32) + deq_m6 = cls._fake_quant_to_e2m1(weight, scale_m6_f8, alpha, block_size) + deq_m4 = cls._fake_quant_to_e2m1(weight, scale_m4_f8, alpha, block_size) + + w_blocks = weight.to(torch.float32).view(*weight.shape[:-1], -1, block_size) + mse_m6 = ((w_blocks - deq_m6) ** 2).mean(dim=-1) + mse_m4 = ((w_blocks - deq_m4) ** 2).mean(dim=-1) + chose_m4 = mse_m4 < mse_m6 + + return torch.where(chose_m4, per_block_scale_m4, per_block_scale_m6) + + @classmethod + def _fake_quant_to_e2m1( + cls, + weight: torch.Tensor, + per_block_scale_f32: torch.Tensor, + alpha: torch.Tensor, + block_size: int, + ) -> torch.Tensor: + """Round-trip quantize one candidate (scale_block ⊗ alpha) and return dequantized blocks. + + Returns shape [..., num_blocks, block_size] in float32. + """ + device = weight.device + w_blocks = weight.to(torch.float32).view(*weight.shape[:-1], -1, block_size) + scale = per_block_scale_f32.view(*per_block_scale_f32.shape, 1).to(device) + alpha_v = alpha.to(torch.float32) + if alpha_v.dim() == 0: + divisor = scale * alpha_v + else: + divisor = scale * alpha_v.view(*alpha_v.shape, *([1] * (scale.dim() - alpha_v.dim()))) + scaled = w_blocks / divisor + + # Sign + abs, then round abs to E2M1 grid using the same bounds as _cast_fp4. Values + # whose magnitude exceeds the implicit grid max (6.0) are clamped before rounding. + sign = torch.sign(scaled) + abs_v = scaled.abs().clamp_(max=FP4_E2M1_MAX) + bounds = cls.get_e2m1_bounds(device) + ord_ = torch.searchsorted(bounds, abs_v, out_int32=True) + # Mirror the equals-bound nudge in _cast_fp4 (round-half-up at odd-indexed bounds) + odd_bounds = bounds[[1, 3, 5]] + nudge = torch.any(abs_v.unsqueeze(-1) == odd_bounds, dim=-1).to(ord_.dtype) + ord_ = ord_ + nudge + # Map ordinal → magnitude + e2m1_pos = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], device=device, dtype=torch.float32 + ) + ord_ = ord_.clamp_(0, 7) + mag = e2m1_pos[ord_.long()] + return sign * mag * divisor @classmethod def get_activation_scaling_factor(cls, quantizer): @@ -249,6 +393,7 @@ def quantize( weights_scaling_factor_2: torch.Tensor | None = None, keep_high_precision: bool = False, try_tensorrt: bool = False, + four_over_six: bool = False, ): """Converting a tensor to a quantized format based on NVFP4 quantization. @@ -258,6 +403,10 @@ def quantize( weights_scaling_factor (torch.Tensor): The scaling factor for the weights. weights_scaling_factor_2 (torch.Tensor): The scaling factor for the weights. keep_high_precision (bool): Whether to keep output scales at high precision. + four_over_six (bool): Enable per-block M=4 vs M=6 adaptive selection + (paper arXiv:2512.02010v5 §3.1). Only consulted when + ``weights_scaling_factor`` is None — otherwise the caller-provided scale + already encodes the M choice. Returns: tuple: Contains quantized data, quantized per block scaling factor, and per tensor scaling factor. @@ -270,13 +419,16 @@ def quantize( input = reduce_block_padding(input, block_sizes={-1: block_size}) if weights_scaling_factor_2 is None: - weights_scaling_factor_2 = cls.get_weights_scaling_factor_2(input) + weights_scaling_factor_2 = cls.get_weights_scaling_factor_2( + input, four_over_six=four_over_six + ) # try call trtllm fp4 quantization if possible if ( fp4_compatible() and weights_scaling_factor is None and try_tensorrt + and not four_over_six and block_size == 16 and input.is_cuda and input.dtype in [torch.half, torch.bfloat16] @@ -305,7 +457,7 @@ def quantize( if weights_scaling_factor is None: weights_scaling_factor, _ = cls.get_weights_scaling_factor( - input, block_size, weights_scaling_factor_2 + input, block_size, weights_scaling_factor_2, four_over_six=four_over_six ) # Reshape the weight and scale factors diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index 15d782c4a79..45a106070ab 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -577,6 +577,7 @@ def forward( amax, global_amax=None, quantize_block_scales=True, + fp8_max_for_normalization=448.0, out_dtype=None, pass_through_bwd=False, ): @@ -592,13 +593,14 @@ def forward( amax, global_amax, quantize_block_scales, + fp8_max_for_normalization, out_dtype, ) @staticmethod def backward(ctx, grad_outputs): """Implements straight through estimation with clipping.""" - return _fake_quant_backward_function(ctx, grad_outputs, num_args=6) + return _fake_quant_backward_function(ctx, grad_outputs, num_args=len(ctx.needs_input_grad)) def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True): diff --git a/modelopt_recipes/configs/numerics/nvfp4_four_over_six.yaml b/modelopt_recipes/configs/numerics/nvfp4_four_over_six.yaml new file mode 100644 index 00000000000..a078328dcb4 --- /dev/null +++ b/modelopt_recipes/configs/numerics/nvfp4_four_over_six.yaml @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Static (max-calibrated) NVFP4 E2M1 quantization with Four-Over-Six (4/6) adaptive per-block scale selection enabled. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizerAttributeConfig +num_bits: e2m1 +block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + four_over_six: true diff --git a/modelopt_recipes/configs/ptq/presets/model/nvfp4_four_over_six.yaml b/modelopt_recipes/configs/ptq/presets/model/nvfp4_four_over_six.yaml new file mode 100644 index 00000000000..5a46a58ea4c --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/nvfp4_four_over_six.yaml @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for dynamic NVFP4 W4A4 quantization with Four-Over-Six (4/6) adaptive scaling on weights. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + w4a4_nvfp4_nvfp4_four_over_six: configs/ptq/units/w4a4_nvfp4_nvfp4_four_over_six + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + +algorithm: max +quant_cfg: + - $import: base_disable_all + - $import: w4a4_nvfp4_nvfp4_four_over_six + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/units/w4a4_nvfp4_nvfp4_four_over_six.yaml b/modelopt_recipes/configs/ptq/units/w4a4_nvfp4_nvfp4_four_over_six.yaml new file mode 100644 index 00000000000..768c09a6184 --- /dev/null +++ b/modelopt_recipes/configs/ptq/units/w4a4_nvfp4_nvfp4_four_over_six.yaml @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizerCfgList snippet that enables dynamic NVFP4 on weight and input quantizers, +# with Four-Over-Six (4/6) adaptive scaling on the weight quantizers only. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizerCfgListConfig +imports: + nvfp4: configs/numerics/nvfp4 + nvfp4_four_over_six: configs/numerics/nvfp4_four_over_six +--- + - quantizer_name: '*weight_quantizer' + cfg: + $import: nvfp4_four_over_six + - quantizer_name: '*input_quantizer' + cfg: + $import: nvfp4 diff --git a/modelopt_recipes/huggingface/models/nvidia/Nemotron-3-Ultra-550B-A55B/ptq/nvfp4-46-max.yaml b/modelopt_recipes/huggingface/models/nvidia/Nemotron-3-Ultra-550B-A55B/ptq/nvfp4-46-max.yaml new file mode 100644 index 00000000000..efee9f5349b --- /dev/null +++ b/modelopt_recipes/huggingface/models/nvidia/Nemotron-3-Ultra-550B-A55B/ptq/nvfp4-46-max.yaml @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +# Ultra NVFP4 mixed-precision recipe with Four-Over-Six (4/6) adaptive block scaling +# on routed-expert weights. Paper: "Four Over Six: More Accurate NVFP4 Quantization +# with Adaptive Block Scaling" (arXiv:2512.02010v5). +# +# Layout (unchanged from super-nvfp4): +# - MoE routed experts: NVFP4 W4A4 weight, block 16 +# + 4/6 adaptive block scaling on weights only +# HF names: mixer.experts..{up,down}_proj +# Megatron-Core names: mlp.experts.local_experts..linear_fc{1,2} +# - MoE shared experts: FP8 per-tensor +# - Mamba mixer linears (mixer.{in,out}_proj):FP8 per-tensor +# - KV cache: FP8 +# - Attention linears, MTP, lm_head, latent MOE, mamba conv1d: BF16 +# Calibration: weight amax (max). +# +# What 4/6 changes at export time: +# - Per-tensor F32 scale (weight_scale_2) uses denominator 6*256 instead of 6*448 +# for routed-expert weights, so blocks selecting M=4 don't overflow the F8_E4M3 +# block-scale ceiling. +# - Per F8_E4M3 block scale (weight_scale_1) is selected per block: M=4 or M=6 +# based on lower per-block quantization MSE against the BF16 weight. +# - Storage layout (U8-packed FP4 + F8_E4M3 block scales + F32 per-tensor scale) +# is byte-identical to standard NVFP4. Native FP4 tensorcore path intact. +# - Activation quantizer cfg is UNCHANGED (standard W4A4 dynamic, no 4/6). +metadata: + recipe_type: ptq + description: Ultra NVFP4 mixed precision with Four-Over-Six adaptive block scaling on routed-expert weights (W4A4, block 16); shared experts + mamba in/out_proj + FP8 per-tensor; FP8 KV cache; everything else BF16. Amax calibration variant. +quantize: + algorithm: + method: max + quant_cfg: + # Disable all layers by default so that these layers stay in original BF16 precision: + # lm_head, output projection, MoE routers/gates, Latent MOE, MTP head, mamba conv1d. + - quantizer_name: '*' + enable: false + + # MoE routed experts -> NVFP4 W4A4, block_size 16, e4m3 scale, 4/6 adaptive block scaling. + # HF/export names: backbone.layers.*.mixer.experts.*.{up,down}_proj. + - quantizer_name: '*mixer.experts.*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + four_over_six: true + num_bits: e2m1 + - quantizer_name: '*mixer.experts.*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + # Megatron-Core/PTQ names: decoder.layers.*.mlp.experts.local_experts.*.linear_fc{1,2}. + - quantizer_name: '*mlp.experts*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + four_over_six: true + num_bits: e2m1 + - quantizer_name: '*mlp.experts*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + + # MoE shared experts -> FP8 per-tensor. + - quantizer_name: '*mixer.shared_experts.*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.shared_experts.*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mlp.shared_experts*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mlp.shared_experts*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # Mamba mixer linears -> FP8 per-tensor. + - quantizer_name: '*mixer.in_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.in_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # KV cache -> FP8. + - quantizer_name: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 diff --git a/tests/unit/torch/quantization/test_nvfp4_four_over_six.py b/tests/unit/torch/quantization/test_nvfp4_four_over_six.py new file mode 100644 index 00000000000..3708f50fd13 --- /dev/null +++ b/tests/unit/torch/quantization/test_nvfp4_four_over_six.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU tests for NVFP4 Four-Over-Six (4/6) adaptive weight scaling. + +4/6 is a weight-only refinement applied under max calibration: it uses an FP8 +normalization max of 256 (instead of 448) and, per block, picks an M=4 scale +candidate (the M=6 scale times 6/4) when it lowers per-block reconstruction MSE +(arXiv:2512.02010). It is enabled via ``block_sizes={"four_over_six": True}`` on +(static, max-calibrated) weight quantizers. +""" + +from types import SimpleNamespace + +import torch + +import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.config import choices +from modelopt.torch.quantization.qtensor.nvfp4_tensor import ( + F8_E4M3_MAX, + F8_E4M3_MAX_46, + FP4_E2M1_MAX, + FP4_E2M1_MAX_M4, + NVFP4QTensor, + _cast_per_block_scale_to_fp8, +) + +BLOCK_SIZE = 16 + + +def _per_block_amax(weight: torch.Tensor, block_size: int) -> torch.Tensor: + """Per-block amax via plain reshape, matching reduce_block_amax on the last axis.""" + blocks = weight.abs().view(*weight.shape[:-1], -1, block_size) + return blocks.amax(dim=-1).float() + + +class TestConstants: + def test_fp8_and_e2m1_constants(self): + assert F8_E4M3_MAX == 448.0 + assert F8_E4M3_MAX_46 == 256.0 + assert FP4_E2M1_MAX == 6.0 + assert FP4_E2M1_MAX_M4 == 4.0 + + def test_m4_over_m6_ratio_is_1_5(self): + assert FP4_E2M1_MAX / FP4_E2M1_MAX_M4 == 1.5 + + +class TestIsFourOverSix: + def test_flag_true(self): + q = SimpleNamespace(block_sizes={-1: BLOCK_SIZE, "four_over_six": True}) + assert NVFP4QTensor._is_four_over_six(q) is True + + def test_flag_false(self): + q = SimpleNamespace(block_sizes={-1: BLOCK_SIZE, "four_over_six": False}) + assert NVFP4QTensor._is_four_over_six(q) is False + + def test_flag_absent_defaults_false(self): + q = SimpleNamespace(block_sizes={-1: BLOCK_SIZE}) + assert NVFP4QTensor._is_four_over_six(q) is False + + def test_missing_block_sizes_defaults_false(self): + assert NVFP4QTensor._is_four_over_six(SimpleNamespace()) is False + assert NVFP4QTensor._is_four_over_six(SimpleNamespace(block_sizes=None)) is False + + +class TestScalingFactor2: + def test_256_vs_448_denominator(self): + torch.manual_seed(0) + w = torch.randn(8, 4 * BLOCK_SIZE) + wsf2_default = NVFP4QTensor.get_weights_scaling_factor_2(w, four_over_six=False) + wsf2_46 = NVFP4QTensor.get_weights_scaling_factor_2(w, four_over_six=True) + # wsf2 = amax / (6 * m_fp8); only m_fp8 differs (448 vs 256). + assert torch.allclose(wsf2_46 / wsf2_default, torch.tensor(448.0 / 256.0), rtol=1e-6) + + +class TestSelectFourOverSixScale: + def _setup(self, seed=0, rows=8, n_blocks=4): + torch.manual_seed(seed) + weight = torch.randn(rows, n_blocks * BLOCK_SIZE) + wsf2 = NVFP4QTensor.get_weights_scaling_factor_2(weight, four_over_six=True) + per_block_amax = _per_block_amax(weight, BLOCK_SIZE) + per_block_scale_m6 = per_block_amax / (FP4_E2M1_MAX * wsf2) + per_block_scale_m6[per_block_scale_m6 == 0] = 1.0 + return weight, wsf2, per_block_scale_m6 + + def test_returns_m6_or_m4_candidate(self): + weight, wsf2, m6 = self._setup() + selected = NVFP4QTensor._select_four_over_six_scale(weight, m6, wsf2, BLOCK_SIZE) + m4 = m6 * (FP4_E2M1_MAX / FP4_E2M1_MAX_M4) + is_m6 = torch.isclose(selected, m6) + is_m4 = torch.isclose(selected, m4) + assert (is_m6 | is_m4).all(), "Selected scale is neither the M=6 nor the M=4 candidate." + + def test_selection_never_increases_block_mse(self): + """Adaptive M=4/M=6 selection must not raise per-block MSE vs M=6 only (same alpha).""" + weight, wsf2, _ = self._setup(seed=3, rows=16, n_blocks=8) + # Both candidates share the same per-tensor alpha (wsf2); only per-block scale differs. + m6_scale, _ = NVFP4QTensor.get_weights_scaling_factor( + weight, BLOCK_SIZE, wsf2, keep_high_precision=True, four_over_six=False + ) + sel_scale, _ = NVFP4QTensor.get_weights_scaling_factor( + weight, BLOCK_SIZE, wsf2, keep_high_precision=True, four_over_six=True + ) + alpha = wsf2.float() + deq_m6 = NVFP4QTensor._fake_quant_to_e2m1( + weight, _cast_per_block_scale_to_fp8(m6_scale).float(), alpha, BLOCK_SIZE + ) + deq_sel = NVFP4QTensor._fake_quant_to_e2m1( + weight, + _cast_per_block_scale_to_fp8( + sel_scale, fp8_max_for_normalization=F8_E4M3_MAX_46 + ).float(), + alpha, + BLOCK_SIZE, + ) + w_blocks = weight.float().view(*weight.shape[:-1], -1, BLOCK_SIZE) + mse_m6 = ((w_blocks - deq_m6) ** 2).mean(dim=-1) + mse_sel = ((w_blocks - deq_sel) ** 2).mean(dim=-1) + assert (mse_sel <= mse_m6 + 1e-12).all(), "4/6 selection increased per-block MSE." + + def test_chooses_m4_when_strictly_better(self): + """At least one block should pick M=4 on random data (else selection is a no-op).""" + weight, wsf2, _ = self._setup(seed=7, rows=32, n_blocks=8) + m6_scale, _ = NVFP4QTensor.get_weights_scaling_factor( + weight, BLOCK_SIZE, wsf2, keep_high_precision=True, four_over_six=False + ) + sel_scale, _ = NVFP4QTensor.get_weights_scaling_factor( + weight, BLOCK_SIZE, wsf2, keep_high_precision=True, four_over_six=True + ) + assert not torch.allclose(sel_scale, m6_scale), "Expected some blocks to switch to M=4." + + +class TestRoundTripScales: + def test_no_zero_or_nan_scales(self): + torch.manual_seed(1) + weight = torch.cat([torch.randn(4, BLOCK_SIZE), torch.full((4, BLOCK_SIZE), 1e-12)], dim=0) + per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor( + weight, BLOCK_SIZE, four_over_six=True + ) + s = per_block_scale.float() + assert torch.isfinite(s).all(), f"Non-finite 4/6 scales: {s.tolist()}" + assert (s > 0).all(), f"Zero 4/6 scales: {s.tolist()}" + + +class TestNVFP4FourOverSixConfig: + @staticmethod + def _block_sizes(cfg, name): + entry = next(e for e in cfg["quant_cfg"] if e["quantizer_name"] == name) + return entry["cfg"]["block_sizes"] + + def test_weight_quantizer_is_static_with_four_over_six(self): + bs = self._block_sizes(mtq.NVFP4_FOUR_OVER_SIX_CFG, "*weight_quantizer") + assert bs.get("type") == "static" + # Schema coerces the bool to int 1; the feature reads it truthily. + assert bs.get("four_over_six") + + def test_input_quantizer_unchanged(self): + bs = self._block_sizes(mtq.NVFP4_FOUR_OVER_SIX_CFG, "*input_quantizer") + assert not bs.get("four_over_six", False) + + def test_registered_in_choices(self): + assert "NVFP4_FOUR_OVER_SIX_CFG" in choices diff --git a/tools/launcher/examples/nvidia/NVIDIA-Nemotron-3-Ultra-550B-A55B-BF16/megatron_lm_ptq.yaml b/tools/launcher/examples/nvidia/NVIDIA-Nemotron-3-Ultra-550B-A55B-BF16/megatron_lm_ptq.yaml index b39697ed7e0..cfcbee327ae 100644 --- a/tools/launcher/examples/nvidia/NVIDIA-Nemotron-3-Ultra-550B-A55B-BF16/megatron_lm_ptq.yaml +++ b/tools/launcher/examples/nvidia/NVIDIA-Nemotron-3-Ultra-550B-A55B-BF16/megatron_lm_ptq.yaml @@ -1,6 +1,7 @@ # Nemotron-3-Ultra-550B-A55B-BF16 PTQ quantization + export + vLLM generation test. -# Tested on B200 Blackwell GPUs. Uses Super NVFP4 mixed-FP8 max calibration recipe, similar to published NVFP4 checkpoint (which is Four Over Six scales). -# +# NVFP4 mixed-FP8 Four-Over-Six max calibration recipe used to create the NVFP4 checkpoint +# Tested on B200 Blackwell GPUs. + # Pipeline: # task_0 (quantize): 4 nodes x 4 GPUs = 16 ranks, TP=1 PP=1 EP=16 ETP=1. # Loads HF weights from /hf-local, saves PTQ ckpt to /cicd. @@ -18,7 +19,7 @@ job_name: Nemotron-3-Ultra_PTQ pipeline: skip: false allow_to_fail: false - note: "PTQ on Nemotron-3-Ultra-550B-A55B-BF16 (super-nvfp4-max-calib): quantize @ 4 nodes, export @ 3 nodes, vLLM generation test@ 1 node" + note: "PTQ on Nemotron-3-Ultra-550B-A55B-BF16 (nvfp4-46-max): quantize @ 4 nodes, export @ 3 nodes, vLLM generation test@ 1 node" task_0: script: common/megatron_lm/quantize/quantize.sh @@ -29,7 +30,7 @@ pipeline: - --calib-size 32 environment: - MLM_MODEL_CFG: nvidia/NVIDIA-Nemotron-3-Ultra-550B-A55B-BF16 - - QUANT_CFG: models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib + - QUANT_CFG: huggingface/models/nvidia/Nemotron-3-Ultra-550B-A55B/ptq/nvfp4-46-max - HF_MODEL_CKPT: /hf-local/nvidia/NVIDIA-Nemotron-3-Ultra-550B-A55B-BF16 # MMLU + Export run as separate tasks; quantize.sh does quantize only. - RUN_MMLU: "false" @@ -52,7 +53,7 @@ pipeline: script: common/megatron_lm/export/export.sh environment: - MLM_MODEL_CFG: nvidia/NVIDIA-Nemotron-3-Ultra-550B-A55B-BF16 - - QUANT_CFG: models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib + - QUANT_CFG: huggingface/models/nvidia/Nemotron-3-Ultra-550B-A55B/ptq/nvfp4-46-max - HF_MODEL_CKPT: /hf-local/nvidia/NVIDIA-Nemotron-3-Ultra-550B-A55B-BF16 - TP: "1" - PP: "12" @@ -73,17 +74,17 @@ pipeline: task_2: script: common/vllm/query.sh args: - - --model /cicd/export/nvidia/NVIDIA-Nemotron-3-Ultra-550B-A55B-BF16_super-nvfp4-max-calib + - --model /cicd/export/nvidia/NVIDIA-Nemotron-3-Ultra-550B-A55B-BF16_nvfp4-46-max - --tensor-parallel-size 4 - --trust-remote-code - -- - --data common/vllm/gpqa_sample.jsonl - --max-tokens 256 - --num-shards 1 - - --save /cicd/vllm/NVIDIA-Nemotron-3-Ultra-550B-A55B-BF16_super-nvfp4-max-calib + - --save /cicd/vllm/NVIDIA-Nemotron-3-Ultra-550B-A55B-BF16_nvfp4-46-max slurm_config: _factory_: "slurm_factory" - container: vllm/vllm-openai:v0.21.0 + container: vllm/vllm-openai:v0.22.0 partition: batch nodes: 1 ntasks_per_node: 1 From db5497e2b1a5ab15a65aca9c2f157a56d5d6a276 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Fri, 12 Jun 2026 07:59:12 -0700 Subject: [PATCH 2/4] Set mypy python_version to 3.10 mypy defaulted to the running interpreter's version and failed to parse 3.10 match/case syntax (e.g. precisionconverter.py). Pin python_version to 3.10 so mypy parses modern syntax. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Jennifer Chen --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 28ca051d9e7..01bde2da4ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -253,6 +253,7 @@ split-on-trailing-comma = false [tool.mypy] files = "." +python_version = "3.10" install_types = true non_interactive = true show_error_codes = true From 90b1c76c954b58eced55a9a78a227b8b50f4decc Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Fri, 12 Jun 2026 12:18:59 -0700 Subject: [PATCH 3/4] update unit tests Signed-off-by: Jennifer Chen --- .../test_nvfp4_fp8_sweep_kernel.py | 3 +- .../quantization/test_config_validation.py | 45 +++++++++++ .../torch/quantization/test_mse_calibrator.py | 45 +++++++++++ .../quantization/test_nvfp4_four_over_six.py | 10 +-- .../test_nvfp4_static_export_cpu.py | 79 +++++++++++++++++++ 5 files changed, 176 insertions(+), 6 deletions(-) diff --git a/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py index d1eba4987d3..ac2beb44853 100644 --- a/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py +++ b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py @@ -37,6 +37,7 @@ from modelopt.torch.quantization.calib import NVFP4MSECalibrator from modelopt.torch.quantization.extensions import get_cuda_ext_mx from modelopt.torch.quantization.nn import TensorQuantizer +from modelopt.torch.quantization.qtensor.nvfp4_tensor import FP8_E4M3_MAX from modelopt.torch.quantization.tensor_quant import static_blockwise_fp4_fake_quant BLOCK_SIZE = 16 @@ -173,7 +174,7 @@ def test_sweep_stores_fp32_amax_and_preserves_output_dtype(dtype, triton_enabled amax = cal.compute_amax() assert amax.dtype == torch.float32 - xq = static_blockwise_fp4_fake_quant(x, amax, global_amax, True, x.dtype) + xq = static_blockwise_fp4_fake_quant(x, amax, global_amax, True, FP8_E4M3_MAX, x.dtype) assert xq.dtype == x.dtype diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index 93a60924792..c7ad571dce9 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -18,6 +18,7 @@ import pytest from pydantic import ValidationError +import modelopt.torch.quantization as mtq from modelopt.torch.quantization.algorithms import _match_quantizer_cfg from modelopt.torch.quantization.config import ( FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, @@ -28,6 +29,7 @@ W4A8_AWQ_BETA_CFG, MaxCalibConfig, QuantizeConfig, + QuantizerAttributeConfig, find_quant_cfg_entry_by_path, need_calibration, normalize_quant_cfg_list, @@ -603,3 +605,46 @@ def test_unknown_field_still_rejected(self): """extra='forbid' must still reject unrelated unknown fields.""" with pytest.raises(ValidationError): MaxCalibConfig(not_a_real_field=True) + + +class TestFourOverSixBlockSizes: + """`four_over_six` is an accepted block_sizes key for NVFP4 4/6 adaptive weight scaling. + + The block_sizes validator only permits a fixed set of string keys + (``type``, ``scale_bits``, ``scale_block_sizes``, ``four_over_six``); any other + string key is rejected. See QuantizerAttributeConfig.validate_block_sizes. + """ + + def test_four_over_six_true_accepted(self): + cfg = QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: 16, "type": "static", "scale_bits": (4, 3), "four_over_six": True}, + ) + # The schema coerces the bool to int 1; the feature reads it truthily. + assert cfg.block_sizes["four_over_six"] + + def test_four_over_six_false_accepted(self): + cfg = QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: 16, "type": "static", "four_over_six": False}, + ) + # Coerced to int 0; must read falsy. + assert not cfg.block_sizes["four_over_six"] + + def test_unknown_block_sizes_string_key_rejected(self): + """A string key outside the allow-list is rejected by the validator.""" + with pytest.raises(ValidationError): + QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: 16, "not_a_real_key": True}, + ) + + def test_nvfp4_four_over_six_cfg_validates(self): + """The shipped NVFP4_FOUR_OVER_SIX_CFG preset validates as a QuantizeConfig.""" + cfg = QuantizeConfig(**mtq.NVFP4_FOUR_OVER_SIX_CFG) + assert isinstance(cfg.quant_cfg, list) + assert len(cfg.quant_cfg) > 0 + + def test_nvfp4_four_over_six_cfg_needs_calibration(self): + """The 4/6 preset is statically calibrated, so it requires calibration.""" + assert need_calibration(mtq.NVFP4_FOUR_OVER_SIX_CFG) diff --git a/tests/unit/torch/quantization/test_mse_calibrator.py b/tests/unit/torch/quantization/test_mse_calibrator.py index df63e51de20..7613f87711e 100644 --- a/tests/unit/torch/quantization/test_mse_calibrator.py +++ b/tests/unit/torch/quantization/test_mse_calibrator.py @@ -780,3 +780,48 @@ def test_grouped_static_nvfp4_quantizers_share_global_amax(self): for child in (model.q_proj, model.k_proj, model.v_proj): assert isinstance(child.weight_quantizer, NVFP4StaticQuantizer) assert torch.equal(child.weight_quantizer.global_amax, torch.tensor(6.0)) + + +class TestFourOverSixNormalizationThreading: + """4/6 is calibrator-agnostic: the MSE calibrator finds per-block amaxes the usual way, + and the 256-vs-448 FP8 normalization is applied later, at static fake-quant time. + + ``NVFP4StaticQuantizer._fake_quantize`` selects ``fp8_max_for_normalization`` from the + ``four_over_six`` block_sizes flag and threads it into ``static_blockwise_fp4_fake_quant``. + """ + + @staticmethod + def _make_static_quantizer(four_over_six: bool) -> NVFP4StaticQuantizer: + block_sizes = {-1: 16, "type": "static", "scale_bits": (4, 3)} + if four_over_six: + block_sizes["four_over_six"] = True + cfg = QuantizerAttributeConfig(num_bits=(2, 1), block_sizes=block_sizes) + q = NVFP4StaticQuantizer(quant_attribute_cfg=cfg) + q.amax = torch.full((1, 4), 0.5) + q.global_amax = torch.tensor(2.0) + return q + + def _captured_fp8_max(self, monkeypatch, four_over_six: bool) -> float: + import modelopt.torch.quantization.nn.modules.tensor_quantizer as tqm + + captured = {} + + def spy(*args, **kwargs): + # Call site: (inputs, amax, global_amax, quantize_block_scales, + # fp8_max_for_normalization, dtype, pass_through_bwd). + # The 4/6 → 256 vs 448 selection happens before this call, so capturing the + # threaded value is enough; we return a passthrough to avoid the triton kernel + # (unavailable on CPU) — this is a unit test of the threading, not the kernel. + captured["fp8_max"] = args[4] + return args[0] + + monkeypatch.setattr(tqm, "static_blockwise_fp4_fake_quant", spy) + q = self._make_static_quantizer(four_over_six) + q._fake_quantize(torch.randn(1, 64)) + return captured["fp8_max"] + + def test_four_over_six_threads_256(self, monkeypatch): + assert self._captured_fp8_max(monkeypatch, four_over_six=True) == 256.0 + + def test_default_threads_448(self, monkeypatch): + assert self._captured_fp8_max(monkeypatch, four_over_six=False) == 448.0 diff --git a/tests/unit/torch/quantization/test_nvfp4_four_over_six.py b/tests/unit/torch/quantization/test_nvfp4_four_over_six.py index 3708f50fd13..53255afea59 100644 --- a/tests/unit/torch/quantization/test_nvfp4_four_over_six.py +++ b/tests/unit/torch/quantization/test_nvfp4_four_over_six.py @@ -29,10 +29,10 @@ import modelopt.torch.quantization as mtq from modelopt.torch.quantization.config import choices from modelopt.torch.quantization.qtensor.nvfp4_tensor import ( - F8_E4M3_MAX, - F8_E4M3_MAX_46, FP4_E2M1_MAX, FP4_E2M1_MAX_M4, + FP8_E4M3_MAX, + FP8_E4M3_MAX_46, NVFP4QTensor, _cast_per_block_scale_to_fp8, ) @@ -48,8 +48,8 @@ def _per_block_amax(weight: torch.Tensor, block_size: int) -> torch.Tensor: class TestConstants: def test_fp8_and_e2m1_constants(self): - assert F8_E4M3_MAX == 448.0 - assert F8_E4M3_MAX_46 == 256.0 + assert FP8_E4M3_MAX == 448.0 + assert FP8_E4M3_MAX_46 == 256.0 assert FP4_E2M1_MAX == 6.0 assert FP4_E2M1_MAX_M4 == 4.0 @@ -120,7 +120,7 @@ def test_selection_never_increases_block_mse(self): deq_sel = NVFP4QTensor._fake_quant_to_e2m1( weight, _cast_per_block_scale_to_fp8( - sel_scale, fp8_max_for_normalization=F8_E4M3_MAX_46 + sel_scale, fp8_max_for_normalization=FP8_E4M3_MAX_46 ).float(), alpha, BLOCK_SIZE, diff --git a/tests/unit/torch/quantization/test_nvfp4_static_export_cpu.py b/tests/unit/torch/quantization/test_nvfp4_static_export_cpu.py index dfb776a0484..1c3d86f3fdf 100644 --- a/tests/unit/torch/quantization/test_nvfp4_static_export_cpu.py +++ b/tests/unit/torch/quantization/test_nvfp4_static_export_cpu.py @@ -24,6 +24,7 @@ from modelopt.torch.quantization.config import QuantizerAttributeConfig from modelopt.torch.quantization.nn import NVFP4StaticQuantizer from modelopt.torch.quantization.qtensor import NVFP4QTensor +from modelopt.torch.quantization.qtensor.nvfp4_tensor import FP4_E2M1_MAX, FP4_E2M1_MAX_M4 BLOCK_SIZE = 16 FP4_VALUES = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6]) @@ -318,3 +319,81 @@ def test_ultra_v3_layer1_distribution_byte_distribution_sane(self): # FP8 e4m3fn NaN bytes are 0x7F (127) and 0xFF (255). nan_count = int(((ws_bytes == 127) | (ws_bytes == 255)).sum().item()) assert nan_count == 0, f"static export emitted {nan_count} NaN FP8 weight_scale bytes" + + +def _make_static_quantizer_46( + per_block_amax: torch.Tensor, global_amax: torch.Tensor +) -> NVFP4StaticQuantizer: + """Static NVFP4 quantizer with 4/6 adaptive block scaling enabled.""" + cfg = QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={ + -1: BLOCK_SIZE, + "type": "static", + "scale_bits": (4, 3), + "four_over_six": True, + }, + ) + q = NVFP4StaticQuantizer(quant_attribute_cfg=cfg) + q.amax = per_block_amax.clone() + q.global_amax = global_amax.clone() + return q + + +class TestNVFP4StaticFourOverSixExport: + """4/6 static export normalizes block scales by 256 (not 448) and stays finite. + + The only difference from the default NVFP4 export is the FP8 normalization max: + weight_scale_2 = global_amax / (6 * 256) instead of / (6 * 448), and per block the + export picks the M=4 scale (M=6 scale * 6/4) when it lowers reconstruction MSE. + """ + + def test_scale_2_normalizes_by_256(self): + weight = _layer1_routed_expert_like(32, 128, n_outliers=4, seed=5) + block_max = _per_block_max(weight) + global_amax = block_max.max() + amax = block_max.clamp(min=1e-30) + + ws2_46 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer( + _make_static_quantizer_46(amax, global_amax) + ) + ws2_default = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer( + _make_static_quantizer(amax, global_amax) + ) + + # weight_scale_2 = global_amax / (6 * m_fp8); only m_fp8 differs (256 vs 448). + assert torch.allclose(ws2_46, global_amax.float() / (6.0 * 256.0), rtol=1e-6) + assert torch.allclose(ws2_46 / ws2_default, torch.tensor(448.0 / 256.0), rtol=1e-5) + + def test_export_round_trip_finite(self): + weight = _layer1_routed_expert_like(64, 256, n_outliers=4, seed=6) + block_max = _per_block_max(weight) + global_amax = block_max.max() + amax = block_max.clamp(min=1e-30) + q = _make_static_quantizer_46(amax, global_amax) + + ws, ws2, deq = _export_round_trip(weight, q) + + assert torch.isfinite(ws.float()).all(), "weight_scale (FP8) must be finite" + assert torch.isfinite(ws2).all(), "weight_scale_2 (FP32) must be finite" + assert torch.isfinite(deq.float()).all(), "dequantized weight must be finite" + + def test_selected_high_precision_scale_is_m6_or_m4(self): + """Each high-precision per-block scale is either the M=6 candidate or its 6/4 multiple.""" + weight = _layer1_routed_expert_like(32, 128, n_outliers=6, seed=7) + block_max = _per_block_max(weight) + global_amax = block_max.max() + amax = block_max.clamp(min=1e-30) + q = _make_static_quantizer_46(amax, global_amax) + + ws2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(q) + selected, _ = NVFP4QTensor.get_weights_scaling_factor_from_quantizer( + q, weight, ws2, keep_high_precision=True + ) + + # Static-path M=6 baseline per-block scale = per_block_amax / 6 (no zero blocks here). + m6 = (amax.float() / FP4_E2M1_MAX).view_as(selected) + m4 = m6 * (FP4_E2M1_MAX / FP4_E2M1_MAX_M4) + is_m6 = torch.isclose(selected, m6, rtol=1e-5) + is_m4 = torch.isclose(selected, m4, rtol=1e-5) + assert (is_m6 | is_m4).all(), "Selected scale is neither the M=6 nor the M=4 candidate." From 86ac97ba73e4315ced5d0be74db8b36930827557 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Fri, 12 Jun 2026 12:41:56 -0700 Subject: [PATCH 4/4] use cast_fp4 in e2m1 fake quant Signed-off-by: Jennifer Chen --- .../torch/quantization/backends/nvfp4_gemm.py | 5 +- .../quantization/qtensor/nvfp4_tensor.py | 48 +++++++------------ 2 files changed, 21 insertions(+), 32 deletions(-) diff --git a/modelopt/torch/quantization/backends/nvfp4_gemm.py b/modelopt/torch/quantization/backends/nvfp4_gemm.py index 3ace55b1e75..b1e6cbe05a8 100644 --- a/modelopt/torch/quantization/backends/nvfp4_gemm.py +++ b/modelopt/torch/quantization/backends/nvfp4_gemm.py @@ -23,6 +23,7 @@ from modelopt.torch.quantization.backends.utils import fp4_compatible from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear from modelopt.torch.quantization.qtensor import NVFP4QTensor, QTensorWrapper +from modelopt.torch.quantization.qtensor.nvfp4_tensor import FP8_E4M3_MAX, FP8_E4M3_MAX_46 from modelopt.torch.quantization.utils import reduce_amax @@ -76,11 +77,11 @@ def _fp4_linear( if not cached_input_global_scale: input_amax = quant_module.input_quantizer.amax or reduce_amax(input_tensor) assert input_amax != 0 - quant_module._input_global_scale = 448.0 * 6.0 / input_amax.float() + quant_module._input_global_scale = FP8_E4M3_MAX * 6.0 / input_amax.float() weight = quant_module.weight is_four_over_six = bool(quant_module.weight_quantizer.block_sizes.get("four_over_six", False)) - weight_fp8_max = 256.0 if is_four_over_six else 448.0 + weight_fp8_max = FP8_E4M3_MAX_46 if is_four_over_six else FP8_E4M3_MAX cached_weight_global_scale = hasattr(quant_module, "_weight_global_scale") if isinstance(weight, QTensorWrapper): # weight is already compressed. diff --git a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py index 648b946b6e9..e6432be231d 100644 --- a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py +++ b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py @@ -29,8 +29,8 @@ # scales to either 4 or 6 per block, therefore the FP8 block scales are either 448 or 256. FP4_E2M1_MAX = 6.0 FP4_E2M1_MAX_M4 = 4.0 -F8_E4M3_MAX = 448.0 -F8_E4M3_MAX_46 = 256.0 +FP8_E4M3_MAX = 448.0 +FP8_E4M3_MAX_46 = 256.0 __all__ = ["NVFP4QTensor"] @@ -38,7 +38,7 @@ def _cast_per_block_scale_to_fp8( per_block_scale: torch.Tensor, per_block_scale_max: torch.Tensor | None = None, - fp8_max_for_normalization: float = F8_E4M3_MAX, + fp8_max_for_normalization: float = FP8_E4M3_MAX, ) -> torch.Tensor: """Clamp to FP8 E4M3FN range [2**-9, 448] and cast — avoids underflow→0 / overflow→NaN. @@ -49,7 +49,7 @@ def _cast_per_block_scale_to_fp8( """ if per_block_scale_max is not None: per_block_scale = per_block_scale.float() * fp8_max_for_normalization / per_block_scale_max - return per_block_scale.clamp(min=2**-9, max=448.0).to(torch.float8_e4m3fn) + return per_block_scale.clamp(min=2**-9, max=FP8_E4M3_MAX).to(torch.float8_e4m3fn) class NVFP4QTensor(BaseQuantizedTensor): @@ -109,7 +109,7 @@ def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer): Returns: The global scaling factor as a float tensor. """ - m_fp8 = F8_E4M3_MAX_46 if cls._is_four_over_six(weight_quantizer) else F8_E4M3_MAX + m_fp8 = FP8_E4M3_MAX_46 if cls._is_four_over_six(weight_quantizer) else FP8_E4M3_MAX global_amax = cls._get_static_global_amax(weight_quantizer) if global_amax is not None: return global_amax.float() / (FP4_E2M1_MAX * m_fp8) @@ -149,7 +149,7 @@ def get_weights_scaling_factor_from_quantizer( ) is_four_over_six = cls._is_four_over_six(weight_quantizer) - fp8_max_for_normalization = F8_E4M3_MAX_46 if is_four_over_six else F8_E4M3_MAX + fp8_max_for_normalization = FP8_E4M3_MAX_46 if is_four_over_six else FP8_E4M3_MAX if cls._is_static_quantizer(weight_quantizer): # Static path: use pre-computed per-block amax values from quantizer @@ -240,7 +240,7 @@ def get_weights_scaling_factor( @classmethod def get_weights_scaling_factor_2(cls, input: torch.Tensor, four_over_six: bool = False): """Returns per tensor weight scaling factor.""" - m_fp8 = F8_E4M3_MAX_46 if four_over_six else F8_E4M3_MAX + m_fp8 = FP8_E4M3_MAX_46 if four_over_six else FP8_E4M3_MAX return reduce_amax(input).float() / (FP4_E2M1_MAX * m_fp8) @classmethod @@ -251,7 +251,7 @@ def _select_four_over_six_scale( weights_scaling_factor_2: torch.Tensor, block_size: int, per_block_scale_max: torch.Tensor | None = None, - fp8_max_for_normalization: float = F8_E4M3_MAX, + fp8_max_for_normalization: float = FP8_E4M3_MAX, ) -> torch.Tensor: """Pick M=4 or M=6 per block by per-block MSE (paper §3.1, arXiv:2512.02010v5). @@ -264,7 +264,7 @@ def _select_four_over_six_scale( per_block_scale_m6: F32 per-block scale under the default M=6 rule. Shape [..., num_blocks]. weights_scaling_factor_2: per-tensor F32 alpha. Must already use the 4/6-adjusted - denominator (FP4_E2M1_MAX * F8_E4M3_MAX_46), set by get_weights_scaling_factor_2*. + denominator (FP4_E2M1_MAX * FP8_E4M3_MAX_46), set by get_weights_scaling_factor_2*. block_size: block length (16 for NVFP4). per_block_scale_max: optional max scale value for the static-export F8 rescale (see _cast_per_block_scale_to_fp8). Pass-through only. @@ -311,7 +311,8 @@ def _fake_quant_to_e2m1( ) -> torch.Tensor: """Round-trip quantize one candidate (scale_block ⊗ alpha) and return dequantized blocks. - Returns shape [..., num_blocks, block_size] in float32. + Reuses ``_cast_fp4`` for the E2M1 rounding so the per-block MSE scoring matches the + deployed NVFP4 quantization exactly. Returns shape [..., num_blocks, block_size] in float32. """ device = weight.device w_blocks = weight.to(torch.float32).view(*weight.shape[:-1], -1, block_size) @@ -321,25 +322,12 @@ def _fake_quant_to_e2m1( divisor = scale * alpha_v else: divisor = scale * alpha_v.view(*alpha_v.shape, *([1] * (scale.dim() - alpha_v.dim()))) - scaled = w_blocks / divisor - - # Sign + abs, then round abs to E2M1 grid using the same bounds as _cast_fp4. Values - # whose magnitude exceeds the implicit grid max (6.0) are clamped before rounding. - sign = torch.sign(scaled) - abs_v = scaled.abs().clamp_(max=FP4_E2M1_MAX) - bounds = cls.get_e2m1_bounds(device) - ord_ = torch.searchsorted(bounds, abs_v, out_int32=True) - # Mirror the equals-bound nudge in _cast_fp4 (round-half-up at odd-indexed bounds) - odd_bounds = bounds[[1, 3, 5]] - nudge = torch.any(abs_v.unsqueeze(-1) == odd_bounds, dim=-1).to(ord_.dtype) - ord_ = ord_ + nudge - # Map ordinal → magnitude - e2m1_pos = torch.tensor( - [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], device=device, dtype=torch.float32 - ) - ord_ = ord_.clamp_(0, 7) - mag = e2m1_pos[ord_.long()] - return sign * mag * divisor + # Quantize to signed E2M1 codes, then map each code back to its FP4 value via + # e2m1_values (code = (sign << 3) + magnitude ordinal). searchsorted in _cast_fp4 + # saturates at ordinal 7 (= 6.0), so out-of-range magnitudes clamp naturally. + codes = cls._cast_fp4(w_blocks / divisor) + fp4_vals = cls.get_e2m1_values(device)[codes.long()] + return fp4_vals * divisor @classmethod def get_activation_scaling_factor(cls, quantizer): @@ -353,7 +341,7 @@ def get_activation_scaling_factor(cls, quantizer): if amax is None: return None - activation_scaling_factor = amax.float() / (quantizer.maxbound * 448.0) + activation_scaling_factor = amax.float() / (quantizer.maxbound * FP8_E4M3_MAX) assert torch.all(activation_scaling_factor > 0), ( f" activation scaling factor {activation_scaling_factor} not positive."