Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
99a0444
Add Qwen-Image registration to diffusers quantization example
jingyu-ml Jun 11, 2026
e0c7910
Qwen-Image SVDQuant: block-range recipe, AWQ-style diffusers export, …
jingyu-ml Jun 12, 2026
6718e72
Qwen-Image SVDQuant: fix export/recipe defects, add Qwen QKV fusion +…
jingyu-ml Jun 12, 2026
8e3b3ed
Qwen-Image SVDQuant: add export/fusion/promotion tests; drop plan ter…
jingyu-ml Jun 12, 2026
027a5e2
Qwen-Image SVDQuant: offline tiny Qwen fixture + e2e export test
jingyu-ml Jun 12, 2026
a511058
Qwen-Image SVDQuant: offline tokenizer, stronger export test, drop St…
jingyu-ml Jun 12, 2026
a6a3d59
Qwen-Image SVDQuant: fix misplaced parametrize decorator + tighten ex…
jingyu-ml Jun 12, 2026
1cfe0b3
Qwen-Image SVDQuant: fix stale tiny-fixture tokenizer docstring
jingyu-ml Jun 12, 2026
789f4ef
Qwen-Image SVDQuant: add immutability + negative-loading tests; fix s…
jingyu-ml Jun 12, 2026
521cda0
Qwen-Image SVDQuant: drop invalid txt_seq_lens dummy kwarg + signatur…
jingyu-ml Jun 12, 2026
d776ffb
Qwen-Image SVDQuant: remove promoted export buffers from live module …
jingyu-ml Jun 12, 2026
c2250cb
Qwen-Image diffusers PTQ: fix 5 execution-surfaced bugs (fp8/nvfp4/sv…
jingyu-ml Jun 12, 2026
fb23155
Qwen-Image: drop operator-specific quantization harness from the example
jingyu-ml Jun 12, 2026
80e2244
Qwen-Image: consolidate tests into the shared diffusers export suite
jingyu-ml Jun 12, 2026
9b472b2
Qwen-Image: add a fast CPU test for the diffusers SVDQuant export pro…
jingyu-ml Jun 12, 2026
4dff6d5
Merge branch 'main' into feature/qwen-image-svdquant-nvfp4
jingyu-ml Jun 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions examples/diffusers/quantization/models_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from enum import Enum
from typing import Any

import torch
from diffusers import (
DiffusionPipeline,
FluxPipeline,
Expand All @@ -30,11 +31,19 @@
from diffusers import Flux2Pipeline
except ImportError:
Flux2Pipeline = None

# Qwen-Image classes were added in a recent diffusers release; import lazily so
# this example still imports on older diffusers versions.
try:
from diffusers import QwenImagePipeline
except ImportError:
QwenImagePipeline = None
from utils import (
filter_func_default,
filter_func_flux_dev,
filter_func_ltx2_vae,
filter_func_ltx_video,
filter_func_qwen_image,
filter_func_wan_vae,
filter_func_wan_video,
)
Expand All @@ -54,6 +63,7 @@ class ModelType(str, Enum):
LTX2 = "ltx-2"
WAN22_T2V_14b = "wan2.2-t2v-14b"
WAN22_T2V_5b = "wan2.2-t2v-5b"
QWEN_IMAGE = "qwen-image"


_FILTER_FUNC_MAP: dict[ModelType, Callable[[str], bool]] = {
Expand All @@ -63,6 +73,7 @@ class ModelType(str, Enum):
ModelType.LTX2: filter_func_ltx_video,
ModelType.WAN22_T2V_14b: filter_func_wan_video,
ModelType.WAN22_T2V_5b: filter_func_wan_video,
ModelType.QWEN_IMAGE: filter_func_qwen_image,
}

_VAE_FILTER_FUNC_MAP: dict[tuple[ModelType, str], Callable[[str], bool]] = {
Expand Down Expand Up @@ -95,6 +106,7 @@ def get_model_filter_func(
ModelType.LTX2: "Lightricks/LTX-2",
ModelType.WAN22_T2V_14b: "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
ModelType.WAN22_T2V_5b: "Wan-AI/Wan2.2-TI2V-5B-Diffusers",
ModelType.QWEN_IMAGE: "Qwen/Qwen-Image",
}

MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline] | None] = {
Expand All @@ -109,6 +121,7 @@ def get_model_filter_func(
ModelType.LTX2: None,
ModelType.WAN22_T2V_14b: WanPipeline,
ModelType.WAN22_T2V_5b: WanPipeline,
ModelType.QWEN_IMAGE: QwenImagePipeline,
}

# Shared dataset configurations
Expand Down Expand Up @@ -226,6 +239,23 @@ def get_model_filter_func(
),
},
},
ModelType.QWEN_IMAGE: {
"backbone": "transformer",
"dataset": _SD_PROMPTS_DATASET,
"inference_extra_args": {
"height": 1024,
"width": 1024,
},
# Quantize only ``transformer_blocks``; keep the first 2 and last 2 blocks
# (and everything outside ``transformer_blocks``) in original precision.
# Applied before calibration via ``build_block_range_quant_cfg`` so SVDQuant
# never mutates the excluded blocks' weights.
"block_range": {
"exclude_first_n": 2,
"exclude_last_n": 2,
"block_module": "transformer_blocks",
},
},
}


