Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion modelopt/torch/export/layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,7 +1198,8 @@ def sync_moe_gate_up_amax(model: nn.Module) -> int:
"""Take element-wise max of gate and up weight quantizer amaxes per expert.

Serving engines fuse gate_proj and up_proj into a single gate_up_proj and
require a single weight_scale_2. Since weight_scale_2 = amax / (6 * 448),
require a single weight_scale_2. Since weight_scale_2 = amax / (6 * m_fp8)
(m_fp8=448 normally, 256 for NVFP4 4/6 mode),
syncing amaxes before quantization ensures the per-block weight_scale values
are computed against a consistent global scale.

Expand Down
17 changes: 15 additions & 2 deletions modelopt/torch/kernels/quantization/gemm/fp4_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
torch.bfloat16: tl.bfloat16,
}

FP8_E4M3_MAX = 448.0


def _torch_dtype_to_tl(dtype: torch.dtype):
if dtype not in _TORCH_TO_TL_DTYPE:
Expand Down Expand Up @@ -211,6 +213,7 @@ def compute_fp4_scales(
amax: torch.Tensor,
global_amax: torch.Tensor | None = None,
quantize_block_scales: bool = True,
fp8_max_for_normalization: float = FP8_E4M3_MAX,
) -> torch.Tensor:
"""Compute per-block FP4 scales from amax values.