Expand Down Expand Up @@ -272,3 +302,70 @@ def parse_extra_params(
i += 1

return extra_params


def build_block_range_quant_cfg(
backbone: torch.nn.Module,
exclude_first_n: int,
exclude_last_n: int,
block_module: str = "transformer_blocks",
) -> list[dict[str, Any]]:
"""Build ordered ``quant_cfg`` rules for a transformer-block-only recipe.

The rules quantize only the linears under ``block_module`` while keeping the
first ``exclude_first_n`` and last ``exclude_last_n`` blocks -- and everything
outside ``block_module`` -- in original precision.

The rules are meant to be appended to the ``quant_cfg`` list consumed by
``mtq.quantize`` so the selection is applied BEFORE calibration. This is
required for SVDQuant, whose calibration subtracts a low-rank residual from
the weights of every *enabled* linear: disabling the excluded blocks only
after calibration would leave their weights mutated instead of bit-identical
to the original precision.

Rules are applied in order with later rules overriding earlier ones:
1. disable every linear weight/input quantizer,
2. re-enable only those under ``block_module`` (``enable`` is a top-level
QuantizerCfgEntry toggle; a ``None`` cfg keeps the base preset's quant params),
3. disable the first/last ``n`` blocks.

Raises:
ValueError: if the backbone has no ``block_module`` list, or it has fewer
than ``exclude_first_n + exclude_last_n + 2`` blocks (it requires at
least two quantized middle blocks).
"""
blocks = getattr(backbone, block_module, None)
if blocks is None or not hasattr(blocks, "__len__"):
raise ValueError(
f"Backbone {type(backbone).__name__} has no '{block_module}' module list; "
"cannot build the transformer-block-range recipe."
)
num_blocks = len(blocks)
# Require at least two quantized middle blocks so the recipe actually
# quantizes something (excluding first/last alone could otherwise leave 0-1
# quantized blocks). For the default 2+2 recipe this means n >= 6.
min_blocks = exclude_first_n + exclude_last_n + 2
if num_blocks < min_blocks:
raise ValueError(
f"'{block_module}' has only {num_blocks} block(s); excluding the first "
f"{exclude_first_n} and last {exclude_last_n} requires at least {min_blocks} blocks "
f"(at least 2 quantized middle blocks)."
)

excluded = sorted(
set(range(exclude_first_n)) | set(range(num_blocks - exclude_last_n, num_blocks))
)
# `enable` is a top-level QuantizerCfgEntry field (independent of `cfg`); a `None`
# cfg leaves the base preset's quant params untouched, so disabling then
# re-enabling restores the original (FP8/NVFP4/...) attributes. Putting `enable`
# under `cfg` is rejected by the QuantizerAttributeConfig validator.
rules: list[dict[str, Any]] = [
{"quantizer_name": "*weight_quantizer", "enable": False},
{"quantizer_name": "*input_quantizer", "enable": False},
{"quantizer_name": f"*{block_module}.*weight_quantizer", "enable": True},
{"quantizer_name": f"*{block_module}.*input_quantizer", "enable": True},
]
for idx in excluded:
rules.append({"quantizer_name": f"*{block_module}.{idx}.*weight_quantizer", "enable": False})
rules.append({"quantizer_name": f"*{block_module}.{idx}.*input_quantizer", "enable": False})
return rules
12 changes: 12 additions & 0 deletions examples/diffusers/quantization/pipeline_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ def create_pipeline_from(
"""
pipeline_cls = MODEL_PIPELINE[model_type]
if pipeline_cls is None:
if model_type == ModelType.QWEN_IMAGE:
raise ImportError(
"Qwen-Image requires a diffusers version that provides "
"QwenImagePipeline. Please upgrade diffusers (e.g. "
"`pip install -U diffusers`) to a release that includes Qwen-Image."
)
raise ValueError(f"Model type {model_type.value} does not use diffusers pipelines.")
model_id = (
MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path
Expand Down Expand Up @@ -99,6 +105,12 @@ def create_pipeline(self) -> Any:

pipeline_cls = MODEL_PIPELINE[self.config.model_type]
if pipeline_cls is None:
if self.config.model_type == ModelType.QWEN_IMAGE:
raise ImportError(
"Qwen-Image requires a diffusers version that provides "
"QwenImagePipeline. Please upgrade diffusers (e.g. "
"`pip install -U diffusers`) to a release that includes Qwen-Image."
)
raise ValueError(
f"Model type {self.config.model_type.value} does not use diffusers pipelines."
)
Expand Down
72 changes: 70 additions & 2 deletions examples/diffusers/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@
set_quant_config_attr,
)
from diffusers import DiffusionPipeline
from models_utils import MODEL_DEFAULTS, ModelType, get_model_filter_func, parse_extra_params
from onnx_utils.export import generate_fp8_scales, modelopt_export_sd
from models_utils import (
MODEL_DEFAULTS,
ModelType,
build_block_range_quant_cfg,
get_model_filter_func,
parse_extra_params,
)
from pipeline_manager import PipelineManager
from quantize_config import (
CalibrationConfig,
Expand Down Expand Up @@ -163,6 +168,28 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any:
}
)

# Apply the transformer-block-range recipe (e.g. Qwen-Image) BEFORE
# calibration. This restricts quantization to `transformer_blocks` and
# excludes the first/last N blocks. It must run before calibration so that
# SVDQuant does not mutate the weights of the excluded blocks. The recipe
# is format-agnostic (applies to FP8/NVFP4/SVDQuant alike).
block_range = MODEL_DEFAULTS.get(self.model_config.model_type, {}).get("block_range")
if block_range is not None:
recipe_rules = build_block_range_quant_cfg(
backbone,
exclude_first_n=block_range.get("exclude_first_n", 2),
exclude_last_n=block_range.get("exclude_last_n", 2),
block_module=block_range.get("block_module", "transformer_blocks"),
)
self.logger.info(
f"Applying block-range recipe ({len(recipe_rules)} rules) for "
f"{self.model_config.model_type.value}: quantize only "
f"'{block_range.get('block_module', 'transformer_blocks')}' excluding "
f"first {block_range.get('exclude_first_n', 2)} / last "
f"{block_range.get('exclude_last_n', 2)} blocks."
)
quant_cfg_list.extend(recipe_rules)

quant_config = {**base_cfg, "quant_cfg": quant_cfg_list}
set_quant_config_attr(
quant_config,
Expand Down Expand Up @@ -291,6 +318,11 @@ def export_onnx(
if not self.config.onnx_dir:
return

# onnx_graphsurgeon (pulled in by onnx_utils.export) is an optional dependency
# only needed for the ONNX export path; import lazily so the HF-checkpoint
# export runs without it installed.
from onnx_utils.export import generate_fp8_scales, modelopt_export_sd

self.logger.info(f"Starting ONNX export to {self.config.onnx_dir}")

if quant_format == QuantFormat.FP8 and self._has_conv_layers(backbone):
Expand Down Expand Up @@ -379,6 +411,10 @@ def export_hf_ckpt(self, pipe: Any, model_config: ModelConfig | None = None) ->
f"Invalid padding_strategy: {padding!r}. Expected 'row' or 'row_col'."
)
kwargs["padding_strategy"] = padding
# The diffusion transformer is large (~20B params); the unified export's
# layerwise-metadata post-processing does not support sharded safetensors, so
# save each component as a single file (no *.safetensors.index.json).
kwargs.setdefault("max_shard_size", "200GB")
export_hf_checkpoint(pipe, export_dir=self.config.hf_ckpt_dir, **kwargs)
self.logger.info("HuggingFace checkpoint export completed successfully")

Expand Down Expand Up @@ -542,6 +578,15 @@ def create_argument_parser() -> argparse.ArgumentParser:
export_group.add_argument(
"--restore-from", type=str, help="Path to restore from previous checkpoint"
)
export_group.add_argument(
"--sanity-image-path",
type=str,
default=None,
help="If set, generate one image from the in-memory quantized pipeline (after "
"quantization, before the weights are packed for export) and save it here. This is "
"a quick functional sanity check of quantized inference; it does NOT reload the "
"exported checkpoint.",
)
Comment on lines +581 to +589

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win

Reject --sanity-image-path for non-image pipelines at argument validation time.

This block assumes every supported model returns result.images[0], but the same CLI also supports video pipelines (LTX_*, WAN*). Today those runs will burn a full inference pass and then fail late on the save step instead of being rejected at the interface boundary.

Suggested guard
         pipeline_manager.print_quant_summary()

+        if args.sanity_image_path and model_type in {
+            ModelType.LTX_VIDEO_DEV,
+            ModelType.LTX2,
+            ModelType.WAN22_T2V_14b,
+            ModelType.WAN22_T2V_5b,
+        }:
+            parser.error("--sanity-image-path is only supported for image pipelines.")
+
         # Optional functional sanity check: generate one image from the in-memory
         # quantized pipeline. This runs BEFORE export (while weights are still
         # fake-quantized and runnable, not yet packed) and does not reload the

Also applies to: 729-750

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/diffusers/quantization/quantize.py` around lines 581 - 589, The CLI
currently accepts --sanity-image-path unconditionally and assumes generated
outputs have images, causing late failures for video/non-image pipelines; update
argument validation in quantize.py to reject --sanity-image-path early when the
selected pipeline type is not an image pipeline: after parsing args (or inside
the existing validation function / main pipeline selection flow), detect the
pipeline kind via the pipeline ID or class name used for inference (the same
symbol(s) that decide which pipeline to instantiate) and raise an error or exit
if --sanity-image-path is set but the pipeline is not one of the known image
pipelines (e.g., StableDiffusion/Any Image* pipelines); apply the same guard for
the second occurrence of this block noted around the other lines so non-image
pipelines fail at argument validation time rather than after a full run.

export_group.add_argument(
"--trt-high-precision-dtype",
type=str,
Expand Down Expand Up @@ -681,6 +726,29 @@ def forward_loop(mod):

pipeline_manager.print_quant_summary()

# Optional functional sanity check: generate one image from the in-memory
# quantized pipeline. This runs BEFORE export (while weights are still
# fake-quantized and runnable, not yet packed) and does not reload the
# exported checkpoint.
if args.sanity_image_path:
try:
logger.info(f"Generating sanity image to {args.sanity_image_path}")
inference_args = MODEL_DEFAULTS.get(model_type, {}).get("inference_extra_args", {})
result = pipe(
prompt="A high-quality photo of a cat wearing sunglasses",
num_inference_steps=calib_config.n_steps,
**inference_args,
)
sanity_path = Path(args.sanity_image_path)
sanity_path.parent.mkdir(parents=True, exist_ok=True)
result.images[0].save(str(sanity_path))
logger.info("Sanity image saved successfully")
except Exception as sanity_error:
# A requested sanity image is a positive success criterion: if it
# cannot be produced, fail loudly rather than reporting success.
logger.error(f"Sanity image generation failed: {sanity_error}", exc_info=True)
raise

for backbone_name, backbone in pipeline_manager.iter_backbones():
export_manager.export_onnx(
pipe,
Expand Down
24 changes: 24 additions & 0 deletions examples/diffusers/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,30 @@ def filter_func_wan_video(name: str) -> bool:
return pattern.match(name) is not None


# Qwen-Image's transformer has 60 ``transformer_blocks``. The recipe quantizes
# only those blocks while keeping the first two and last two -- and everything
# outside ``transformer_blocks`` -- in original precision. The model-agnostic,
# config-driven form of this recipe (deriving the block count from the model)
# lives in quantize.py; this name-only filter covers the plain FP8/NVFP4 path
# for the full 60-block Qwen-Image transformer.
QWEN_IMAGE_NUM_TRANSFORMER_BLOCKS = 60
_QWEN_IMAGE_BLOCK_RE = re.compile(r"(?:^|\.)transformer_blocks\.(\d+)(?:\.|$)")


def filter_func_qwen_image(name: str) -> bool:
"""Filter function specifically for Qwen-Image models.

Returns ``True`` for modules to keep in original precision (quantization
disabled): everything outside ``transformer_blocks``, plus the first two and
last two transformer blocks.
"""
match = _QWEN_IMAGE_BLOCK_RE.search(name)
if match is None:
return True
block_idx = int(match.group(1))
return block_idx < 2 or block_idx >= QWEN_IMAGE_NUM_TRANSFORMER_BLOCKS - 2


def load_calib_prompts(
batch_size,
calib_data_path: str | Path = "Gustavosta/Stable-Diffusion-Prompts",
Expand Down
37 changes: 37 additions & 0 deletions modelopt/torch/export/convert_hf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ def _quant_algo_to_group_config(quant_algo: str, group_size: int | None = None)
return {
"weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": gs},
}
elif quant_algo == "NVFP4_SVD":
gs = group_size or 16
return {
"input_activations": {
"dynamic": False,
"num_bits": 4,
"type": "float",
"group_size": gs,
},
"weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": gs},
"has_zero_point": False,
"pre_quant_scale": True,
}
elif quant_algo in ("NVFP4_AWQ", "W4A8_AWQ"):
gs = group_size or 128
return {
Expand Down Expand Up @@ -196,6 +209,30 @@ def convert_hf_quant_config_format(input_config: dict[str, Any]) -> dict[str, An
"targets": ["Linear"],
}
new_config["config_groups"] = {"group_0": config_group_details}
elif quant_algo_value == "NVFP4_SVD":
# NVFP4 + SVDQuant: NVFP4 weights/activations plus an AWQ-style
# pre_quant_scale and a low-rank residual (svdquant_lora_a/b) stored as
# <module>.pre_quant_scale / <module>.svdquant_lora_{a,b} in the
# safetensors. The config mirrors NVFP4 with a pre_quant_scale flag and
# the LoRA rank so consumers can reconstruct
# ``y = NVFP4_GEMM(x) + (x @ lora_a^T) @ lora_b^T``.
group_size = original_quantization_details.get("group_size", 16)
config_group_details = {
"input_activations": {
"dynamic": False,
"num_bits": 4,
"type": "float",
"group_size": group_size,
},
"weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": group_size},
"has_zero_point": False,
"pre_quant_scale": True,
"targets": ["Linear"],
}
lora_rank = original_quantization_details.get("lora_rank")
if lora_rank is not None:
config_group_details["lora_rank"] = lora_rank
new_config["config_groups"] = {"group_0": config_group_details}
elif quant_algo_value == "MIXED_PRECISION":
quantized_layers = original_quantization_details.get("quantized_layers", {})

Expand Down
Loading
Loading