Expand All @@ -220,6 +223,8 @@ def compute_fp4_scales(
amax: Per-block amax values (any shape).
global_amax: Global amax for FP8 two-level scaling. Computed from *amax* if None.
quantize_block_scales: If True, quantize scales to FP8 E4M3.
fp8_max_for_normalization: FP8 max value used to normalize per-block scales
before FP8 quantization (default 448.0; use 256.0 for NVFP4 4/6 mode).

Returns:
Per-block scales (same shape as *amax*), float32.
Expand All @@ -235,7 +240,7 @@ def compute_fp4_scales(
global_amax = reduce_amax(amax, axis=None, keepdims=False, squeeze_scalar=True)

global_amax = global_amax.float()
scale_fp8_quant_amax = global_amax / 6.0
scale_fp8_quant_amax = global_amax * (FP8_E4M3_MAX / fp8_max_for_normalization) / 6.0
scale = scaled_e4m3_impl(scale, scale_fp8_quant_amax)

return scale
Expand All @@ -246,6 +251,7 @@ def static_blockwise_fp4_fake_quant(
amax: torch.Tensor,
global_amax: torch.Tensor | None = None,
quantize_block_scales: bool = True,
fp8_max_for_normalization: float = FP8_E4M3_MAX,
out_dtype: torch.dtype | None = None,
):
"""Static blockwise FP4 fake quantization using Triton kernel.
Expand All @@ -261,6 +267,8 @@ def static_blockwise_fp4_fake_quant(
consumes it as a flat 1-D buffer of length ``NUM_FP4_BLOCKS``.
global_amax: FP32 scalar global amax. If provided, used to compute scale_fp8_quant_amax.
quantize_block_scales: If True, quantize block scales to FP8.
fp8_max_for_normalization: FP8 max value used to normalize per-block scales
before FP8 quantization (default 448.0; use 256.0 for NVFP4 4/6 mode).
out_dtype: Output dtype. Defaults to x.dtype if None.
"""
original_shape = x.shape
Expand All @@ -275,7 +283,12 @@ def static_blockwise_fp4_fake_quant(
if out_dtype is None:
out_dtype = x.dtype

scale = compute_fp4_scales(amax, global_amax, quantize_block_scales)
scale = compute_fp4_scales(
amax,
global_amax,
quantize_block_scales,
fp8_max_for_normalization=fp8_max_for_normalization,
)

x_flat = x.contiguous().view(-1)
y_flat = torch.empty_like(x_flat, dtype=out_dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def fp4_fake_quant_block(
x: torch.Tensor,
global_amax: torch.Tensor,
block_size: int = 16,
fp8_max_for_normalization: float = 448.0,
tile_rows: int = 16,
tile_cols: int = 64,
num_warps: int | None = None,
Expand All @@ -114,6 +115,8 @@ def fp4_fake_quant_block(
x (torch.Tensor): Input tensor of shape ``(M, N)`` or higher.
global_amax (torch.Tensor): Global maximum value tensor for scaling.
block_size (int): Number of elements per FP4 block.
fp8_max_for_normalization (float): FP8 max value used to normalize per-block
scales before FP8 quantization (default 448.0; use 256.0 for NVFP4 4/6 mode).
tile_rows (int, optional): Row tile size. Defaults to 16.
tile_cols (int, optional): Column tile size. Defaults to 64. Rounded up to
the nearest multiple of ``block_size`` internally.
Expand All @@ -137,7 +140,7 @@ def fp4_fake_quant_block(
tile_cols_aligned = ((tile_cols + block_size - 1) // block_size) * block_size
num_fp4_blocks = tile_cols_aligned // block_size

global_scale = (global_amax.float() / (6.0 * 448.0)).to(x.device)
global_scale = (global_amax.float() / (6.0 * fp8_max_for_normalization)).to(x.device)

grid = lambda *_: (triton.cdiv(M, tile_rows), triton.cdiv(N, tile_cols_aligned))

Expand Down
16 changes: 14 additions & 2 deletions modelopt/torch/quantization/backends/nvfp4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from modelopt.torch.quantization.backends.utils import fp4_compatible
from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear
from modelopt.torch.quantization.qtensor import NVFP4QTensor, QTensorWrapper
from modelopt.torch.quantization.qtensor.nvfp4_tensor import FP8_E4M3_MAX, FP8_E4M3_MAX_46
from modelopt.torch.quantization.utils import reduce_amax


Expand Down Expand Up @@ -76,9 +77,11 @@ def _fp4_linear(
if not cached_input_global_scale:
input_amax = quant_module.input_quantizer.amax or reduce_amax(input_tensor)
assert input_amax != 0
quant_module._input_global_scale = 448.0 * 6.0 / input_amax.float()
quant_module._input_global_scale = FP8_E4M3_MAX * 6.0 / input_amax.float()

weight = quant_module.weight
is_four_over_six = bool(quant_module.weight_quantizer.block_sizes.get("four_over_six", False))
weight_fp8_max = FP8_E4M3_MAX_46 if is_four_over_six else FP8_E4M3_MAX

cached_weight_global_scale = hasattr(quant_module, "_weight_global_scale")
if isinstance(weight, QTensorWrapper): # weight is already compressed.
Expand All @@ -102,7 +105,7 @@ def _fp4_linear(
if not cached_weight_global_scale:
weight_amax = quant_module.weight_quantizer.amax or reduce_amax(weight)
assert weight_amax != 0
quant_module._weight_global_scale = 448.0 * 6.0 / weight_amax.float()
quant_module._weight_global_scale = weight_fp8_max * 6.0 / weight_amax.float()
quant_module._weight_amax = weight_amax

weight_fp4, weight_scale = torch.ops.trtllm.fp4_quantize(
Expand Down Expand Up @@ -211,6 +214,15 @@ def _nvfp4_availability_check(module, input, args, kwargs):
if not hasattr(module, "input_quantizer") or not hasattr(module, "weight_quantizer"):
return False

# 4/6 relies on adaptive per-block M=4 vs M=6 selection. TRT on-the-fly quantization
# from FP16/BF16 weights only has a global scale and cannot reproduce that selection.
# Keep backend available for pre-quantized weights (QTensorWrapper) where per-block
# scales are already materialized.
if module.weight_quantizer.block_sizes.get("four_over_six", False) and not isinstance(
module.weight, QTensorWrapper
):
return False

quant_cfg_list: list = mtq.NVFP4_DEFAULT_CFG["quant_cfg"]
# Quantizer configs
input_cfg = mtq.config.find_quant_cfg_entry_by_path(quant_cfg_list, "*input_quantizer").get(
Expand Down
6 changes: 5 additions & 1 deletion modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def validate_block_sizes(cls, v, info: ValidationInfo):
)
for _k, _v in v.items():
if isinstance(_k, str):
assert _k in ["type", "scale_bits", "scale_block_sizes"]
assert _k in ["type", "scale_bits", "scale_block_sizes", "four_over_six"]
else:
assert isinstance(_k, int) and (_v is None or isinstance(_v, int))
return v
Expand Down Expand Up @@ -1328,6 +1328,9 @@ def _load_quantizer_cfg_dict_list(config_path: str) -> list[dict[str, Any]]:
FP8_AFFINE_KV_CFG: dict[str, Any] = _load_quantize_config_dict("configs/ptq/presets/kv/fp8_affine")

NVFP4_DEFAULT_CFG: dict[str, Any] = _load_quantize_config_dict("configs/ptq/presets/model/nvfp4")
NVFP4_FOUR_OVER_SIX_CFG: dict[str, Any] = _load_quantize_config_dict(
"configs/ptq/presets/model/nvfp4_four_over_six"
)
NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG: dict[str, Any] = _load_quantize_config_dict(
"configs/ptq/presets/model/nvfp4_w4a4_weight_mse_fp8_sweep"
)
Expand Down Expand Up @@ -1406,6 +1409,7 @@ def _load_quantizer_cfg_dict_list(config_path: str) -> list[dict[str, Any]]:
"NVFP4_AWQ_FULL_CFG",
"NVFP4_AWQ_LITE_CFG",
"NVFP4_DEFAULT_CFG",
"NVFP4_FOUR_OVER_SIX_CFG",
"NVFP4_FP8_MHA_CONFIG",
"NVFP4_KV_CFG",
"NVFP4_KV_ROTATE_CFG",
Expand Down
13 changes: 10 additions & 3 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,10 +787,13 @@ def _real_quantize(self, inputs):
outputs, _weights_scaling_factor, _weights_scaling_factor_2 = NVFP4QTensor.quantize(
inputs,
self._block_sizes[-1],
weights_scaling_factor_2=self.amax.float() / (448.0 * 6.0)
if self.amax is not None
else None,
weights_scaling_factor_2=(
NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(self)
if self.amax is not None
else None
),
try_tensorrt=True,
four_over_six=bool(self._block_sizes.get("four_over_six", False)),
)
buffer_to_register["_scale"] = _weights_scaling_factor
buffer_to_register["_double_scale"] = _weights_scaling_factor_2
Expand Down Expand Up @@ -1449,11 +1452,15 @@ def _apply(self, fn, recurse=True):
def _fake_quantize(self, inputs):
"""Fake quantization using two-level scaling with _amax and _global_amax."""
if self.amax is not None:
fp8_max_for_normalization = (
256.0 if self.block_sizes.get("four_over_six", False) else 448.0
)
return static_blockwise_fp4_fake_quant(
inputs,
self.amax,
self.global_amax, # Can be None, will be computed internally
True, # quantize_block_scales
fp8_max_for_normalization,
inputs.dtype,
self._pass_through_bwd,
)
Expand Down
Loading
Loading