From 99a04445072973b27f011af14092aca463b53dc0 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 11 Jun 2026 16:50:03 -0700 Subject: [PATCH 01/15] Add Qwen-Image registration to diffusers quantization example Register Qwen/Qwen-Image as a supported model in the diffusers quantization example: - ModelType.QWEN_IMAGE and lazy-imported QwenImagePipeline (so the example still imports on older diffusers). - MODEL_REGISTRY / MODEL_PIPELINE / MODEL_DEFAULTS entries (backbone="transformer", text-to-image calibration dataset). - An actionable ImportError when the installed diffusers lacks Qwen classes, instead of an opaque failure. - filter_func_qwen_image: quantize only transformer_blocks, keeping the first two and last two of the 60 blocks (and everything outside transformer_blocks) in original precision. Enables the plain FP8/NVFP4 export path for Qwen-Image. Core SVDQuant code is unchanged. (Qwen-Image SVDQuant checkpoint work, RLCR round 0 / M1.) Co-Authored-By: Claude Opus 4.8 Signed-off-by: Jingyu Xin --- .../diffusers/quantization/models_utils.py | 20 ++++++++++++++++ .../quantization/pipeline_manager.py | 6 +++++ examples/diffusers/quantization/utils.py | 24 +++++++++++++++++++ 3 files changed, 50 insertions(+) diff --git a/examples/diffusers/quantization/models_utils.py b/examples/diffusers/quantization/models_utils.py index b59744282f6..ecacc38d407 100644 --- a/examples/diffusers/quantization/models_utils.py +++ b/examples/diffusers/quantization/models_utils.py @@ -30,11 +30,19 @@ from diffusers import Flux2Pipeline except ImportError: Flux2Pipeline = None + +# Qwen-Image classes were added in a recent diffusers release; import lazily so +# this example still imports on older diffusers versions. +try: + from diffusers import QwenImagePipeline +except ImportError: + QwenImagePipeline = None from utils import ( filter_func_default, filter_func_flux_dev, filter_func_ltx2_vae, filter_func_ltx_video, + filter_func_qwen_image, filter_func_wan_vae, filter_func_wan_video, ) @@ -54,6 +62,7 @@ class ModelType(str, Enum): LTX2 = "ltx-2" WAN22_T2V_14b = "wan2.2-t2v-14b" WAN22_T2V_5b = "wan2.2-t2v-5b" + QWEN_IMAGE = "qwen-image" _FILTER_FUNC_MAP: dict[ModelType, Callable[[str], bool]] = { @@ -63,6 +72,7 @@ class ModelType(str, Enum): ModelType.LTX2: filter_func_ltx_video, ModelType.WAN22_T2V_14b: filter_func_wan_video, ModelType.WAN22_T2V_5b: filter_func_wan_video, + ModelType.QWEN_IMAGE: filter_func_qwen_image, } _VAE_FILTER_FUNC_MAP: dict[tuple[ModelType, str], Callable[[str], bool]] = { @@ -95,6 +105,7 @@ def get_model_filter_func( ModelType.LTX2: "Lightricks/LTX-2", ModelType.WAN22_T2V_14b: "Wan-AI/Wan2.2-T2V-A14B-Diffusers", ModelType.WAN22_T2V_5b: "Wan-AI/Wan2.2-TI2V-5B-Diffusers", + ModelType.QWEN_IMAGE: "Qwen/Qwen-Image", } MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline] | None] = { @@ -109,6 +120,7 @@ def get_model_filter_func( ModelType.LTX2: None, ModelType.WAN22_T2V_14b: WanPipeline, ModelType.WAN22_T2V_5b: WanPipeline, + ModelType.QWEN_IMAGE: QwenImagePipeline, } # Shared dataset configurations @@ -226,6 +238,14 @@ def get_model_filter_func( ), }, }, + ModelType.QWEN_IMAGE: { + "backbone": "transformer", + "dataset": _SD_PROMPTS_DATASET, + "inference_extra_args": { + "height": 1024, + "width": 1024, + }, + }, } diff --git a/examples/diffusers/quantization/pipeline_manager.py b/examples/diffusers/quantization/pipeline_manager.py index 85e335ba787..f3878c24f49 100644 --- a/examples/diffusers/quantization/pipeline_manager.py +++ b/examples/diffusers/quantization/pipeline_manager.py @@ -99,6 +99,12 @@ def create_pipeline(self) -> Any: pipeline_cls = MODEL_PIPELINE[self.config.model_type] if pipeline_cls is None: + if self.config.model_type == ModelType.QWEN_IMAGE: + raise ImportError( + "Qwen-Image requires a diffusers version that provides " + "QwenImagePipeline. Please upgrade diffusers (e.g. " + "`pip install -U diffusers`) to a release that includes Qwen-Image." + ) raise ValueError( f"Model type {self.config.model_type.value} does not use diffusers pipelines." ) diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index d102e83e068..be3f6276db9 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -111,6 +111,30 @@ def filter_func_wan_video(name: str) -> bool: return pattern.match(name) is not None +# Qwen-Image's transformer has 60 ``transformer_blocks``. The recipe quantizes +# only those blocks while keeping the first two and last two -- and everything +# outside ``transformer_blocks`` -- in original precision. The model-agnostic, +# pre-calibration form of this recipe (deriving the block count from the model) +# lives in quantize.py; this name-only filter covers the plain FP8/NVFP4 path +# for the full 60-block Qwen-Image transformer. +QWEN_IMAGE_NUM_TRANSFORMER_BLOCKS = 60 +_QWEN_IMAGE_BLOCK_RE = re.compile(r"(?:^|\.)transformer_blocks\.(\d+)(?:\.|$)") + + +def filter_func_qwen_image(name: str) -> bool: + """Filter function specifically for Qwen-Image models. + + Returns ``True`` for modules to keep in original precision (quantization + disabled): everything outside ``transformer_blocks``, plus the first two and + last two transformer blocks. + """ + match = _QWEN_IMAGE_BLOCK_RE.search(name) + if match is None: + return True + block_idx = int(match.group(1)) + return block_idx < 2 or block_idx >= QWEN_IMAGE_NUM_TRANSFORMER_BLOCKS - 2 + + def load_calib_prompts( batch_size, calib_data_path: str | Path = "Gustavosta/Stable-Diffusion-Prompts", From e0c7910766cf9460fe618800dc7f00dbf066b143 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 11 Jun 2026 17:16:44 -0700 Subject: [PATCH 02/15] Qwen-Image SVDQuant: block-range recipe, AWQ-style diffusers export, harness Implements the Qwen-Image NVFP4/FP8/SVDQuant diffusers quantization feature (RLCR round 0 / M2-M5), keeping core SVDQuant code unchanged: M2 (recipe): build_block_range_quant_cfg() emits ordered quant_cfg rules (disable-all -> enable *.transformer_blocks.* -> disable first/last-N), applied pre-calibration in Quantizer.get_quant_config so SVDQuant never mutates the excluded blocks. Driven by a MODEL_DEFAULTS["block_range"] entry for Qwen-Image (exclude first 2 / last 2; n derived from the model; n>=first+last+1 enforced). M3 (export): _export_diffusers_checkpoint now promotes quantizer-owned tensors to clean module-level safetensors keys before hide_quantizers_from_state_dict (diffusers path only; the transformers path keeps its postprocess_state_dict rename): input_quantizer._pre_quant_scale -> .pre_quant_scale (AWQ key), weight_quantizer.svdquant_lora_a/b -> .svdquant_lora_a/b. Adds an NVFP4_SVD branch to convert_hf_config (modeled on nvfp4_awq: pre_quant_scale + lora_rank), and process_layer_quant_config now flags SVDQuant with pre_quant_scale=True. This also resolves the diffusers pre_quant_scale TODO for AWQ-style exports. M4 (tests): unit tests for the block-range recipe (first/last-2 exclusion, n>=6 validation) and the NVFP4_SVD HF config conversion. M5 (harness): quantize.py --sanity-image-path (in-memory quantized-inference image, pre-export) + examples/diffusers/quantization/qwen_image_svdquant/ {run_qwen_image_quantization.sh, README.md} (parameterized container/model/ export flow for FP8/NVFP4/SVDQuant). Co-Authored-By: Claude Opus 4.8 Signed-off-by: Jingyu Xin --- .../diffusers/quantization/models_utils.py | 71 ++++++++++++ examples/diffusers/quantization/quantize.py | 59 +++++++++- .../qwen_image_svdquant/README.md | 108 ++++++++++++++++++ .../run_qwen_image_quantization.sh | 94 +++++++++++++++ modelopt/torch/export/convert_hf_config.py | 35 ++++++ modelopt/torch/export/quant_utils.py | 4 + modelopt/torch/export/unified_export_hf.py | 65 +++++++++++ .../diffusers/test_qwen_block_range_recipe.py | 91 +++++++++++++++ .../export/test_convert_hf_config_svdquant.py | 85 ++++++++++++++ 9 files changed, 611 insertions(+), 1 deletion(-) create mode 100644 examples/diffusers/quantization/qwen_image_svdquant/README.md create mode 100755 examples/diffusers/quantization/qwen_image_svdquant/run_qwen_image_quantization.sh create mode 100644 tests/examples/diffusers/test_qwen_block_range_recipe.py create mode 100644 tests/unit/torch/export/test_convert_hf_config_svdquant.py diff --git a/examples/diffusers/quantization/models_utils.py b/examples/diffusers/quantization/models_utils.py index ecacc38d407..05f59b232c9 100644 --- a/examples/diffusers/quantization/models_utils.py +++ b/examples/diffusers/quantization/models_utils.py @@ -18,6 +18,7 @@ from enum import Enum from typing import Any +import torch from diffusers import ( DiffusionPipeline, FluxPipeline, @@ -245,6 +246,15 @@ def get_model_filter_func( "height": 1024, "width": 1024, }, + # Quantize only ``transformer_blocks``; keep the first 2 and last 2 blocks + # (and everything outside ``transformer_blocks``) in original precision. + # Applied pre-calibration via ``build_block_range_quant_cfg`` so SVDQuant + # never mutates the excluded blocks' weights. + "block_range": { + "exclude_first_n": 2, + "exclude_last_n": 2, + "block_module": "transformer_blocks", + }, }, } @@ -292,3 +302,64 @@ def parse_extra_params( i += 1 return extra_params + + +def build_block_range_quant_cfg( + backbone: torch.nn.Module, + exclude_first_n: int, + exclude_last_n: int, + block_module: str = "transformer_blocks", +) -> list[dict[str, Any]]: + """Build ordered ``quant_cfg`` rules for a transformer-block-only recipe. + + The rules quantize only the linears under ``block_module`` while keeping the + first ``exclude_first_n`` and last ``exclude_last_n`` blocks -- and everything + outside ``block_module`` -- in original precision. + + The rules are meant to be appended to the ``quant_cfg`` list consumed by + ``mtq.quantize`` so the selection is applied BEFORE calibration. This is + required for SVDQuant, whose calibration subtracts a low-rank residual from + the weights of every *enabled* linear: disabling the excluded blocks only + after calibration would leave their weights mutated instead of bit-identical + to the original precision. + + Rules are applied in order with later rules overriding earlier ones: + 1. disable every linear weight/input quantizer, + 2. re-enable only those under ``block_module``, + 3. disable the first/last ``n`` blocks. + + Raises: + ValueError: if the backbone has no ``block_module`` list, or it has fewer + than ``exclude_first_n + exclude_last_n + 1`` blocks. + """ + blocks = getattr(backbone, block_module, None) + if blocks is None or not hasattr(blocks, "__len__"): + raise ValueError( + f"Backbone {type(backbone).__name__} has no '{block_module}' module list; " + "cannot build the transformer-block-range recipe." + ) + num_blocks = len(blocks) + min_blocks = exclude_first_n + exclude_last_n + 1 + if num_blocks < min_blocks: + raise ValueError( + f"'{block_module}' has only {num_blocks} block(s); excluding the first " + f"{exclude_first_n} and last {exclude_last_n} requires at least {min_blocks} blocks." + ) + + excluded = sorted( + set(range(exclude_first_n)) | set(range(num_blocks - exclude_last_n, num_blocks)) + ) + rules: list[dict[str, Any]] = [ + {"quantizer_name": "*weight_quantizer", "cfg": {"enable": False}}, + {"quantizer_name": "*input_quantizer", "cfg": {"enable": False}}, + {"quantizer_name": f"*{block_module}.*weight_quantizer", "cfg": {"enable": True}}, + {"quantizer_name": f"*{block_module}.*input_quantizer", "cfg": {"enable": True}}, + ] + for idx in excluded: + rules.append( + {"quantizer_name": f"*{block_module}.{idx}.*weight_quantizer", "cfg": {"enable": False}} + ) + rules.append( + {"quantizer_name": f"*{block_module}.{idx}.*input_quantizer", "cfg": {"enable": False}} + ) + return rules diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 299a101172a..886eb775eb3 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -32,7 +32,13 @@ set_quant_config_attr, ) from diffusers import DiffusionPipeline -from models_utils import MODEL_DEFAULTS, ModelType, get_model_filter_func, parse_extra_params +from models_utils import ( + MODEL_DEFAULTS, + ModelType, + build_block_range_quant_cfg, + get_model_filter_func, + parse_extra_params, +) from onnx_utils.export import generate_fp8_scales, modelopt_export_sd from pipeline_manager import PipelineManager from quantize_config import ( @@ -163,6 +169,28 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any: } ) + # Apply the transformer-block-range recipe (e.g. Qwen-Image) BEFORE + # calibration. This restricts quantization to `transformer_blocks` and + # excludes the first/last N blocks. It must run pre-calibration so that + # SVDQuant does not mutate the weights of the excluded blocks. The recipe + # is format-agnostic (applies to FP8/NVFP4/SVDQuant alike). + block_range = MODEL_DEFAULTS.get(self.model_config.model_type, {}).get("block_range") + if block_range is not None: + recipe_rules = build_block_range_quant_cfg( + backbone, + exclude_first_n=block_range.get("exclude_first_n", 2), + exclude_last_n=block_range.get("exclude_last_n", 2), + block_module=block_range.get("block_module", "transformer_blocks"), + ) + self.logger.info( + f"Applying block-range recipe ({len(recipe_rules)} rules) for " + f"{self.model_config.model_type.value}: quantize only " + f"'{block_range.get('block_module', 'transformer_blocks')}' excluding " + f"first {block_range.get('exclude_first_n', 2)} / last " + f"{block_range.get('exclude_last_n', 2)} blocks." + ) + quant_cfg_list.extend(recipe_rules) + quant_config = {**base_cfg, "quant_cfg": quant_cfg_list} set_quant_config_attr( quant_config, @@ -542,6 +570,15 @@ def create_argument_parser() -> argparse.ArgumentParser: export_group.add_argument( "--restore-from", type=str, help="Path to restore from previous checkpoint" ) + export_group.add_argument( + "--sanity-image-path", + type=str, + default=None, + help="If set, generate one image from the in-memory quantized pipeline (after " + "quantization, before the weights are packed for export) and save it here. This is " + "a quick functional sanity check of quantized inference; it does NOT reload the " + "exported checkpoint.", + ) export_group.add_argument( "--trt-high-precision-dtype", type=str, @@ -681,6 +718,26 @@ def forward_loop(mod): pipeline_manager.print_quant_summary() + # Optional functional sanity check: generate one image from the in-memory + # quantized pipeline. This runs BEFORE export (while weights are still + # fake-quantized and runnable, not yet packed) and does not reload the + # exported checkpoint. + if args.sanity_image_path: + try: + logger.info(f"Generating sanity image to {args.sanity_image_path}") + inference_args = MODEL_DEFAULTS.get(model_type, {}).get("inference_extra_args", {}) + result = pipe( + prompt="A high-quality photo of a cat wearing sunglasses", + num_inference_steps=calib_config.n_steps, + **inference_args, + ) + sanity_path = Path(args.sanity_image_path) + sanity_path.parent.mkdir(parents=True, exist_ok=True) + result.images[0].save(str(sanity_path)) + logger.info("Sanity image saved successfully") + except Exception as sanity_error: # noqa: BLE001 + logger.warning(f"Sanity image generation failed (non-fatal): {sanity_error}") + for backbone_name, backbone in pipeline_manager.iter_backbones(): export_manager.export_onnx( pipe, diff --git a/examples/diffusers/quantization/qwen_image_svdquant/README.md b/examples/diffusers/quantization/qwen_image_svdquant/README.md new file mode 100644 index 00000000000..02daf02ac8c --- /dev/null +++ b/examples/diffusers/quantization/qwen_image_svdquant/README.md @@ -0,0 +1,108 @@ +# Qwen-Image Quantization (FP8 / NVFP4 / NVFP4-SVDQuant) + +A reproducible harness for quantizing [`Qwen/Qwen-Image`](https://huggingface.co/Qwen/Qwen-Image) +with the diffusers quantization example and exporting HuggingFace checkpoints. + +## What it does + +- Registers Qwen-Image in the diffusers quantization example (`--model qwen-image`). +- **Recipe**: quantizes only the linears under `transformer_blocks`, keeping the + **first 2 and last 2** of the 60 blocks (and everything outside + `transformer_blocks`: text encoder, VAE, embedders, norms, `proj_out`, …) in + original precision. The exclusion is applied **before calibration** so that for + SVDQuant the excluded blocks' weights stay bit-identical to the original. +- Produces three checkpoints: **FP8**, **NVFP4** (max), and **NVFP4 + SVDQuant**. +- Exports a HuggingFace unified checkpoint per component (safetensors + `config.json`). + +### SVDQuant checkpoint format (AWQ-aligned) + +For the SVDQuant export, the quantizer-owned tensors are promoted to clean, +module-level safetensors keys (mirroring how AWQ exports `pre_quant_scale`): + +| Tensor | Safetensors key | +|--------|-----------------| +| AWQ smoothing scale (`input_quantizer._pre_quant_scale`) | `.pre_quant_scale` | +| Low-rank factor A (`weight_quantizer.svdquant_lora_a`) | `.svdquant_lora_a` | +| Low-rank factor B (`weight_quantizer.svdquant_lora_b`) | `.svdquant_lora_b` | + +They are embedded in the component's main safetensors (no sidecar). The +`config.json`'s `quantization_config` follows the `nvfp4_awq` shape with +`"pre_quant_scale": true` plus the SVDQuant `lora_rank`, so a consumer can +reconstruct `y = NVFP4_GEMM(x) + (x @ lora_a^T) @ lora_b^T`. (No in-repo runtime +applies this residual yet; the checkpoint is a documented on-disk artifact.) + +## Layout (kernel-dev defaults) + +| Env var | Default | Purpose | +|---------|---------|---------| +| `KERNEL_DEV_ROOT` | `/lustre/fsw/coreai_dlalgo_modelopt/users/jingyux/kernel-dev` | Root for container/models/output | +| `MODEL_DIR` | `${KERNEL_DEV_ROOT}/models/Qwen-Image` | Local model cache | +| `OUTPUT_DIR` | `${KERNEL_DEV_ROOT}/qwen_image_ckpts` | Exported checkpoints | +| `HF_TOKEN_FILE` | `${KERNEL_DEV_ROOT}/HF_TOKEN.txt` | Hugging Face token file | +| `FORMATS` | `fp8 nvfp4 svdquant` | Formats to run | +| `CALIB_SIZE` / `BATCH_SIZE` / `N_STEPS` / `LOWRANK` | `64 / 2 / 20 / 32` | Calibration knobs | + +## 1. Build the container (once) + +The diffusers example needs a recent `diffusers` (with `QwenImagePipeline`) and +modelopt installed from source. From a base NGC PyTorch image: + +```bash +CONTAINER_DIR=/lustre/fsw/coreai_dlalgo_modelopt/users/jingyux/kernel-dev/container +mkdir -p "${CONTAINER_DIR}" + +# Import a base image to an enroot squashfs (adjust the tag as needed). +enroot import -o "${CONTAINER_DIR}/modelopt-diffusers.sqsh" \ + docker://nvcr.io#nvidia/pytorch:25.04-py3 + +# Install modelopt (from source) + example deps into the container, then re-save. +srun --container-image="${CONTAINER_DIR}/modelopt-diffusers.sqsh" \ + --container-mounts=/lustre:/lustre --container-save="${CONTAINER_DIR}/modelopt-diffusers.sqsh" \ + bash -lc ' + cd /lustre/fsw/coreai_dlalgo_modelopt/users/jingyux/kernel-dev/source/Model-Optimizer && + pip install -e ".[dev]" && + pip install -U "diffusers>=0.35" "transformers>=4.52" accelerate datasets && + python -c "from diffusers import QwenImagePipeline; print(\"QwenImagePipeline OK\")" + ' +``` + +## 2. Run quantization + +Inside the container (or via `srun`), run the harness: + +```bash +srun --gpus=1 \ + --container-image=/lustre/fsw/coreai_dlalgo_modelopt/users/jingyux/kernel-dev/container/modelopt-diffusers.sqsh \ + --container-mounts=/lustre:/lustre \ + bash examples/diffusers/quantization/qwen_image_svdquant/run_qwen_image_quantization.sh +``` + +This downloads `Qwen/Qwen-Image` to `MODEL_DIR` (idempotent), then for each +format writes `${OUTPUT_DIR}/qwen-image-/` (HF checkpoint + `sanity.png`). + +Run a single format, or preview the commands without executing: + +```bash +FORMATS=svdquant LOWRANK=32 bash .../run_qwen_image_quantization.sh +DRY_RUN=1 bash .../run_qwen_image_quantization.sh # print planned commands only +``` + +The equivalent direct `quantize.py` invocation for SVDQuant: + +```bash +python examples/diffusers/quantization/quantize.py \ + --model qwen-image --override-model-path "${MODEL_DIR}" --model-dtype BFloat16 \ + --format fp4 --quant-algo svdquant --lowrank 32 \ + --calib-size 64 --batch-size 2 --n-steps 20 \ + --hf-ckpt-dir "${OUTPUT_DIR}/qwen-image-svdquant" \ + --sanity-image-path "${OUTPUT_DIR}/qwen-image-svdquant/sanity.png" +``` + +## Notes + +- `Qwen/Qwen-Image` loads without `trust_remote_code`. +- The transformer is ~20B params; calibration needs a GPU with enough memory + (use `--cpu-offloading` if VRAM-limited). +- The `--sanity-image-path` image is generated from the **in-memory** quantized + pipeline before the weights are packed for export (a functional check of + quantized inference; it does not reload the exported checkpoint). diff --git a/examples/diffusers/quantization/qwen_image_svdquant/run_qwen_image_quantization.sh b/examples/diffusers/quantization/qwen_image_svdquant/run_qwen_image_quantization.sh new file mode 100755 index 00000000000..1ba6e440e69 --- /dev/null +++ b/examples/diffusers/quantization/qwen_image_svdquant/run_qwen_image_quantization.sh @@ -0,0 +1,94 @@ +#!/usr/bin/env bash +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Reproducible Qwen-Image quantization (FP8 / NVFP4 / NVFP4-SVDQuant) using the +# diffusers quantization example. This script is meant to run INSIDE a container +# that already has NVIDIA Model Optimizer installed from source and a +# Qwen-capable diffusers (see README.md for building the container and the +# Slurm/srun wrapper). +# +# It downloads Qwen/Qwen-Image (idempotently), then for each requested format +# runs `quantize.py` to calibrate the transformer (only `transformer_blocks`, +# excluding the first 2 / last 2 blocks), generate a quantized-inference sanity +# image, and export a HuggingFace checkpoint. +# +# All paths are parameterized via environment variables; the defaults match the +# kernel-dev experiment layout described in README.md. +set -euo pipefail + +# --- Configuration (override via environment) -------------------------------- +KERNEL_DEV_ROOT="${KERNEL_DEV_ROOT:-/lustre/fsw/coreai_dlalgo_modelopt/users/jingyux/kernel-dev}" +MODEL_ID="${MODEL_ID:-Qwen/Qwen-Image}" +MODEL_DIR="${MODEL_DIR:-${KERNEL_DEV_ROOT}/models/Qwen-Image}" +OUTPUT_DIR="${OUTPUT_DIR:-${KERNEL_DEV_ROOT}/qwen_image_ckpts}" +HF_TOKEN_FILE="${HF_TOKEN_FILE:-${KERNEL_DEV_ROOT}/HF_TOKEN.txt}" +# Path to the diffusers quantization example (this script lives one level below it). +QUANT_DIR="${QUANT_DIR:-$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)}" + +# Formats to run: any of {fp8, nvfp4, svdquant}. +FORMATS="${FORMATS:-fp8 nvfp4 svdquant}" + +# Calibration knobs (small defaults for a quick run; raise CALIB_SIZE for quality). +CALIB_SIZE="${CALIB_SIZE:-64}" +BATCH_SIZE="${BATCH_SIZE:-2}" +N_STEPS="${N_STEPS:-20}" +LOWRANK="${LOWRANK:-32}" +MODEL_DTYPE="${MODEL_DTYPE:-BFloat16}" + +# Set DRY_RUN=1 to print the planned commands without executing them. +DRY_RUN="${DRY_RUN:-0}" + +log() { echo "[qwen-image-quant] $*"; } +run() { + log "+ $*" + if [[ "${DRY_RUN}" != "1" ]]; then + "$@" + fi +} + +# --- Hugging Face token ------------------------------------------------------ +if [[ ! -r "${HF_TOKEN_FILE}" ]]; then + echo "ERROR: HF token file not found or not readable: ${HF_TOKEN_FILE}" >&2 + echo " Set HF_TOKEN_FILE to a readable file containing your Hugging Face token." >&2 + exit 1 +fi +HF_TOKEN="$(tr -d '[:space:]' < "${HF_TOKEN_FILE}")" +if [[ -z "${HF_TOKEN}" ]]; then + echo "ERROR: HF token file is empty: ${HF_TOKEN_FILE}" >&2 + exit 1 +fi +export HF_TOKEN +export HUGGING_FACE_HUB_TOKEN="${HF_TOKEN}" + +# --- Download the model (idempotent) ---------------------------------------- +log "Downloading ${MODEL_ID} -> ${MODEL_DIR} (skipped if already present)" +run mkdir -p "${MODEL_DIR}" +run huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_DIR}" --exclude "*.onnx" + +# --- Quantize + export for each format -------------------------------------- +mkdir -p "${OUTPUT_DIR}" +for fmt in ${FORMATS}; do + case "${fmt}" in + fp8) quant_args=(--format fp8 --quant-algo max) ;; + nvfp4) quant_args=(--format fp4 --quant-algo max) ;; + svdquant) quant_args=(--format fp4 --quant-algo svdquant --lowrank "${LOWRANK}") ;; + *) echo "ERROR: unknown format '${fmt}' (expected fp8|nvfp4|svdquant)" >&2; exit 1 ;; + esac + + out="${OUTPUT_DIR}/qwen-image-${fmt}" + log "=== Quantizing Qwen-Image (${fmt}) -> ${out} ===" + run python "${QUANT_DIR}/quantize.py" \ + --model qwen-image \ + --override-model-path "${MODEL_DIR}" \ + --model-dtype "${MODEL_DTYPE}" \ + "${quant_args[@]}" \ + --calib-size "${CALIB_SIZE}" \ + --batch-size "${BATCH_SIZE}" \ + --n-steps "${N_STEPS}" \ + --hf-ckpt-dir "${out}" \ + --sanity-image-path "${out}/sanity.png" + log "Done: ${fmt}. Checkpoint at ${out}, sanity image at ${out}/sanity.png" +done + +log "All requested formats complete. Checkpoints under ${OUTPUT_DIR}" diff --git a/modelopt/torch/export/convert_hf_config.py b/modelopt/torch/export/convert_hf_config.py index 06e5923a30f..70016a54bca 100644 --- a/modelopt/torch/export/convert_hf_config.py +++ b/modelopt/torch/export/convert_hf_config.py @@ -62,6 +62,18 @@ def _quant_algo_to_group_config(quant_algo: str, group_size: int | None = None) return { "weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": gs}, } + elif quant_algo == "NVFP4_SVD": + gs = group_size or 16 + return { + "input_activations": { + "dynamic": False, + "num_bits": 4, + "type": "float", + "group_size": gs, + }, + "weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": gs}, + "pre_quant_scale": True, + } elif quant_algo in ("NVFP4_AWQ", "W4A8_AWQ"): gs = group_size or 128 return { @@ -196,6 +208,29 @@ def convert_hf_quant_config_format(input_config: dict[str, Any]) -> dict[str, An "targets": ["Linear"], } new_config["config_groups"] = {"group_0": config_group_details} + elif quant_algo_value == "NVFP4_SVD": + # NVFP4 + SVDQuant: NVFP4 weights/activations plus an AWQ-style + # pre_quant_scale and a low-rank residual (svdquant_lora_a/b) stored as + # .pre_quant_scale / .svdquant_lora_{a,b} in the + # safetensors. The config mirrors NVFP4 with a pre_quant_scale flag and + # the LoRA rank so consumers can reconstruct + # ``y = NVFP4_GEMM(x) + (x @ lora_a^T) @ lora_b^T``. + group_size = original_quantization_details.get("group_size", 16) + config_group_details = { + "input_activations": { + "dynamic": False, + "num_bits": 4, + "type": "float", + "group_size": group_size, + }, + "weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": group_size}, + "pre_quant_scale": True, + "targets": ["Linear"], + } + lora_rank = original_quantization_details.get("lora_rank") + if lora_rank is not None: + config_group_details["lora_rank"] = lora_rank + new_config["config_groups"] = {"group_0": config_group_details} elif quant_algo_value == "MIXED_PRECISION": quantized_layers = original_quantization_details.get("quantized_layers", {}) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index d8ddf442924..e35ead6253b 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -749,9 +749,13 @@ def process_layer_quant_config(layer_config_dict): "group_size": block_size_value, } elif v == "nvfp4_svdquant": + # SVDQuant builds on the AWQ-style pre_quant_scale smoothing, so its + # config mirrors nvfp4_awq (group_size + pre_quant_scale flag). layer_config = { "quant_algo": "NVFP4_SVD", "group_size": block_size_value, + "has_zero_point": False, + "pre_quant_scale": True, } elif v == "mxfp8": layer_config = { diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index ef5757aa0cb..0a02f751855 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -1013,6 +1013,58 @@ def _fuse_qkv_linears_diffusion( ) +def _detect_svdquant_rank(component: nn.Module) -> int | None: + """Return the SVDQuant low-rank dimension from the first SVDQuant linear, if any. + + ``svdquant_lora_a`` has shape ``(rank, in_features)``, so its first dimension + is the low-rank size. + """ + for _, sub_module in component.named_modules(): + weight_quantizer = getattr(sub_module, "weight_quantizer", None) + lora_a = getattr(weight_quantizer, "svdquant_lora_a", None) + if lora_a is not None: + return int(lora_a.shape[0]) + return None + + +def _promote_quantizer_tensors_to_module(component: nn.Module) -> None: + """Promote quantizer-owned export tensors onto their parent linear module. + + The diffusers export path saves via ``save_pretrained`` inside + :func:`hide_quantizers_from_state_dict` (which deletes the ``weight_quantizer`` + / ``input_quantizer`` submodules) and -- unlike the transformers path -- does + NOT run :func:`postprocess_state_dict`. Without this step the AWQ smoothing + scale and the SVDQuant low-rank factors would be dropped from the exported + checkpoint. We register them as module buffers under clean, AWQ-aligned keys + so they are embedded in the component's main safetensors: + + - ``input_quantizer._pre_quant_scale`` -> ``.pre_quant_scale`` + (the same key the transformers/AWQ path produces via postprocess_state_dict) + - ``weight_quantizer.svdquant_lora_a`` -> ``.svdquant_lora_a`` + - ``weight_quantizer.svdquant_lora_b`` -> ``.svdquant_lora_b`` + + This runs after :func:`_process_quantized_modules` (which leaves these + quantizer buffers in place) and before ``save_pretrained``. + """ + for _, sub_module in component.named_modules(): + if not is_quantlinear(sub_module): + continue + + input_quantizer = getattr(sub_module, "input_quantizer", None) + pre_quant_scale = getattr(input_quantizer, "_pre_quant_scale", None) + if pre_quant_scale is not None and not hasattr(sub_module, "pre_quant_scale"): + sub_module.register_buffer("pre_quant_scale", pre_quant_scale.detach().clone()) + + weight_quantizer = getattr(sub_module, "weight_quantizer", None) + lora_a = getattr(weight_quantizer, "svdquant_lora_a", None) + lora_b = getattr(weight_quantizer, "svdquant_lora_b", None) + if lora_a is not None and lora_b is not None: + if not hasattr(sub_module, "svdquant_lora_a"): + sub_module.register_buffer("svdquant_lora_a", lora_a.detach().clone()) + if not hasattr(sub_module, "svdquant_lora_b"): + sub_module.register_buffer("svdquant_lora_b", lora_b.detach().clone()) + + def _export_diffusers_checkpoint( pipe: Any, dtype: torch.dtype | None, @@ -1091,8 +1143,21 @@ def _export_diffusers_checkpoint( # Step 4: Process quantized modules (convert weights, register scales) _process_quantized_modules(component, component_dtype, is_modelopt_qlora=False) + # Step 4.5: Promote quantizer-owned tensors (AWQ pre_quant_scale and + # SVDQuant LoRA factors) onto the module so they survive + # hide_quantizers_from_state_dict and are embedded in the component's + # main safetensors under clean, AWQ-aligned keys. + _promote_quantizer_tensors_to_module(component) + # Step 5: Build quantization config quant_config = get_quant_config(component, is_modelopt_qlora=False) + if quant_config: + quantization_details = quant_config.get("quantization", {}) + # Record the SVDQuant low-rank size so consumers know the LoRA shape. + if quantization_details.get("quant_algo") == "NVFP4_SVD": + svdquant_rank = _detect_svdquant_rank(component) + if svdquant_rank is not None: + quantization_details["lora_rank"] = svdquant_rank hf_quant_config = convert_hf_quant_config_format(quant_config) if quant_config else None # Step 6: Save the component diff --git a/tests/examples/diffusers/test_qwen_block_range_recipe.py b/tests/examples/diffusers/test_qwen_block_range_recipe.py new file mode 100644 index 00000000000..906dd1fb60e --- /dev/null +++ b/tests/examples/diffusers/test_qwen_block_range_recipe.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the transformer-block-range quantization recipe (e.g. Qwen-Image). + +The recipe must quantize only the linears under ``transformer_blocks`` while +excluding the first/last N blocks, and it must be expressible as pre-calibration +``quant_cfg`` rules (so SVDQuant never mutates the excluded blocks' weights). +""" + +import re +import sys +from pathlib import Path + +import pytest + +# Importing the example module pulls in diffusers/torch/datasets/modelopt. +pytest.importorskip("diffusers") +pytest.importorskip("torch") + +# Make the diffusers quantization example importable. +_EXAMPLE_DIR = Path(__file__).parents[3] / "examples" / "diffusers" / "quantization" +if str(_EXAMPLE_DIR) not in sys.path: + sys.path.insert(0, str(_EXAMPLE_DIR)) + +from models_utils import build_block_range_quant_cfg # noqa: E402 + +_BLOCK_RULE_RE = re.compile(r"\*transformer_blocks\.(\d+)\.\*(?:weight|input)_quantizer") + + +class _StubBackbone: + """Minimal stand-in exposing a ``transformer_blocks`` sequence of length n.""" + + def __init__(self, num_blocks: int): + self.transformer_blocks = list(range(num_blocks)) + + +def _disabled_block_indices(rules): + """Indices of transformer blocks explicitly disabled by per-block rules.""" + indices = set() + for rule in rules: + if rule["cfg"].get("enable") is False: + match = _BLOCK_RULE_RE.fullmatch(rule["quantizer_name"]) + if match: + indices.add(int(match.group(1))) + return indices + + +def test_recipe_excludes_first_and_last_two_blocks(): + rules = build_block_range_quant_cfg(_StubBackbone(6), exclude_first_n=2, exclude_last_n=2) + + # 1. disable-all rules come first (weight + input). + assert rules[0] == {"quantizer_name": "*weight_quantizer", "cfg": {"enable": False}} + assert rules[1] == {"quantizer_name": "*input_quantizer", "cfg": {"enable": False}} + # 2. then enable only the transformer_blocks. + assert {"quantizer_name": "*transformer_blocks.*weight_quantizer", "cfg": {"enable": True}} in rules + assert {"quantizer_name": "*transformer_blocks.*input_quantizer", "cfg": {"enable": True}} in rules + # 3. then disable the first 2 and last 2 of the 6 blocks -> {0, 1, 4, 5}; quantize {2, 3}. + assert _disabled_block_indices(rules) == {0, 1, 4, 5} + + +def test_recipe_block_count_scales_with_model(): + # For a 60-block model (Qwen-Image), exclude {0, 1, 58, 59}; quantize 2..57. + rules = build_block_range_quant_cfg(_StubBackbone(60), exclude_first_n=2, exclude_last_n=2) + assert _disabled_block_indices(rules) == {0, 1, 58, 59} + + +def test_recipe_rejects_too_few_blocks(): + # 2 + 2 exclusion needs at least 5 blocks; 4 blocks must raise a clear error. + with pytest.raises(ValueError, match="at least"): + build_block_range_quant_cfg(_StubBackbone(4), exclude_first_n=2, exclude_last_n=2) + + +def test_recipe_missing_block_module_raises(): + class _NoBlocks: + pass + + with pytest.raises(ValueError, match="transformer_blocks"): + build_block_range_quant_cfg(_NoBlocks(), exclude_first_n=2, exclude_last_n=2) diff --git a/tests/unit/torch/export/test_convert_hf_config_svdquant.py b/tests/unit/torch/export/test_convert_hf_config_svdquant.py new file mode 100644 index 00000000000..d00365703b1 --- /dev/null +++ b/tests/unit/torch/export/test_convert_hf_config_svdquant.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NVFP4_SVD (SVDQuant) HF quantization-config conversion.""" + +from modelopt.torch.export.convert_hf_config import ( + _quant_algo_to_group_config, + convert_hf_quant_config_format, +) + + +def test_nvfp4_svd_group_config_mirrors_awq_with_pre_quant_scale(): + """The NVFP4_SVD config group is NVFP4 weights/activations + a pre_quant_scale flag.""" + group = _quant_algo_to_group_config("NVFP4_SVD", group_size=16) + assert group["pre_quant_scale"] is True + assert group["weights"] == { + "dynamic": False, + "num_bits": 4, + "type": "float", + "group_size": 16, + } + assert group["input_activations"]["num_bits"] == 4 + assert group["input_activations"]["type"] == "float" + assert group["input_activations"]["group_size"] == 16 + + +def test_convert_hf_quant_config_format_nvfp4_svd(): + """A full NVFP4_SVD quantization dict converts to a complete compressed-tensors config.""" + input_config = { + "producer": {"name": "modelopt", "version": "0.0.0"}, + "quantization": { + "quant_algo": "NVFP4_SVD", + "group_size": 16, + "has_zero_point": False, + "pre_quant_scale": True, + "lora_rank": 32, + "exclude_modules": ["transformer_blocks.0.*", "proj_out"], + "kv_cache_quant_algo": None, + }, + } + + out = convert_hf_quant_config_format(input_config) + + # A real config group is emitted (not a bare {"quant_algo": ...} fallback). + assert "config_groups" in out + group = out["config_groups"]["group_0"] + assert group["pre_quant_scale"] is True + assert group["lora_rank"] == 32 + assert group["weights"]["num_bits"] == 4 + assert group["weights"]["type"] == "float" + assert group["weights"]["group_size"] == 16 + assert group["input_activations"]["num_bits"] == 4 + assert group["targets"] == ["Linear"] + + # Top-level metadata is preserved. + assert out["quant_algo"] == "NVFP4_SVD" + assert out["ignore"] == ["transformer_blocks.0.*", "proj_out"] + assert out["quant_method"] == "modelopt" + + +def test_convert_hf_quant_config_format_nvfp4_svd_without_rank(): + """lora_rank is optional; omitting it must not break the conversion.""" + input_config = { + "quantization": { + "quant_algo": "NVFP4_SVD", + "group_size": 16, + "pre_quant_scale": True, + }, + } + out = convert_hf_quant_config_format(input_config) + group = out["config_groups"]["group_0"] + assert "lora_rank" not in group + assert group["pre_quant_scale"] is True From 6718e722ca58897a169157ffc9a963afc524704a Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 11 Jun 2026 17:42:27 -0700 Subject: [PATCH 03/15] Qwen-Image SVDQuant: fix export/recipe defects, add Qwen QKV fusion + tests Addresses the round-0 Codex review (RLCR round 1): Blocking fixes: - convert_hf_config: NVFP4_SVD config groups now keep `has_zero_point: False` (both convert_hf_quant_config_format and _quant_algo_to_group_config); asserted in the unit test. - build_block_range_quant_cfg: minimum is now first+last+2 (>=2 quantized middle blocks; n>=6 for the 2+2 Qwen recipe); recipe test rejects 5/4/3-block models. - quantize.py --sanity-image-path failures are now fatal (re-raise -> non-zero exit) so the harness cannot report success without the image; the harness also verifies sanity.png + safetensors + config.json exist per format. Qwen export enablement: - diffusers_utils.generate_diffusion_dummy_inputs: add a QwenImageTransformer2DModel branch (packed latents [B,(H//2)(W//2),C], encoder_hidden_states_mask, img_shapes, txt_seq_lens, optional guidance, continuous timestep). - unified_export_hf._fuse_qkv_linears_diffusion gains strict=; Qwen QKV fusion now fails hard instead of silently skipping. Promotion buffers now overwrite on re-export. create_pipeline_from gives the same actionable Qwen import error. Tests: - New tests/unit/torch/quantization/test_svdquant_forward_fold.py: LoRA stays on weight_quantizer, forward includes a nonzero residual, fold_weight folds it and drops the buffers (existing test_svdquant_lora_weights left unmodified). Deferred to Round 2 / cluster: tiny Qwen2_5_VL fixture + full diffusers e2e export test (needs a Qwen-capable diffusers + GPU); the actual AC-7 checkpoint run. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Jingyu Xin --- .../diffusers/quantization/models_utils.py | 8 +- .../quantization/pipeline_manager.py | 6 + examples/diffusers/quantization/quantize.py | 7 +- .../run_qwen_image_quantization.sh | 7 ++ modelopt/torch/export/convert_hf_config.py | 2 + modelopt/torch/export/diffusers_utils.py | 40 ++++++ modelopt/torch/export/unified_export_hf.py | 25 ++-- .../diffusers/test_qwen_block_range_recipe.py | 10 +- .../export/test_convert_hf_config_svdquant.py | 2 + .../test_svdquant_forward_fold.py | 116 ++++++++++++++++++ 10 files changed, 209 insertions(+), 14 deletions(-) create mode 100644 tests/unit/torch/quantization/test_svdquant_forward_fold.py diff --git a/examples/diffusers/quantization/models_utils.py b/examples/diffusers/quantization/models_utils.py index 05f59b232c9..e7faf7589c1 100644 --- a/examples/diffusers/quantization/models_utils.py +++ b/examples/diffusers/quantization/models_utils.py @@ -339,11 +339,15 @@ def build_block_range_quant_cfg( "cannot build the transformer-block-range recipe." ) num_blocks = len(blocks) - min_blocks = exclude_first_n + exclude_last_n + 1 + # Require at least two quantized middle blocks so the recipe actually + # quantizes something (excluding first/last alone could otherwise leave 0-1 + # quantized blocks). For the default 2+2 recipe this means n >= 6. + min_blocks = exclude_first_n + exclude_last_n + 2 if num_blocks < min_blocks: raise ValueError( f"'{block_module}' has only {num_blocks} block(s); excluding the first " - f"{exclude_first_n} and last {exclude_last_n} requires at least {min_blocks} blocks." + f"{exclude_first_n} and last {exclude_last_n} requires at least {min_blocks} blocks " + f"(at least 2 quantized middle blocks)." ) excluded = sorted( diff --git a/examples/diffusers/quantization/pipeline_manager.py b/examples/diffusers/quantization/pipeline_manager.py index f3878c24f49..5496f131ba2 100644 --- a/examples/diffusers/quantization/pipeline_manager.py +++ b/examples/diffusers/quantization/pipeline_manager.py @@ -61,6 +61,12 @@ def create_pipeline_from( """ pipeline_cls = MODEL_PIPELINE[model_type] if pipeline_cls is None: + if model_type == ModelType.QWEN_IMAGE: + raise ImportError( + "Qwen-Image requires a diffusers version that provides " + "QwenImagePipeline. Please upgrade diffusers (e.g. " + "`pip install -U diffusers`) to a release that includes Qwen-Image." + ) raise ValueError(f"Model type {model_type.value} does not use diffusers pipelines.") model_id = ( MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 886eb775eb3..57b9db589db 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -735,8 +735,11 @@ def forward_loop(mod): sanity_path.parent.mkdir(parents=True, exist_ok=True) result.images[0].save(str(sanity_path)) logger.info("Sanity image saved successfully") - except Exception as sanity_error: # noqa: BLE001 - logger.warning(f"Sanity image generation failed (non-fatal): {sanity_error}") + except Exception as sanity_error: + # A requested sanity image is a positive success criterion: if it + # cannot be produced, fail loudly rather than reporting success. + logger.error(f"Sanity image generation failed: {sanity_error}", exc_info=True) + raise for backbone_name, backbone in pipeline_manager.iter_backbones(): export_manager.export_onnx( diff --git a/examples/diffusers/quantization/qwen_image_svdquant/run_qwen_image_quantization.sh b/examples/diffusers/quantization/qwen_image_svdquant/run_qwen_image_quantization.sh index 1ba6e440e69..5e0569c0e40 100755 --- a/examples/diffusers/quantization/qwen_image_svdquant/run_qwen_image_quantization.sh +++ b/examples/diffusers/quantization/qwen_image_svdquant/run_qwen_image_quantization.sh @@ -88,6 +88,13 @@ for fmt in ${FORMATS}; do --n-steps "${N_STEPS}" \ --hf-ckpt-dir "${out}" \ --sanity-image-path "${out}/sanity.png" + + # Verify the expected artifacts were produced (a missing artifact is a failure). + if [[ "${DRY_RUN}" != "1" ]]; then + [[ -f "${out}/sanity.png" ]] || { echo "ERROR: missing sanity image ${out}/sanity.png" >&2; exit 1; } + find "${out}" -name '*.safetensors' | grep -q . || { echo "ERROR: no safetensors under ${out}" >&2; exit 1; } + find "${out}" -name 'config.json' | grep -q . || { echo "ERROR: no config.json under ${out}" >&2; exit 1; } + fi log "Done: ${fmt}. Checkpoint at ${out}, sanity image at ${out}/sanity.png" done diff --git a/modelopt/torch/export/convert_hf_config.py b/modelopt/torch/export/convert_hf_config.py index 70016a54bca..6f7dedb97c8 100644 --- a/modelopt/torch/export/convert_hf_config.py +++ b/modelopt/torch/export/convert_hf_config.py @@ -72,6 +72,7 @@ def _quant_algo_to_group_config(quant_algo: str, group_size: int | None = None) "group_size": gs, }, "weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": gs}, + "has_zero_point": False, "pre_quant_scale": True, } elif quant_algo in ("NVFP4_AWQ", "W4A8_AWQ"): @@ -224,6 +225,7 @@ def convert_hf_quant_config_format(input_config: dict[str, Any]) -> dict[str, An "group_size": group_size, }, "weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": group_size}, + "has_zero_point": False, "pre_quant_scale": True, "targets": ["Linear"], } diff --git a/modelopt/torch/export/diffusers_utils.py b/modelopt/torch/export/diffusers_utils.py index 9620c97c10e..823f26c8e5d 100644 --- a/modelopt/torch/export/diffusers_utils.py +++ b/modelopt/torch/export/diffusers_utils.py @@ -142,6 +142,11 @@ def _is_model_type(module_path: str, class_name: str, fallback: bool) -> bool: "UNet2DConditionModel", "unet" in model_class_name.lower(), ) + is_qwen = _is_model_type( + "diffusers.models.transformers", + "QwenImageTransformer2DModel", + "qwen" in model_class_name.lower(), + ) cfg = getattr(model, "config", None) @@ -321,6 +326,40 @@ def _wan_inputs() -> dict[str, torch.Tensor]: "return_dict": False, } + def _qwen_inputs() -> dict[str, Any]: + # QwenImageTransformer2DModel does NOT take the standard + # (hidden_states[B,C,H,W], timestep, encoder_hidden_states) triple. It expects + # *packed* latents [B, (H//2)*(W//2), in_channels] plus encoder_hidden_states, + # encoder_hidden_states_mask, img_shapes, txt_seq_lens, and optional guidance. + # Timesteps are continuous in [0, 1] (not the diffusers [0, 1000] scale). + in_channels = getattr(cfg, "in_channels", 64) + joint_attention_dim = getattr(cfg, "joint_attention_dim", 3584) + guidance_embeds = getattr(cfg, "guidance_embeds", False) + + # Small packed spatial grid (already divided by the 2x2 patch size). + packed_h = packed_w = 4 + img_seq_len = packed_h * packed_w + text_seq_len = 8 + + dummy_inputs: dict[str, Any] = { + "hidden_states": torch.randn( + batch_size, img_seq_len, in_channels, device=device, dtype=dtype + ), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype + ), + "encoder_hidden_states_mask": torch.ones( + batch_size, text_seq_len, device=device, dtype=torch.int64 + ), + "timestep": torch.tensor([0.5], device=device, dtype=dtype).expand(batch_size), + "img_shapes": [[(1, packed_h, packed_w)]] * batch_size, + "txt_seq_lens": [text_seq_len] * batch_size, + "return_dict": False, + } + if guidance_embeds: + dummy_inputs["guidance"] = torch.tensor([4.0], device=device, dtype=torch.float32) + return dummy_inputs + def _generic_transformer_inputs() -> dict[str, torch.Tensor] | None: # Try generic transformer handling for other model types # Check if model has common transformer attributes @@ -366,6 +405,7 @@ def _generic_transformer_inputs() -> dict[str, torch.Tensor] | None: ("dit", is_dit, _dit_inputs), ("wan", is_wan, _wan_inputs), ("unet", is_unet, _unet_inputs), + ("qwen", is_qwen, _qwen_inputs), ] for _, matches, build_inputs in model_input_builders: diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 0a02f751855..6d870a726bd 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -960,7 +960,9 @@ def _export_transformers_checkpoint( def _fuse_qkv_linears_diffusion( - model: nn.Module, dummy_forward_fn: Callable[[], None] | None = None + model: nn.Module, + dummy_forward_fn: Callable[[], None] | None = None, + strict: bool = False, ) -> None: """Fuse QKV linear layers that share the same input for diffusion models. @@ -994,6 +996,11 @@ def _fuse_qkv_linears_diffusion( model, dummy_forward_fn, collect_layernorms=False ) except Exception as e: + if strict: + raise RuntimeError( + f"QKV fusion dummy forward failed for {type(model).__name__}; a working " + f"dummy forward is required to export this model correctly. Original error: {e}" + ) from e print(f"Warning: Failed to run dummy forward for QKV fusion: {e}") print("Skipping QKV fusion. Quantization may still work but amax values won't be unified.") return @@ -1050,19 +1057,19 @@ def _promote_quantizer_tensors_to_module(component: nn.Module) -> None: if not is_quantlinear(sub_module): continue + # register_buffer overwrites an existing buffer of the same name, so a + # repeated export refreshes (rather than keeps stale) promoted tensors. input_quantizer = getattr(sub_module, "input_quantizer", None) pre_quant_scale = getattr(input_quantizer, "_pre_quant_scale", None) - if pre_quant_scale is not None and not hasattr(sub_module, "pre_quant_scale"): + if pre_quant_scale is not None: sub_module.register_buffer("pre_quant_scale", pre_quant_scale.detach().clone()) weight_quantizer = getattr(sub_module, "weight_quantizer", None) lora_a = getattr(weight_quantizer, "svdquant_lora_a", None) lora_b = getattr(weight_quantizer, "svdquant_lora_b", None) if lora_a is not None and lora_b is not None: - if not hasattr(sub_module, "svdquant_lora_a"): - sub_module.register_buffer("svdquant_lora_a", lora_a.detach().clone()) - if not hasattr(sub_module, "svdquant_lora_b"): - sub_module.register_buffer("svdquant_lora_b", lora_b.detach().clone()) + sub_module.register_buffer("svdquant_lora_a", lora_a.detach().clone()) + sub_module.register_buffer("svdquant_lora_b", lora_b.detach().clone()) def _export_diffusers_checkpoint( @@ -1138,7 +1145,11 @@ def _export_diffusers_checkpoint( # This is similar to requantize_resmooth_fused_llm_layers but simplified for diffusion # TODO: Add pre_quant_scale handling and FFN fusion for AWQ-style quantization print(f" Running QKV fusion for {component_name}...") - _fuse_qkv_linears_diffusion(component) + # Qwen-Image's packed-latent forward signature is non-standard; if the + # dummy forward fails for it, fail loudly rather than silently skipping + # fusion (which would export un-unified amax values). + is_qwen_component = "qwen" in type(component).__name__.lower() + _fuse_qkv_linears_diffusion(component, strict=is_qwen_component) # Step 4: Process quantized modules (convert weights, register scales) _process_quantized_modules(component, component_dtype, is_modelopt_qlora=False) diff --git a/tests/examples/diffusers/test_qwen_block_range_recipe.py b/tests/examples/diffusers/test_qwen_block_range_recipe.py index 906dd1fb60e..82690ae252c 100644 --- a/tests/examples/diffusers/test_qwen_block_range_recipe.py +++ b/tests/examples/diffusers/test_qwen_block_range_recipe.py @@ -77,10 +77,14 @@ def test_recipe_block_count_scales_with_model(): assert _disabled_block_indices(rules) == {0, 1, 58, 59} -def test_recipe_rejects_too_few_blocks(): - # 2 + 2 exclusion needs at least 5 blocks; 4 blocks must raise a clear error. +@pytest.mark.parametrize("num_blocks", [5, 4, 3]) +def test_recipe_rejects_too_few_blocks(num_blocks): + # A 2 + 2 exclusion needs at least 6 blocks (>= 2 quantized middle blocks). + # A 5-block model leaves only 1 middle block and must be rejected too. with pytest.raises(ValueError, match="at least"): - build_block_range_quant_cfg(_StubBackbone(4), exclude_first_n=2, exclude_last_n=2) + build_block_range_quant_cfg( + _StubBackbone(num_blocks), exclude_first_n=2, exclude_last_n=2 + ) def test_recipe_missing_block_module_raises(): diff --git a/tests/unit/torch/export/test_convert_hf_config_svdquant.py b/tests/unit/torch/export/test_convert_hf_config_svdquant.py index d00365703b1..13a22ab22d5 100644 --- a/tests/unit/torch/export/test_convert_hf_config_svdquant.py +++ b/tests/unit/torch/export/test_convert_hf_config_svdquant.py @@ -25,6 +25,7 @@ def test_nvfp4_svd_group_config_mirrors_awq_with_pre_quant_scale(): """The NVFP4_SVD config group is NVFP4 weights/activations + a pre_quant_scale flag.""" group = _quant_algo_to_group_config("NVFP4_SVD", group_size=16) assert group["pre_quant_scale"] is True + assert group["has_zero_point"] is False assert group["weights"] == { "dynamic": False, "num_bits": 4, @@ -57,6 +58,7 @@ def test_convert_hf_quant_config_format_nvfp4_svd(): assert "config_groups" in out group = out["config_groups"]["group_0"] assert group["pre_quant_scale"] is True + assert group["has_zero_point"] is False assert group["lora_rank"] == 32 assert group["weights"]["num_bits"] == 4 assert group["weights"]["type"] == "float" diff --git a/tests/unit/torch/quantization/test_svdquant_forward_fold.py b/tests/unit/torch/quantization/test_svdquant_forward_fold.py new file mode 100644 index 00000000000..f6a8a83f4b4 --- /dev/null +++ b/tests/unit/torch/quantization/test_svdquant_forward_fold.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SVDQuant forward / fold coverage. + +These tests protect the invariants the diffusers SVDQuant export relies on +(DEC-5: the LoRA factors stay on the ``weight_quantizer`` in the live model; the +export layer promotes them). They complement (and do not modify) the existing +``test_calib.py::test_svdquant_lora_weights``. +""" + +from functools import partial + +import torch +import torch.nn as nn + +import modelopt.torch.quantization as mtq + + +class _SVDMLP(nn.Module): + def __init__(self, dim: int = 64): + super().__init__() + self.fc1 = nn.Linear(dim, dim) + self.fc2 = nn.Linear(dim, dim) + + def forward(self, x): + return self.fc2(torch.relu(self.fc1(x))) + + +def _forward_loop(model, dataloader): + for batch in dataloader: + model(batch) + + +def _quantize_svdquant(dim: int = 64) -> nn.Module: + model = _SVDMLP(dim) + quant_config = mtq.INT8_SMOOTHQUANT_CFG.copy() + quant_config["algorithm"] = {"method": "svdquant", "lowrank": 8} + data = [torch.randn(2, dim) for _ in range(2)] + mtq.quantize(model, quant_config, partial(_forward_loop, dataloader=data)) + return model + + +def _quantized_linears(model: nn.Module): + return [m for m in model.modules() if isinstance(m, torch.nn.Linear)] + + +def test_svdquant_lora_stays_on_weight_quantizer(): + """DEC-5: LoRA lives on the quantizer, not the module (export promotes it).""" + model = _quantize_svdquant() + linears = _quantized_linears(model) + assert linears + for module in linears: + wq = module.weight_quantizer + assert wq.svdquant_lora_a is not None + assert wq.svdquant_lora_b is not None + # Not refactored onto the module. + assert not hasattr(module, "svdquant_lora_a") + assert not hasattr(module, "svdquant_lora_b") + + +def test_svdquant_forward_includes_nonzero_residual(): + """The forward output includes a nonzero low-rank residual term.""" + model = _quantize_svdquant() + for module in _quantized_linears(model): + x = torch.randn(2, module.in_features) + + residual = module._compute_lora_residual(x) + assert residual is not None + assert torch.count_nonzero(residual) > 0 + + full = module(x) + + # Temporarily drop the LoRA buffers to get the base (no-residual) output. + wq = module.weight_quantizer + lora_a = wq._svdquant_lora_a + lora_b = wq._svdquant_lora_b + delattr(wq, "_svdquant_lora_a") + delattr(wq, "_svdquant_lora_b") + try: + base = module(x) + finally: + wq.register_buffer("_svdquant_lora_a", lora_a) + wq.register_buffer("_svdquant_lora_b", lora_b) + + # The residual measurably changes the forward output. + assert not torch.allclose(full, base) + + +def test_svdquant_fold_weight_removes_buffers_and_changes_weight(): + """fold_weight() folds the residual into the weight and drops the buffers.""" + model = _quantize_svdquant() + for module in _quantized_linears(model): + wq = module.weight_quantizer + assert hasattr(wq, "_svdquant_lora_a") + assert hasattr(wq, "_svdquant_lora_b") + + weight_before = module.weight.detach().clone() + module.fold_weight() + + assert not hasattr(wq, "_svdquant_lora_a") + assert not hasattr(wq, "_svdquant_lora_b") + # Folding (quantized weight + low-rank residual) changes the stored weight. + assert not torch.allclose(module.weight, weight_before) From 8e3b3ed12d48f1a9f667b0fa55d982cf4fc93d13 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 11 Jun 2026 17:58:44 -0700 Subject: [PATCH 04/15] Qwen-Image SVDQuant: add export/fusion/promotion tests; drop plan terminology Round 2 (addresses round-1 Codex review: the round-1 code had no direct test coverage). Adds tests/unit/torch/export/test_diffusers_qwen_export.py: - Qwen dummy inputs: generate_diffusion_dummy_inputs builds the expected keys for a real tiny QwenImageTransformer2DModel, and the generated dummy forward runs on it (this is what catches any wrong shape/kwarg in the dummy-input builder). - Strict fusion: _fuse_qkv_linears_diffusion(strict=True) re-raises on a failing dummy forward; strict=False does not. - Structural export: _promote_quantizer_tensors_to_module promotes SVDQuant LoRA + pre_quant_scale to clean module keys that survive hide_quantizers_from_state_dict (promoted .svdquant_lora_a/b + .pre_quant_scale present; weight_quantizer / input_quantizer keys absent), on a calibrated tiny SVDQuant MLP. Also removes plan/workflow terminology (DEC-5, "pre-calibration") from source and test comments per the plan code-style note. Still pending (Round 3 / cluster): the full tiny Qwen pipeline fixture + e2e subprocess export test (needs diffusers' tokenizer/text-encoder construction and a GPU) and the AC-7 cluster run. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Jingyu Xin --- .../diffusers/quantization/models_utils.py | 2 +- examples/diffusers/quantization/quantize.py | 2 +- examples/diffusers/quantization/utils.py | 2 +- .../diffusers/test_qwen_block_range_recipe.py | 4 +- .../export/test_diffusers_qwen_export.py | 134 ++++++++++++++++++ .../test_svdquant_forward_fold.py | 8 +- 6 files changed, 143 insertions(+), 9 deletions(-) create mode 100644 tests/unit/torch/export/test_diffusers_qwen_export.py diff --git a/examples/diffusers/quantization/models_utils.py b/examples/diffusers/quantization/models_utils.py index e7faf7589c1..eb2ff1f6039 100644 --- a/examples/diffusers/quantization/models_utils.py +++ b/examples/diffusers/quantization/models_utils.py @@ -248,7 +248,7 @@ def get_model_filter_func( }, # Quantize only ``transformer_blocks``; keep the first 2 and last 2 blocks # (and everything outside ``transformer_blocks``) in original precision. - # Applied pre-calibration via ``build_block_range_quant_cfg`` so SVDQuant + # Applied before calibration via ``build_block_range_quant_cfg`` so SVDQuant # never mutates the excluded blocks' weights. "block_range": { "exclude_first_n": 2, diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 57b9db589db..e81dc51a66b 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -171,7 +171,7 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any: # Apply the transformer-block-range recipe (e.g. Qwen-Image) BEFORE # calibration. This restricts quantization to `transformer_blocks` and - # excludes the first/last N blocks. It must run pre-calibration so that + # excludes the first/last N blocks. It must run before calibration so that # SVDQuant does not mutate the weights of the excluded blocks. The recipe # is format-agnostic (applies to FP8/NVFP4/SVDQuant alike). block_range = MODEL_DEFAULTS.get(self.model_config.model_type, {}).get("block_range") diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index be3f6276db9..fd57c328378 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -114,7 +114,7 @@ def filter_func_wan_video(name: str) -> bool: # Qwen-Image's transformer has 60 ``transformer_blocks``. The recipe quantizes # only those blocks while keeping the first two and last two -- and everything # outside ``transformer_blocks`` -- in original precision. The model-agnostic, -# pre-calibration form of this recipe (deriving the block count from the model) +# before-calibration form of this recipe (deriving the block count from the model) # lives in quantize.py; this name-only filter covers the plain FP8/NVFP4 path # for the full 60-block Qwen-Image transformer. QWEN_IMAGE_NUM_TRANSFORMER_BLOCKS = 60 diff --git a/tests/examples/diffusers/test_qwen_block_range_recipe.py b/tests/examples/diffusers/test_qwen_block_range_recipe.py index 82690ae252c..8214ec1c124 100644 --- a/tests/examples/diffusers/test_qwen_block_range_recipe.py +++ b/tests/examples/diffusers/test_qwen_block_range_recipe.py @@ -16,8 +16,8 @@ """Unit tests for the transformer-block-range quantization recipe (e.g. Qwen-Image). The recipe must quantize only the linears under ``transformer_blocks`` while -excluding the first/last N blocks, and it must be expressible as pre-calibration -``quant_cfg`` rules (so SVDQuant never mutates the excluded blocks' weights). +excluding the first/last N blocks, and it must be expressible as ``quant_cfg`` +rules applied before calibration (so SVDQuant never mutates the excluded blocks). """ import re diff --git a/tests/unit/torch/export/test_diffusers_qwen_export.py b/tests/unit/torch/export/test_diffusers_qwen_export.py new file mode 100644 index 00000000000..6577a8800f0 --- /dev/null +++ b/tests/unit/torch/export/test_diffusers_qwen_export.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the Qwen-Image SVDQuant diffusers export path. + +Covers the three pieces added for Qwen support: +- the Qwen branch of ``generate_diffusion_dummy_inputs`` (validated by running the + dummy forward on a real tiny ``QwenImageTransformer2DModel``), +- the strict-failure mode of ``_fuse_qkv_linears_diffusion``, +- promotion of quantizer-owned SVDQuant tensors to clean module-level keys that + survive ``hide_quantizers_from_state_dict``. +""" + +from functools import partial + +import pytest +import torch +import torch.nn as nn + +import modelopt.torch.quantization as mtq +from modelopt.torch.export.diffusers_utils import ( + generate_diffusion_dummy_forward_fn, + generate_diffusion_dummy_inputs, + hide_quantizers_from_state_dict, +) +from modelopt.torch.export.unified_export_hf import ( + _fuse_qkv_linears_diffusion, + _promote_quantizer_tensors_to_module, +) + + +class _MLP(nn.Module): + def __init__(self, dim: int = 64): + super().__init__() + self.fc1 = nn.Linear(dim, dim) + self.fc2 = nn.Linear(dim, dim) + + def forward(self, x): + return self.fc2(torch.relu(self.fc1(x))) + + +def _forward_loop(model, data): + for batch in data: + model(batch) + + +def _quantize(model: nn.Module, algorithm=None, dim: int = 64) -> nn.Module: + cfg = mtq.INT8_SMOOTHQUANT_CFG.copy() + if algorithm is not None: + cfg["algorithm"] = algorithm + data = [torch.randn(2, dim) for _ in range(2)] + mtq.quantize(model, cfg, partial(_forward_loop, data=data)) + return model + + +def test_qwen_dummy_inputs_drive_real_transformer_forward(): + """The Qwen dummy inputs must actually drive a real tiny Qwen transformer.""" + pytest.importorskip("diffusers") + from _test_utils.torch.diffusers_models import get_tiny_qwen_image_transformer + + transformer = get_tiny_qwen_image_transformer().to("cpu", torch.float32).eval() + + inputs = generate_diffusion_dummy_inputs(transformer, torch.device("cpu"), torch.float32) + assert inputs is not None + for key in ( + "hidden_states", + "encoder_hidden_states", + "encoder_hidden_states_mask", + "img_shapes", + "txt_seq_lens", + ): + assert key in inputs, f"missing Qwen dummy input '{key}'" + assert inputs["hidden_states"].shape[-1] == transformer.config.in_channels + assert inputs["encoder_hidden_states"].shape[-1] == transformer.config.joint_attention_dim + + # Strongest check: the generated dummy inputs run through the real model. + with torch.no_grad(): + generate_diffusion_dummy_forward_fn(transformer)() + + +def test_qwen_qkv_fusion_strict_raises_on_failed_dummy_forward(): + """strict=True turns a dummy-forward failure into a hard error; strict=False does not.""" + model = _quantize(_MLP()) + + def _boom(): + raise RuntimeError("dummy forward failed") + + with pytest.raises(RuntimeError): + _fuse_qkv_linears_diffusion(model, dummy_forward_fn=_boom, strict=True) + + # Non-strict path warns and returns without raising. + _fuse_qkv_linears_diffusion(model, dummy_forward_fn=_boom, strict=False) + + +def test_svdquant_promotion_survives_hide_quantizers(): + """Promoted LoRA + pre_quant_scale land on the module under clean keys and + survive ``hide_quantizers_from_state_dict`` (which strips the quantizers).""" + model = _quantize(_MLP(), algorithm={"method": "svdquant", "lowrank": 8}) + + _promote_quantizer_tensors_to_module(model) + + linears = [m for m in model.modules() if isinstance(m, torch.nn.Linear)] + assert linears + for module in linears: + assert hasattr(module, "svdquant_lora_a") + assert hasattr(module, "svdquant_lora_b") + # INT8_SMOOTHQUANT produces a pre_quant_scale that is promoted too. + assert hasattr(module, "pre_quant_scale") + # Rank-consistent shapes: lora_a [rank, in], lora_b [out, rank]. + assert module.svdquant_lora_a.shape[1] == module.in_features + assert module.svdquant_lora_b.shape[0] == module.out_features + assert module.svdquant_lora_a.shape[0] == module.svdquant_lora_b.shape[1] + + with hide_quantizers_from_state_dict(model): + keys = list(model.state_dict().keys()) + + assert any(k.endswith(".svdquant_lora_a") for k in keys) + assert any(k.endswith(".svdquant_lora_b") for k in keys) + assert any(k.endswith(".pre_quant_scale") for k in keys) + # Clean keys only: no quantizer-prefixed keys remain once quantizers are hidden. + assert not any("weight_quantizer" in k for k in keys) + assert not any("input_quantizer" in k for k in keys) diff --git a/tests/unit/torch/quantization/test_svdquant_forward_fold.py b/tests/unit/torch/quantization/test_svdquant_forward_fold.py index f6a8a83f4b4..085414f93a3 100644 --- a/tests/unit/torch/quantization/test_svdquant_forward_fold.py +++ b/tests/unit/torch/quantization/test_svdquant_forward_fold.py @@ -15,9 +15,9 @@ """SVDQuant forward / fold coverage. -These tests protect the invariants the diffusers SVDQuant export relies on -(DEC-5: the LoRA factors stay on the ``weight_quantizer`` in the live model; the -export layer promotes them). They complement (and do not modify) the existing +These tests protect the invariants the diffusers SVDQuant export relies on: the +LoRA factors stay on the ``weight_quantizer`` in the live model and the export +layer promotes them. They complement (and do not modify) the existing ``test_calib.py::test_svdquant_lora_weights``. """ @@ -58,7 +58,7 @@ def _quantized_linears(model: nn.Module): def test_svdquant_lora_stays_on_weight_quantizer(): - """DEC-5: LoRA lives on the quantizer, not the module (export promotes it).""" + """LoRA lives on the quantizer, not the module (the export layer promotes it).""" model = _quantize_svdquant() linears = _quantized_linears(model) assert linears From 027a5e200a362f554fe1d5fcb90348c50b7de2b6 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 11 Jun 2026 18:15:05 -0700 Subject: [PATCH 05/15] Qwen-Image SVDQuant: offline tiny Qwen fixture + e2e export test Round 3 (addresses round-2 Codex review): - Fix the tiny Qwen-Image pipeline fixture (tests/_test_utils/torch/diffusers_models.py): build the Qwen2.5-VL text encoder inline from a tiny Qwen2_5_VLConfig (no Hub model load; the previous hf-internal-testing/...Qwen2_5_VL id does not exist), load the tokenizer from the tiny ...Qwen2VL id diffusers' own fast test uses, build the transformer with num_layers=6 (so the corrected first-2/last-2 block-range recipe, which needs >=6 blocks, is valid) and joint_attention_dim=16 matching the text encoder hidden_size, and a z_dim=4 VAE. Mirrors diffusers' QwenImagePipelineFastTests.get_dummy_components. - Add Qwen FP8 / NVFP4 / NVFP4-SVDQuant cases to test_export_diffusers_hf_ckpt.py using the tiny fixture. The test opens transformer/config.json and the exported safetensors and asserts: quant_method=modelopt; no weight_quantizer / input_quantizer._amax keys; for SVDQuant, promoted .svdquant_lora_a/b + .pre_quant_scale keys, config group pre_quant_scale/has_zero_point/ lora_rank, and non-empty ignore (excluded blocks); for plain formats, weight_scale. GPU/diffusers skip-guarded. - Drop remaining workflow terminology (Step 4.5, before-calibration) from the comments I introduced. Still cluster-only (no GPU here): executing these tests and the AC-7 harness run. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Jingyu Xin --- examples/diffusers/quantization/utils.py | 2 +- modelopt/torch/export/unified_export_hf.py | 4 +- tests/_test_utils/torch/diffusers_models.py | 87 +++++++++++------ .../test_export_diffusers_hf_ckpt.py | 93 +++++++++++++++++++ 4 files changed, 153 insertions(+), 33 deletions(-) diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index fd57c328378..c3cfdcd5cdd 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -114,7 +114,7 @@ def filter_func_wan_video(name: str) -> bool: # Qwen-Image's transformer has 60 ``transformer_blocks``. The recipe quantizes # only those blocks while keeping the first two and last two -- and everything # outside ``transformer_blocks`` -- in original precision. The model-agnostic, -# before-calibration form of this recipe (deriving the block count from the model) +# config-driven form of this recipe (deriving the block count from the model) # lives in quantize.py; this name-only filter covers the plain FP8/NVFP4 path # for the full 60-block Qwen-Image transformer. QWEN_IMAGE_NUM_TRANSFORMER_BLOCKS = 60 diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 6d870a726bd..6f08b653fa4 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -1154,8 +1154,8 @@ def _export_diffusers_checkpoint( # Step 4: Process quantized modules (convert weights, register scales) _process_quantized_modules(component, component_dtype, is_modelopt_qlora=False) - # Step 4.5: Promote quantizer-owned tensors (AWQ pre_quant_scale and - # SVDQuant LoRA factors) onto the module so they survive + # Promote quantizer-owned tensors (AWQ pre_quant_scale and SVDQuant + # LoRA factors) onto the module so they survive # hide_quantizers_from_state_dict and are embedded in the component's # main safetensors under clean, AWQ-aligned keys. _promote_quantizer_tensors_to_module(component) diff --git a/tests/_test_utils/torch/diffusers_models.py b/tests/_test_utils/torch/diffusers_models.py index a42ebdd8ffb..3a5e4adc982 100644 --- a/tests/_test_utils/torch/diffusers_models.py +++ b/tests/_test_utils/torch/diffusers_models.py @@ -312,20 +312,16 @@ def get_tiny_qwen_image_vae(**config_kwargs): def create_tiny_qwen_image_pipeline_dir(tmp_path: Path) -> Path: - """Create and save a tiny Qwen-Image pipeline to a directory (SKETCH). - - Mirrors ``create_tiny_wan22_pipeline_dir``. Needs in-container validation; the - fragile piece is the Qwen2.5-VL text encoder. This prefers a tiny-random HF model - (as Wan uses ``hf-internal-testing/tiny-random-t5``); if that id drifts or the - config schema differs across transformers versions, copy the text-encoder - construction from diffusers' own QwenImage fast test - (``tests/pipelines/qwenimage/test_qwenimage.py``). - - For the DMD2 mock-data training path the transformer consumes the dataloader's - embeddings rather than the text encoder, so the bundled tiny text encoder only - needs to load; its hidden size is intentionally decoupled from the transformer's - ``joint_attention_dim`` (set the dataloader's ``text_embed_dim`` to match instead). - The saved dir loads with ``QwenImagePipeline.from_pretrained(path)``. + """Create and save a tiny, (mostly) offline Qwen-Image pipeline to a directory. + + Mirrors diffusers' ``QwenImagePipelineFastTests.get_dummy_components``: the + Qwen2.5-VL text encoder is built inline from a tiny ``Qwen2_5_VLConfig`` (no Hub + model load); only the tokenizer is fetched from the tiny ``Qwen2VL`` test repo + (building a Qwen tokenizer fully offline is impractical). The transformer uses + ``num_layers=6`` so the first-2/last-2 block-range recipe is valid, and its + ``joint_attention_dim`` matches the text encoder ``hidden_size`` (16) so the + pipeline runs end-to-end during quantization calibration. The saved dir loads + with ``QwenImagePipeline.from_pretrained(path)``. """ if QwenImageTransformer2DModel is None or AutoencoderKLQwenImage is None: pytest.skip("QwenImage diffusers classes not available in this diffusers version.") @@ -333,27 +329,58 @@ def create_tiny_qwen_image_pipeline_dir(tmp_path: Path) -> Path: transformers = pytest.importorskip("transformers") - # Tiny Qwen2.5-VL text encoder + matching Qwen2 tokenizer (loaded, but bypassed - # during DMD2 mock-data training). - # NOTE (validated 2026-06-06): the hf-internal-testing id below does NOT exist on the - # Hub, so this fixture currently skips. To make the recipe e2e runnable in CI, - # construct the encoder inline from a tiny ``Qwen2_5_VLConfig`` (nested text + vision - # config) — mirror diffusers' ``QwenImagePipelineFastTests.get_dummy_components`` in - # ``tests/pipelines/qwenimage/test_qwenimage.py``. - tiny_id = "hf-internal-testing/tiny-random-Qwen2_5_VLForConditionalGeneration" + # Tiny Qwen2.5-VL text encoder, built offline from a tiny config (no Hub model + # load), mirroring diffusers' QwenImagePipelineFastTests.get_dummy_components. + qwen_vl_config = transformers.Qwen2_5_VLConfig( + text_config={ + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "rope_scaling": { + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000.0, + }, + vision_config={ + "depth": 2, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 2, + "out_hidden_size": 16, + }, + hidden_size=16, + vocab_size=152064, + vision_end_token_id=151653, + vision_start_token_id=151652, + vision_token_id=151654, + ) + text_encoder = transformers.Qwen2_5_VLForConditionalGeneration(qwen_vl_config).eval() + + # The Qwen tokenizer cannot be built fully offline; load the tiny one diffusers' + # own fast test uses (this id exists, unlike the Qwen2.5-VL one previously tried). try: - text_encoder = transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained(tiny_id) - tokenizer = transformers.Qwen2Tokenizer.from_pretrained(tiny_id) - except Exception as exc: # pragma: no cover - depends on hub availability / version - pytest.skip( - f"tiny Qwen2.5-VL text encoder unavailable ({exc}); " - "copy the fixture from diffusers' QwenImage fast test" + tokenizer = transformers.Qwen2Tokenizer.from_pretrained( + "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" ) + except Exception as exc: # pragma: no cover - depends on hub availability + pytest.skip(f"tiny Qwen tokenizer unavailable ({exc})") torch.manual_seed(0) - transformer = get_tiny_qwen_image_transformer() + # num_layers=6 so the first-2/last-2 block-range recipe (which needs >=6 blocks) + # is valid; joint_attention_dim must match the text encoder hidden_size (16). + transformer = get_tiny_qwen_image_transformer( + num_layers=6, + in_channels=16, + out_channels=4, + joint_attention_dim=16, + num_attention_heads=3, + ) torch.manual_seed(0) - vae = get_tiny_qwen_image_vae() + vae = get_tiny_qwen_image_vae(z_dim=4, latents_mean=[0.0] * 4, latents_std=[1.0] * 4) scheduler = FlowMatchEulerDiscreteScheduler( base_image_seq_len=256, diff --git a/tests/examples/diffusers/test_export_diffusers_hf_ckpt.py b/tests/examples/diffusers/test_export_diffusers_hf_ckpt.py index 88821bbf8f7..6ea8d22b51e 100644 --- a/tests/examples/diffusers/test_export_diffusers_hf_ckpt.py +++ b/tests/examples/diffusers/test_export_diffusers_hf_ckpt.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from pathlib import Path from typing import NamedTuple @@ -130,6 +131,98 @@ def test_diffusers_hf_ckpt_export(model: DiffuserHfExportModel, tmp_path: Path) assert len(weight_files) > 0, f"No weight files (.safetensors or .bin) found in {hf_ckpt_dir}" +class QwenHfExportModel(NamedTuple): + format_type: str + quant_algo: str + is_svdquant: bool + + def quantize_and_export_hf(self, tiny_qwen_image_path: str, tmp_path: Path) -> Path: + hf_ckpt_dir = tmp_path / f"qwen_{self.format_type}_{self.quant_algo}_hf_ckpt" + cmd_args = [ + "python", + "quantize.py", + "--model", + "qwen-image", + "--override-model-path", + str(tiny_qwen_image_path), + "--format", + self.format_type, + "--quant-algo", + self.quant_algo, + "--collect-method", + "default", + "--model-dtype", + "BFloat16", + "--trt-high-precision-dtype", + "BFloat16", + "--calib-size", + "2", + "--batch-size", + "1", + "--n-steps", + "2", + "--hf-ckpt-dir", + str(hf_ckpt_dir), + ] + if self.is_svdquant: + cmd_args.extend(["--lowrank", "8"]) + run_example_command(cmd_args, "diffusers/quantization") + return hf_ckpt_dir + + +@pytest.mark.parametrize( + "qwen_model", + [ + pytest.param(QwenHfExportModel("fp8", "max", False), marks=minimum_sm(89)), + pytest.param(QwenHfExportModel("fp4", "max", False), marks=minimum_sm(89)), + pytest.param(QwenHfExportModel("fp4", "svdquant", True), marks=minimum_sm(89)), + ], + ids=["qwen_fp8_max", "qwen_nvfp4_max", "qwen_nvfp4_svdquant"], +) +def test_qwen_image_hf_ckpt_export( + qwen_model: QwenHfExportModel, tiny_qwen_image_path: str, tmp_path: Path +) -> None: + from safetensors import safe_open + + hf_ckpt_dir = qwen_model.quantize_and_export_hf(tiny_qwen_image_path, tmp_path) + assert hf_ckpt_dir.exists(), f"HF checkpoint directory was not created: {hf_ckpt_dir}" + + # The transformer is the quantized component. + transformer_dir = hf_ckpt_dir / "transformer" + config_path = transformer_dir / "config.json" + assert config_path.exists(), f"no transformer/config.json in {hf_ckpt_dir}" + quant_config = json.loads(config_path.read_text()).get("quantization_config") + assert quant_config is not None, "missing quantization_config" + assert quant_config.get("quant_method") == "modelopt" + + keys: set[str] = set() + safetensors_files = sorted(transformer_dir.rglob("*.safetensors")) + assert safetensors_files, f"no safetensors in {transformer_dir}" + for path in safetensors_files: + with safe_open(str(path), framework="pt") as handle: + keys.update(handle.keys()) + + # No live quantizer state should leak into the exported checkpoint. + assert not any("weight_quantizer" in k for k in keys), "quantizer keys leaked into export" + assert not any("input_quantizer._amax" in k for k in keys) + + if qwen_model.is_svdquant: + # Promoted SVDQuant tensors under clean module-level keys. + assert any(k.endswith(".svdquant_lora_a") for k in keys) + assert any(k.endswith(".svdquant_lora_b") for k in keys) + assert any(k.endswith(".pre_quant_scale") for k in keys) + # config schema modeled on nvfp4_awq. + group = next(iter(quant_config.get("config_groups", {}).values()), {}) + assert group.get("pre_quant_scale") is True + assert group.get("has_zero_point") is False + assert "lora_rank" in group + # Excluded first-2/last-2 blocks + non-transformer_blocks modules are ignored. + assert quant_config.get("ignore"), "expected excluded modules in 'ignore'" + else: + # Plain FP8/NVFP4 expose weight scales. + assert any(k.endswith(".weight_scale") for k in keys), "no weight_scale in export" + + class Wan22HfExportModel(NamedTuple): model: str backbone: str | None From a51105814724655bfa645e8a0f95dab96ef2c5c3 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 11 Jun 2026 18:28:49 -0700 Subject: [PATCH 06/15] Qwen-Image SVDQuant: offline tokenizer, stronger export test, drop Step comments Round 4 (addresses round-3 Codex review): - Offline tiny Qwen tokenizer: _build_local_qwen2_tokenizer builds a deterministic byte-level Qwen2 tokenizer locally (GPT-2 byte->unicode vocab + Qwen specials, empty merges) instead of a Hub load; removes the tokenizer-unavailable skip path. - Strengthen test_qwen_image_hf_ckpt_export: assert equal module-prefix sets for .svdquant_lora_a/.svdquant_lora_b/.pre_quant_scale; promoted linears are a subset of weight-scaled linears; only the middle blocks {2,3} of 6 are quantized (first-2/ last-2 excluded); lora_a=[rank,in]/lora_b=[out,rank] with rank == --lowrank (8); NVFP4 weight_scale_2 present; exact config (quant_algo=NVFP4_SVD, lora_rank=8, pre_quant_scale=True, has_zero_point=False, non-empty ignore). - Remove the remaining "Step N:" workflow comments from unified_export_hf.py (the round-3 "grep clean" claim was wrong; verified clean across the whole file). Still cluster-only (no GPU/torch/diffusers here): executing these tests and the AC-7 harness run. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Jingyu Xin --- modelopt/torch/export/unified_export_hf.py | 22 +++--- tests/_test_utils/torch/diffusers_models.py | 39 ++++++++--- .../test_export_diffusers_hf_ckpt.py | 69 ++++++++++++++++--- 3 files changed, 101 insertions(+), 29 deletions(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 6f08b653fa4..164f2a0e1fe 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -1100,7 +1100,7 @@ def _export_diffusers_checkpoint( """ export_dir = Path(export_dir) - # Step 1: Get all pipeline components (nn.Module, tokenizers, schedulers, etc.) + # Get all pipeline components (nn.Module, tokenizers, schedulers, etc.) all_components = get_diffusion_components(pipe, components) if not all_components: @@ -1122,7 +1122,7 @@ def _export_diffusers_checkpoint( except Exception: is_diffusers_pipe = False - # Step 3: Export each nn.Module component with quantization handling + # Export each nn.Module component with quantization handling for component_name, component in module_components.items(): is_quantized = has_quantized_modules(component) status = "quantized" if is_quantized else "non-quantized" @@ -1141,7 +1141,7 @@ def _export_diffusers_checkpoint( component_dtype = dtype if dtype is not None else infer_dtype_from_model(component) if is_quantized: - # Step 3.5: Fuse QKV linears that share the same input (unify amax values) + # Fuse QKV linears that share the same input (unify amax values) # This is similar to requantize_resmooth_fused_llm_layers but simplified for diffusion # TODO: Add pre_quant_scale handling and FFN fusion for AWQ-style quantization print(f" Running QKV fusion for {component_name}...") @@ -1151,7 +1151,7 @@ def _export_diffusers_checkpoint( is_qwen_component = "qwen" in type(component).__name__.lower() _fuse_qkv_linears_diffusion(component, strict=is_qwen_component) - # Step 4: Process quantized modules (convert weights, register scales) + # Process quantized modules (convert weights, register scales) _process_quantized_modules(component, component_dtype, is_modelopt_qlora=False) # Promote quantizer-owned tensors (AWQ pre_quant_scale and SVDQuant @@ -1160,7 +1160,7 @@ def _export_diffusers_checkpoint( # main safetensors under clean, AWQ-aligned keys. _promote_quantizer_tensors_to_module(component) - # Step 5: Build quantization config + # Build quantization config quant_config = get_quant_config(component, is_modelopt_qlora=False) if quant_config: quantization_details = quant_config.get("quantization", {}) @@ -1171,7 +1171,7 @@ def _export_diffusers_checkpoint( quantization_details["lora_rank"] = svdquant_rank hf_quant_config = convert_hf_quant_config_format(quant_config) if quant_config else None - # Step 6: Save the component + # Save the component # - diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter # - for non-diffusers modules (e.g., LTX-2 transformer), fall back to torch.save if hasattr(component, "save_pretrained"): @@ -1181,7 +1181,7 @@ def _export_diffusers_checkpoint( with hide_quantizers_from_state_dict(component): _save_component_state_dict_safetensors(component, component_export_dir) - # Step 7: Post-process — merge, metadata, padding, swizzle + # Post-process — merge, metadata, padding, swizzle _postprocess_safetensors( component_export_dir, pipe, @@ -1189,7 +1189,7 @@ def _export_diffusers_checkpoint( **kwargs, ) - # Step 8: Update config.json with quantization info + # Update config.json with quantization info if hf_quant_config is not None: config_path = component_export_dir / "config.json" if config_path.exists(): @@ -1204,7 +1204,7 @@ def _export_diffusers_checkpoint( else: _save_component_state_dict_safetensors(component, component_export_dir) - # Step 9: Update config.json with sparse attention info (both quantized and non-quantized) + # Update config.json with sparse attention info (both quantized and non-quantized) if export_sparse_attention_config is not None: sparse_attn_config = export_sparse_attention_config(component) if sparse_attn_config is not None: @@ -1219,7 +1219,7 @@ def _export_diffusers_checkpoint( print(f" Saved to: {component_export_dir}") - # Step 4: Export non-nn.Module components (tokenizers, schedulers, feature extractors, etc.) + # Export non-nn.Module components (tokenizers, schedulers, feature extractors, etc.) if is_diffusers_pipe: for component_name, component in all_components.items(): # Skip nn.Module components (already handled above) @@ -1247,7 +1247,7 @@ def _export_diffusers_checkpoint( print(f" Saved to: {component_export_dir}") - # Step 5: For pipelines, also save model_index.json + # For pipelines, also save model_index.json if is_diffusers_pipe: model_index_path = export_dir / "model_index.json" is_partial_export = components is not None diff --git a/tests/_test_utils/torch/diffusers_models.py b/tests/_test_utils/torch/diffusers_models.py index 3a5e4adc982..71e0e19cf78 100644 --- a/tests/_test_utils/torch/diffusers_models.py +++ b/tests/_test_utils/torch/diffusers_models.py @@ -311,6 +311,35 @@ def get_tiny_qwen_image_vae(**config_kwargs): return AutoencoderKLQwenImage(**kwargs) +def _build_local_qwen2_tokenizer(out_dir: Path): + """Build a tiny, fully offline byte-level Qwen2 tokenizer (no Hub access). + + Uses the GPT-2/Qwen byte->unicode mapping for the 256 single-byte tokens plus + Qwen's core special tokens, with an empty merge table (pure byte-level + fallback). This is enough to tokenize calibration prompts so the pipeline runs + end-to-end; it is not meant for high-quality text. + """ + import json + + import transformers + from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode + + out_dir.mkdir(parents=True, exist_ok=True) + vocab = {token: idx for idx, token in enumerate(bytes_to_unicode().values())} + for special in ("<|endoftext|>", "<|im_start|>", "<|im_end|>"): + vocab.setdefault(special, len(vocab)) + (out_dir / "vocab.json").write_text(json.dumps(vocab)) + (out_dir / "merges.txt").write_text("#version: 0.2\n") + + return transformers.Qwen2Tokenizer( + vocab_file=str(out_dir / "vocab.json"), + merges_file=str(out_dir / "merges.txt"), + unk_token="<|endoftext|>", + eos_token="<|endoftext|>", + pad_token="<|endoftext|>", + ) + + def create_tiny_qwen_image_pipeline_dir(tmp_path: Path) -> Path: """Create and save a tiny, (mostly) offline Qwen-Image pipeline to a directory. @@ -360,14 +389,8 @@ def create_tiny_qwen_image_pipeline_dir(tmp_path: Path) -> Path: ) text_encoder = transformers.Qwen2_5_VLForConditionalGeneration(qwen_vl_config).eval() - # The Qwen tokenizer cannot be built fully offline; load the tiny one diffusers' - # own fast test uses (this id exists, unlike the Qwen2.5-VL one previously tried). - try: - tokenizer = transformers.Qwen2Tokenizer.from_pretrained( - "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" - ) - except Exception as exc: # pragma: no cover - depends on hub availability - pytest.skip(f"tiny Qwen tokenizer unavailable ({exc})") + # Deterministic local byte-level Qwen2 tokenizer (built offline; no Hub, no skip). + tokenizer = _build_local_qwen2_tokenizer(tmp_path / "qwen_tokenizer") torch.manual_seed(0) # num_layers=6 so the first-2/last-2 block-range recipe (which needs >=6 blocks) diff --git a/tests/examples/diffusers/test_export_diffusers_hf_ckpt.py b/tests/examples/diffusers/test_export_diffusers_hf_ckpt.py index 6ea8d22b51e..58fe587cf0e 100644 --- a/tests/examples/diffusers/test_export_diffusers_hf_ckpt.py +++ b/tests/examples/diffusers/test_export_diffusers_hf_ckpt.py @@ -179,6 +179,29 @@ def quantize_and_export_hf(self, tiny_qwen_image_path: str, tmp_path: Path) -> P ], ids=["qwen_fp8_max", "qwen_nvfp4_max", "qwen_nvfp4_svdquant"], ) +def _module_prefixes(keys: set[str], suffix: str) -> set[str]: + """Module paths (key minus suffix) for every key ending in ``suffix``.""" + return {k[: -len(suffix)] for k in keys if k.endswith(suffix)} + + +def _block_indices(prefixes: set[str]) -> set[int]: + """transformer_blocks indices referenced by a set of module prefixes.""" + import re + + indices = set() + for prefix in prefixes: + match = re.search(r"transformer_blocks\.(\d+)\.", prefix) + if match: + indices.add(int(match.group(1))) + return indices + + +# Tiny Qwen fixture has 6 transformer blocks; the recipe excludes the first 2 and +# last 2, so only blocks 2 and 3 are quantized. +_QWEN_QUANTIZED_BLOCKS = {2, 3} +_QWEN_LORA_RANK = 8 + + def test_qwen_image_hf_ckpt_export( qwen_model: QwenHfExportModel, tiny_qwen_image_path: str, tmp_path: Path ) -> None: @@ -196,31 +219,57 @@ def test_qwen_image_hf_ckpt_export( assert quant_config.get("quant_method") == "modelopt" keys: set[str] = set() + lora_tensors: dict[str, "object"] = {} safetensors_files = sorted(transformer_dir.rglob("*.safetensors")) assert safetensors_files, f"no safetensors in {transformer_dir}" for path in safetensors_files: with safe_open(str(path), framework="pt") as handle: - keys.update(handle.keys()) + for key in handle.keys(): + keys.add(key) + if key.endswith(".svdquant_lora_a") or key.endswith(".svdquant_lora_b"): + lora_tensors[key] = handle.get_tensor(key) # No live quantizer state should leak into the exported checkpoint. assert not any("weight_quantizer" in k for k in keys), "quantizer keys leaked into export" assert not any("input_quantizer._amax" in k for k in keys) + # Recipe: only the middle blocks are quantized (first-2/last-2 excluded). + weight_scale_prefixes = _module_prefixes(keys, ".weight_scale") + assert _block_indices(weight_scale_prefixes) == _QWEN_QUANTIZED_BLOCKS, ( + f"expected only blocks {_QWEN_QUANTIZED_BLOCKS} quantized" + ) + if qwen_model.is_svdquant: - # Promoted SVDQuant tensors under clean module-level keys. - assert any(k.endswith(".svdquant_lora_a") for k in keys) - assert any(k.endswith(".svdquant_lora_b") for k in keys) - assert any(k.endswith(".pre_quant_scale") for k in keys) - # config schema modeled on nvfp4_awq. + a_prefixes = _module_prefixes(keys, ".svdquant_lora_a") + b_prefixes = _module_prefixes(keys, ".svdquant_lora_b") + pqs_prefixes = _module_prefixes(keys, ".pre_quant_scale") + assert a_prefixes, "no promoted svdquant_lora_a keys" + # Every promoted linear carries lora_a, lora_b, and pre_quant_scale. + assert a_prefixes == b_prefixes == pqs_prefixes + # ...and each is a quantized linear, only in the middle blocks. + assert a_prefixes <= weight_scale_prefixes + assert _block_indices(a_prefixes) == _QWEN_QUANTIZED_BLOCKS + # Rank-consistent shapes; lora_a=[rank, in], lora_b=[out, rank], rank == --lowrank. + for key, tensor in lora_tensors.items(): + if key.endswith(".svdquant_lora_a"): + assert tensor.shape[0] == _QWEN_LORA_RANK + else: + assert tensor.shape[1] == _QWEN_LORA_RANK + # NVFP4 secondary scales are present. + assert any(k.endswith(".weight_scale_2") for k in keys) + # config schema (modeled on nvfp4_awq). + assert quant_config.get("quant_algo") == "NVFP4_SVD" group = next(iter(quant_config.get("config_groups", {}).values()), {}) + assert group.get("lora_rank") == _QWEN_LORA_RANK assert group.get("pre_quant_scale") is True assert group.get("has_zero_point") is False - assert "lora_rank" in group - # Excluded first-2/last-2 blocks + non-transformer_blocks modules are ignored. assert quant_config.get("ignore"), "expected excluded modules in 'ignore'" else: - # Plain FP8/NVFP4 expose weight scales. - assert any(k.endswith(".weight_scale") for k in keys), "no weight_scale in export" + # Plain FP8/NVFP4: weight scales present, no SVDQuant tensors. + assert weight_scale_prefixes, "no weight_scale in export" + assert not any(k.endswith(".svdquant_lora_a") for k in keys) + if qwen_model.format_type == "fp4": + assert any(k.endswith(".weight_scale_2") for k in keys) class Wan22HfExportModel(NamedTuple): From a6a3d590bcf0b7e4a8581b1443a449f8dc06f5e6 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 11 Jun 2026 18:39:10 -0700 Subject: [PATCH 07/15] Qwen-Image SVDQuant: fix misplaced parametrize decorator + tighten export test Round 5 (addresses round-4 Codex review, which found a regression I introduced): - The round-4 edit inserted the _module_prefixes/_block_indices helpers between @pytest.mark.parametrize("qwen_model", ...) and test_qwen_image_hf_ckpt_export, so the decorator was attached to the helper and the test would request an undefined qwen_model fixture. Moved the helpers/constants above the decorator so it directly decorates the test (verified via ast: the test now carries the qwen_model parametrization and the helper is undecorated). - Tightened SVDQuant assertions: require a_prefixes == b_prefixes == pqs_prefixes == weight_scale_prefixes (every quantized linear is promoted, no gaps), and assert every quantized prefix is under transformer_blocks (nothing outside is quantized), in addition to the {2,3}-only block check. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Jingyu Xin --- .../test_export_diffusers_hf_ckpt.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/examples/diffusers/test_export_diffusers_hf_ckpt.py b/tests/examples/diffusers/test_export_diffusers_hf_ckpt.py index 58fe587cf0e..9e2e30a160c 100644 --- a/tests/examples/diffusers/test_export_diffusers_hf_ckpt.py +++ b/tests/examples/diffusers/test_export_diffusers_hf_ckpt.py @@ -170,15 +170,6 @@ def quantize_and_export_hf(self, tiny_qwen_image_path: str, tmp_path: Path) -> P return hf_ckpt_dir -@pytest.mark.parametrize( - "qwen_model", - [ - pytest.param(QwenHfExportModel("fp8", "max", False), marks=minimum_sm(89)), - pytest.param(QwenHfExportModel("fp4", "max", False), marks=minimum_sm(89)), - pytest.param(QwenHfExportModel("fp4", "svdquant", True), marks=minimum_sm(89)), - ], - ids=["qwen_fp8_max", "qwen_nvfp4_max", "qwen_nvfp4_svdquant"], -) def _module_prefixes(keys: set[str], suffix: str) -> set[str]: """Module paths (key minus suffix) for every key ending in ``suffix``.""" return {k[: -len(suffix)] for k in keys if k.endswith(suffix)} @@ -202,6 +193,15 @@ def _block_indices(prefixes: set[str]) -> set[int]: _QWEN_LORA_RANK = 8 +@pytest.mark.parametrize( + "qwen_model", + [ + pytest.param(QwenHfExportModel("fp8", "max", False), marks=minimum_sm(89)), + pytest.param(QwenHfExportModel("fp4", "max", False), marks=minimum_sm(89)), + pytest.param(QwenHfExportModel("fp4", "svdquant", True), marks=minimum_sm(89)), + ], + ids=["qwen_fp8_max", "qwen_nvfp4_max", "qwen_nvfp4_svdquant"], +) def test_qwen_image_hf_ckpt_export( qwen_model: QwenHfExportModel, tiny_qwen_image_path: str, tmp_path: Path ) -> None: @@ -233,8 +233,13 @@ def test_qwen_image_hf_ckpt_export( assert not any("weight_quantizer" in k for k in keys), "quantizer keys leaked into export" assert not any("input_quantizer._amax" in k for k in keys) - # Recipe: only the middle blocks are quantized (first-2/last-2 excluded). + # Recipe: only the middle transformer blocks are quantized — first-2/last-2 of + # transformer_blocks are excluded, and nothing outside transformer_blocks. weight_scale_prefixes = _module_prefixes(keys, ".weight_scale") + assert weight_scale_prefixes, "no quantized linears found in export" + assert all("transformer_blocks." in p for p in weight_scale_prefixes), ( + f"a non-transformer_blocks module was quantized: {weight_scale_prefixes}" + ) assert _block_indices(weight_scale_prefixes) == _QWEN_QUANTIZED_BLOCKS, ( f"expected only blocks {_QWEN_QUANTIZED_BLOCKS} quantized" ) @@ -244,10 +249,9 @@ def test_qwen_image_hf_ckpt_export( b_prefixes = _module_prefixes(keys, ".svdquant_lora_b") pqs_prefixes = _module_prefixes(keys, ".pre_quant_scale") assert a_prefixes, "no promoted svdquant_lora_a keys" - # Every promoted linear carries lora_a, lora_b, and pre_quant_scale. - assert a_prefixes == b_prefixes == pqs_prefixes - # ...and each is a quantized linear, only in the middle blocks. - assert a_prefixes <= weight_scale_prefixes + # Every promoted linear carries lora_a, lora_b, and pre_quant_scale, and + # every quantized linear is promoted (the sets are identical). + assert a_prefixes == b_prefixes == pqs_prefixes == weight_scale_prefixes assert _block_indices(a_prefixes) == _QWEN_QUANTIZED_BLOCKS # Rank-consistent shapes; lora_a=[rank, in], lora_b=[out, rank], rank == --lowrank. for key, tensor in lora_tensors.items(): From 1cfe0b36c9590e3b28f478e364402d3fb7f3a6c4 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 11 Jun 2026 18:48:09 -0700 Subject: [PATCH 08/15] Qwen-Image SVDQuant: fix stale tiny-fixture tokenizer docstring Round 6 (round-5 review found no code blocker; only the queued docstring nit): the create_tiny_qwen_image_pipeline_dir docstring still said the tokenizer was fetched from the Hub, but Round 4 switched it to a local offline build (_build_local_qwen2_tokenizer). Updated the wording to "fully offline". Co-Authored-By: Claude Opus 4.8 Signed-off-by: Jingyu Xin --- tests/_test_utils/torch/diffusers_models.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/_test_utils/torch/diffusers_models.py b/tests/_test_utils/torch/diffusers_models.py index 71e0e19cf78..6e08518bdce 100644 --- a/tests/_test_utils/torch/diffusers_models.py +++ b/tests/_test_utils/torch/diffusers_models.py @@ -341,16 +341,16 @@ def _build_local_qwen2_tokenizer(out_dir: Path): def create_tiny_qwen_image_pipeline_dir(tmp_path: Path) -> Path: - """Create and save a tiny, (mostly) offline Qwen-Image pipeline to a directory. - - Mirrors diffusers' ``QwenImagePipelineFastTests.get_dummy_components``: the - Qwen2.5-VL text encoder is built inline from a tiny ``Qwen2_5_VLConfig`` (no Hub - model load); only the tokenizer is fetched from the tiny ``Qwen2VL`` test repo - (building a Qwen tokenizer fully offline is impractical). The transformer uses - ``num_layers=6`` so the first-2/last-2 block-range recipe is valid, and its - ``joint_attention_dim`` matches the text encoder ``hidden_size`` (16) so the - pipeline runs end-to-end during quantization calibration. The saved dir loads - with ``QwenImagePipeline.from_pretrained(path)``. + """Create and save a tiny, fully offline Qwen-Image pipeline to a directory. + + Mirrors diffusers' ``QwenImagePipelineFastTests.get_dummy_components`` but with + no Hub access: the Qwen2.5-VL text encoder is built inline from a tiny + ``Qwen2_5_VLConfig``, and the tokenizer is built locally by + ``_build_local_qwen2_tokenizer`` (byte-level vocab written to a temp dir). The + transformer uses ``num_layers=6`` so the first-2/last-2 block-range recipe is + valid, and its ``joint_attention_dim`` matches the text encoder ``hidden_size`` + (16) so the pipeline runs end-to-end during quantization calibration. The saved + dir loads with ``QwenImagePipeline.from_pretrained(path)``. """ if QwenImageTransformer2DModel is None or AutoencoderKLQwenImage is None: pytest.skip("QwenImage diffusers classes not available in this diffusers version.") From 789f4ef0a002d7a08ed2372cb3dd59037d2d8966 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 11 Jun 2026 19:01:23 -0700 Subject: [PATCH 09/15] Qwen-Image SVDQuant: add immutability + negative-loading tests; fix stale docs Round 7 (addresses round-6 Codex review's two missing-coverage items): - AC-2.2 SVDQuant immutability test (test_qwen_block_range_recipe.py): builds a 6-block backbone, snapshots the excluded first/last block linear weights, runs SVDQuant via build_block_range_quant_cfg, and asserts the excluded blocks' weights are bit-identical (never calibrated) with no LoRA, while the middle blocks {2,3} receive LoRA and have their weights modified. - AC-1 negative-loading tests (new test_qwen_pipeline_loading.py): monkeypatch MODEL_PIPELINE[QWEN_IMAGE]=None and assert the actionable ImportError; a fake pipeline asserts create_pipeline does not pass trust_remote_code. Stale-doc cleanups: the resolved pre_quant_scale TODO wording in unified_export_hf.py; the build_block_range_quant_cfg docstring (first+last+1 -> +2); the conftest "SKETCH" wording (the fixture is now a working offline build). Co-Authored-By: Claude Opus 4.8 Signed-off-by: Jingyu Xin --- .../diffusers/quantization/models_utils.py | 3 +- modelopt/torch/export/unified_export_hf.py | 6 +- tests/examples/diffusers/conftest.py | 8 +- .../diffusers/test_qwen_block_range_recipe.py | 64 ++++++++++++++++ .../diffusers/test_qwen_pipeline_loading.py | 74 +++++++++++++++++++ 5 files changed, 149 insertions(+), 6 deletions(-) create mode 100644 tests/examples/diffusers/test_qwen_pipeline_loading.py diff --git a/examples/diffusers/quantization/models_utils.py b/examples/diffusers/quantization/models_utils.py index eb2ff1f6039..86fa0750c42 100644 --- a/examples/diffusers/quantization/models_utils.py +++ b/examples/diffusers/quantization/models_utils.py @@ -330,7 +330,8 @@ def build_block_range_quant_cfg( Raises: ValueError: if the backbone has no ``block_module`` list, or it has fewer - than ``exclude_first_n + exclude_last_n + 1`` blocks. + than ``exclude_first_n + exclude_last_n + 2`` blocks (it requires at + least two quantized middle blocks). """ blocks = getattr(backbone, block_module, None) if blocks is None or not hasattr(blocks, "__len__"): diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 164f2a0e1fe..c47b12a3fd6 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -973,7 +973,8 @@ def _fuse_qkv_linears_diffusion( Note: This is a simplified version for diffusion models that: - Handles QKV fusion (shared input detection) - Filters to only fuse actual QKV projection layers (not AdaLN, FFN, etc.) - - Skips pre_quant_scale handling (TODO for future) + - Skips pre_quant_scale *fusion* (the export path promotes pre_quant_scale to + module-level keys separately; see _promote_quantizer_tensors_to_module) - Skips FFN fusion with layernorm (TODO for future) Args: @@ -1143,7 +1144,8 @@ def _export_diffusers_checkpoint( if is_quantized: # Fuse QKV linears that share the same input (unify amax values) # This is similar to requantize_resmooth_fused_llm_layers but simplified for diffusion - # TODO: Add pre_quant_scale handling and FFN fusion for AWQ-style quantization + # TODO: Add FFN fusion for AWQ-style quantization (pre_quant_scale is + # promoted to module keys at export by _promote_quantizer_tensors_to_module below) print(f" Running QKV fusion for {component_name}...") # Qwen-Image's packed-latent forward signature is non-standard; if the # dummy forward fails for it, fail loudly rather than silently skipping diff --git a/tests/examples/diffusers/conftest.py b/tests/examples/diffusers/conftest.py index e704f6d5879..625f9bc3415 100644 --- a/tests/examples/diffusers/conftest.py +++ b/tests/examples/diffusers/conftest.py @@ -35,9 +35,11 @@ def tiny_wan22_path(tmp_path_factory): def tiny_qwen_image_path(tmp_path_factory): """Create a tiny Qwen-Image pipeline and return its path (built once per session). - SKETCH fixture for the recipe-level DMD2 e2e (``test_fastgen_recipe_e2e.py``). - See ``create_tiny_qwen_image_pipeline_dir`` for caveats — notably the tiny - Qwen2.5-VL text encoder, which needs in-container validation. + Used by the diffusers Qwen export tests and the recipe-level DMD2 e2e + (``test_fastgen_recipe_e2e.py``). The pipeline is built fully offline by + ``create_tiny_qwen_image_pipeline_dir`` (inline tiny Qwen2.5-VL text encoder + + local byte-level tokenizer); it skips only when the diffusers Qwen classes are + unavailable. """ try: from _test_utils.torch.diffusers_models import create_tiny_qwen_image_pipeline_dir diff --git a/tests/examples/diffusers/test_qwen_block_range_recipe.py b/tests/examples/diffusers/test_qwen_block_range_recipe.py index 8214ec1c124..7e0110afacc 100644 --- a/tests/examples/diffusers/test_qwen_block_range_recipe.py +++ b/tests/examples/diffusers/test_qwen_block_range_recipe.py @@ -93,3 +93,67 @@ class _NoBlocks: with pytest.raises(ValueError, match="transformer_blocks"): build_block_range_quant_cfg(_NoBlocks(), exclude_first_n=2, exclude_last_n=2) + + +def test_svdquant_recipe_leaves_excluded_blocks_bit_identical(): + """AC-2.2: the pre-calibration recipe must keep the excluded first/last blocks + bit-identical through SVDQuant (whose calibration subtracts a residual from + every *enabled* linear), while the middle blocks receive LoRA.""" + import torch + import torch.nn as nn + + import modelopt.torch.quantization as mtq + + class _Block(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + return self.proj(x) + + class _Backbone(nn.Module): + def __init__(self, num_blocks: int = 6, dim: int = 32): + super().__init__() + self.transformer_blocks = nn.ModuleList(_Block(dim) for _ in range(num_blocks)) + + def forward(self, x): + for block in self.transformer_blocks: + x = block(x) + return x + + torch.manual_seed(0) + model = _Backbone(num_blocks=6, dim=32) + weights_before = { + i: model.transformer_blocks[i].proj.weight.detach().clone() for i in range(6) + } + + # Base rules quantize every linear weight/input quantizer; the recipe then + # disables all and re-enables only the middle transformer blocks (2, 3). + quant_cfg = { + "quant_cfg": [ + {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_name": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + *build_block_range_quant_cfg(model, exclude_first_n=2, exclude_last_n=2), + ], + "algorithm": {"method": "svdquant", "lowrank": 4}, + } + calib_data = [torch.randn(2, 32) for _ in range(2)] + mtq.quantize(model, quant_cfg, lambda m: [m(batch) for batch in calib_data]) + + excluded = {0, 1, 4, 5} + for idx in range(6): + proj = model.transformer_blocks[idx].proj + lora_a = getattr(getattr(proj, "weight_quantizer", None), "svdquant_lora_a", None) + if idx in excluded: + # Never calibrated -> weight bit-identical, no LoRA residual. + assert torch.equal(proj.weight, weights_before[idx]), ( + f"excluded block {idx} weight was modified" + ) + assert lora_a is None, f"excluded block {idx} unexpectedly has SVDQuant LoRA" + else: + # Calibrated -> LoRA present and the residual was subtracted from the weight. + assert lora_a is not None, f"middle block {idx} is missing SVDQuant LoRA" + assert not torch.equal(proj.weight, weights_before[idx]), ( + f"middle block {idx} weight was not modified by SVDQuant" + ) diff --git a/tests/examples/diffusers/test_qwen_pipeline_loading.py b/tests/examples/diffusers/test_qwen_pipeline_loading.py new file mode 100644 index 00000000000..612251633b5 --- /dev/null +++ b/tests/examples/diffusers/test_qwen_pipeline_loading.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Negative-path loading tests for Qwen-Image in the diffusers quantization example. + +These cover the AC-1 negative criteria without a GPU or a real model: +- selecting Qwen-Image when diffusers lacks the Qwen classes raises a clear, + actionable error (not an opaque failure); +- Qwen loading does not pass ``trust_remote_code``. +""" + +import logging +import sys +from pathlib import Path + +import pytest + +pytest.importorskip("diffusers") +pytest.importorskip("torch") + +_EXAMPLE_DIR = Path(__file__).parents[3] / "examples" / "diffusers" / "quantization" +if str(_EXAMPLE_DIR) not in sys.path: + sys.path.insert(0, str(_EXAMPLE_DIR)) + +import models_utils # noqa: E402 +import pipeline_manager # noqa: E402 +from models_utils import ModelType # noqa: E402 +from quantize_config import ModelConfig # noqa: E402 + + +def _qwen_pipeline_manager() -> "pipeline_manager.PipelineManager": + config = ModelConfig(model_type=ModelType.QWEN_IMAGE, backbone=["transformer"]) + return pipeline_manager.PipelineManager(config, logging.getLogger("qwen-loading-test")) + + +def test_missing_qwen_pipeline_raises_actionable_error(monkeypatch): + # Simulate a diffusers version without QwenImagePipeline. + monkeypatch.setitem(models_utils.MODEL_PIPELINE, ModelType.QWEN_IMAGE, None) + manager = _qwen_pipeline_manager() + with pytest.raises(ImportError, match="Qwen-Image requires"): + manager.create_pipeline() + + +def test_qwen_loading_does_not_pass_trust_remote_code(monkeypatch): + captured_kwargs: dict = {} + + class _FakeQwenPipeline: + @classmethod + def from_pretrained(cls, model_id, **kwargs): + captured_kwargs.update(kwargs) + return cls() + + def set_progress_bar_config(self, **kwargs): + pass + + monkeypatch.setitem(models_utils.MODEL_PIPELINE, ModelType.QWEN_IMAGE, _FakeQwenPipeline) + manager = _qwen_pipeline_manager() + manager.create_pipeline() + + # Qwen-Image must load without trust_remote_code. + assert captured_kwargs.get("trust_remote_code") is not True + assert "trust_remote_code" not in captured_kwargs From 521cda09585a540ec5c3659c0c882acb91a13a77 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 11 Jun 2026 19:13:12 -0700 Subject: [PATCH 10/15] Qwen-Image SVDQuant: drop invalid txt_seq_lens dummy kwarg + signature-gate Round 8 (addresses round-7 Codex review, which verified against the diffusers source that QwenImageTransformer2DModel.forward has no txt_seq_lens parameter): - _qwen_inputs no longer passes txt_seq_lens (the real forward signature is hidden_states, encoder_hidden_states, encoder_hidden_states_mask, timestep, img_shapes, guidance, return_dict). Passing txt_seq_lens would have raised an unexpected-keyword error and, because Qwen export uses strict QKV fusion, hard-failed the export. - Signature-gate the dummy inputs: filter to the kwargs the installed model's forward actually accepts (via inspect.signature), so diffusers-version drift cannot hard-fail strict fusion either. - Update test_diffusers_qwen_export.py: no longer require txt_seq_lens. - Remove AC- plan terminology from two test docstrings (code-style note). Co-Authored-By: Claude Opus 4.8 Signed-off-by: Jingyu Xin --- modelopt/torch/export/diffusers_utils.py | 14 ++++++++++++-- .../diffusers/test_qwen_block_range_recipe.py | 6 +++--- .../diffusers/test_qwen_pipeline_loading.py | 2 +- .../torch/export/test_diffusers_qwen_export.py | 1 - 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/modelopt/torch/export/diffusers_utils.py b/modelopt/torch/export/diffusers_utils.py index 823f26c8e5d..fedfe98e723 100644 --- a/modelopt/torch/export/diffusers_utils.py +++ b/modelopt/torch/export/diffusers_utils.py @@ -330,7 +330,7 @@ def _qwen_inputs() -> dict[str, Any]: # QwenImageTransformer2DModel does NOT take the standard # (hidden_states[B,C,H,W], timestep, encoder_hidden_states) triple. It expects # *packed* latents [B, (H//2)*(W//2), in_channels] plus encoder_hidden_states, - # encoder_hidden_states_mask, img_shapes, txt_seq_lens, and optional guidance. + # encoder_hidden_states_mask, img_shapes, and optional guidance. # Timesteps are continuous in [0, 1] (not the diffusers [0, 1000] scale). in_channels = getattr(cfg, "in_channels", 64) joint_attention_dim = getattr(cfg, "joint_attention_dim", 3584) @@ -353,11 +353,21 @@ def _qwen_inputs() -> dict[str, Any]: ), "timestep": torch.tensor([0.5], device=device, dtype=dtype).expand(batch_size), "img_shapes": [[(1, packed_h, packed_w)]] * batch_size, - "txt_seq_lens": [text_seq_len] * batch_size, "return_dict": False, } if guidance_embeds: dummy_inputs["guidance"] = torch.tensor([4.0], device=device, dtype=torch.float32) + + # Only pass kwargs the installed QwenImageTransformer2DModel.forward accepts + # (signatures vary across diffusers versions); prevents the strict QKV-fusion + # dummy forward from failing on an unexpected keyword argument. + import inspect + + try: + accepted = set(inspect.signature(model.forward).parameters) + dummy_inputs = {k: v for k, v in dummy_inputs.items() if k in accepted} + except (TypeError, ValueError): + pass return dummy_inputs def _generic_transformer_inputs() -> dict[str, torch.Tensor] | None: diff --git a/tests/examples/diffusers/test_qwen_block_range_recipe.py b/tests/examples/diffusers/test_qwen_block_range_recipe.py index 7e0110afacc..d1fdf3b7c65 100644 --- a/tests/examples/diffusers/test_qwen_block_range_recipe.py +++ b/tests/examples/diffusers/test_qwen_block_range_recipe.py @@ -96,9 +96,9 @@ class _NoBlocks: def test_svdquant_recipe_leaves_excluded_blocks_bit_identical(): - """AC-2.2: the pre-calibration recipe must keep the excluded first/last blocks - bit-identical through SVDQuant (whose calibration subtracts a residual from - every *enabled* linear), while the middle blocks receive LoRA.""" + """The block-range recipe must keep the excluded first/last blocks bit-identical + through SVDQuant (whose calibration subtracts a residual from every *enabled* + linear), while the middle blocks receive LoRA.""" import torch import torch.nn as nn diff --git a/tests/examples/diffusers/test_qwen_pipeline_loading.py b/tests/examples/diffusers/test_qwen_pipeline_loading.py index 612251633b5..cbc49aa5ea3 100644 --- a/tests/examples/diffusers/test_qwen_pipeline_loading.py +++ b/tests/examples/diffusers/test_qwen_pipeline_loading.py @@ -15,7 +15,7 @@ """Negative-path loading tests for Qwen-Image in the diffusers quantization example. -These cover the AC-1 negative criteria without a GPU or a real model: +These cover the negative loading paths without a GPU or a real model: - selecting Qwen-Image when diffusers lacks the Qwen classes raises a clear, actionable error (not an opaque failure); - Qwen loading does not pass ``trust_remote_code``. diff --git a/tests/unit/torch/export/test_diffusers_qwen_export.py b/tests/unit/torch/export/test_diffusers_qwen_export.py index 6577a8800f0..ad4b62b048f 100644 --- a/tests/unit/torch/export/test_diffusers_qwen_export.py +++ b/tests/unit/torch/export/test_diffusers_qwen_export.py @@ -79,7 +79,6 @@ def test_qwen_dummy_inputs_drive_real_transformer_forward(): "encoder_hidden_states", "encoder_hidden_states_mask", "img_shapes", - "txt_seq_lens", ): assert key in inputs, f"missing Qwen dummy input '{key}'" assert inputs["hidden_states"].shape[-1] == transformer.config.in_channels From d776ffb5ec079aa89298283350c0ea8d1e324891 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 11 Jun 2026 19:27:54 -0700 Subject: [PATCH 11/15] Qwen-Image SVDQuant: remove promoted export buffers from live module after export Round 9 (clears the last queued code item from Codex; no code blockers remain): _promote_quantizer_tensors_to_module left the temporary .svdquant_lora_a/b + .pre_quant_scale buffers on the live module after export. Add _remove_promoted_quantizer_tensors and call it after each quantized diffusers component is saved, so the live module is unchanged post-export (repeated export / module reuse stay correct). The quantizer-owned tensors are untouched. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Jingyu Xin --- modelopt/torch/export/unified_export_hf.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index c47b12a3fd6..abbea9fc37e 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -1073,6 +1073,21 @@ def _promote_quantizer_tensors_to_module(component: nn.Module) -> None: sub_module.register_buffer("svdquant_lora_b", lora_b.detach().clone()) +def _remove_promoted_quantizer_tensors(component: nn.Module) -> None: + """Undo :func:`_promote_quantizer_tensors_to_module`. + + Removes the temporary module-level export buffers (``svdquant_lora_a/b`` and + ``pre_quant_scale``) so the live module is unchanged after export, keeping + repeated export / post-export module reuse correct. The quantizer-owned tensors + (``weight_quantizer.svdquant_lora_a/b``, ``input_quantizer._pre_quant_scale``) + are left untouched. + """ + for _, sub_module in component.named_modules(): + for buffer_name in ("svdquant_lora_a", "svdquant_lora_b", "pre_quant_scale"): + if buffer_name in getattr(sub_module, "_buffers", {}): + del sub_module._buffers[buffer_name] + + def _export_diffusers_checkpoint( pipe: Any, dtype: torch.dtype | None, @@ -1200,6 +1215,10 @@ def _export_diffusers_checkpoint( config_data["quantization_config"] = hf_quant_config with open(config_path, "w") as file: json.dump(config_data, file, indent=4) + + # Drop the temporary promoted export buffers so the live module is + # unchanged after export (supports repeated export / module reuse). + _remove_promoted_quantizer_tensors(component) # Non-quantized component: just save as-is elif hasattr(component, "save_pretrained"): component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) From c2250cb5b5b0fa89f56d353cbf229ee8030177fc Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 11 Jun 2026 23:01:12 -0700 Subject: [PATCH 12/15] Qwen-Image diffusers PTQ: fix 5 execution-surfaced bugs (fp8/nvfp4/svdquant) Validated end-to-end on GB200 against the real Qwen/Qwen-Image: all three formats export correct HF checkpoints (only transformer_blocks 2..57; nothing outside), no quantizer-state leak, and the focused tests pass. - models_utils: build_block_range_quant_cfg now uses the top-level enable QuantizerCfgEntry field (a None cfg retains the base preset's params) instead of nesting cfg.enable, which the QuantizerAttributeConfig validator rejects/mis-applies (the old form left every block quantized). - quantize.py: import onnx_utils.export lazily (only needed for --onnx-dir; avoids a hard onnx_graphsurgeon dependency), and pass max_shard_size so the ~20B transformer saves as a single safetensors -- the unified export's layerwise-metadata post-processing does not support sharded files. - diffusers_utils: hide_quantizers_from_state_dict strips quantizer submodules from all modules, not only is_quantlinear, so enabled input quantizers on norm layers no longer leak input_quantizer._amax into the checkpoint. - tests: the tiny QwenImageTransformer2DModel fixture signature-gates its kwargs (diffusers 0.38 removed pooled_projection_dim from the constructor); the recipe test asserts the corrected top-level enable schema. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Jingyu Xin --- .../diffusers/quantization/models_utils.py | 23 ++++++++++--------- examples/diffusers/quantization/quantize.py | 10 +++++++- modelopt/torch/export/diffusers_utils.py | 23 +++++++++++-------- tests/_test_utils/torch/diffusers_models.py | 8 +++++++ .../diffusers/test_qwen_block_range_recipe.py | 13 ++++++----- 5 files changed, 50 insertions(+), 27 deletions(-) diff --git a/examples/diffusers/quantization/models_utils.py b/examples/diffusers/quantization/models_utils.py index 86fa0750c42..1126a421390 100644 --- a/examples/diffusers/quantization/models_utils.py +++ b/examples/diffusers/quantization/models_utils.py @@ -325,7 +325,8 @@ def build_block_range_quant_cfg( Rules are applied in order with later rules overriding earlier ones: 1. disable every linear weight/input quantizer, - 2. re-enable only those under ``block_module``, + 2. re-enable only those under ``block_module`` (``enable`` is a top-level + QuantizerCfgEntry toggle; a ``None`` cfg keeps the base preset's quant params), 3. disable the first/last ``n`` blocks. Raises: @@ -354,17 +355,17 @@ def build_block_range_quant_cfg( excluded = sorted( set(range(exclude_first_n)) | set(range(num_blocks - exclude_last_n, num_blocks)) ) + # `enable` is a top-level QuantizerCfgEntry field (independent of `cfg`); a `None` + # cfg leaves the base preset's quant params untouched, so disabling then + # re-enabling restores the original (FP8/NVFP4/...) attributes. Putting `enable` + # under `cfg` is rejected by the QuantizerAttributeConfig validator. rules: list[dict[str, Any]] = [ - {"quantizer_name": "*weight_quantizer", "cfg": {"enable": False}}, - {"quantizer_name": "*input_quantizer", "cfg": {"enable": False}}, - {"quantizer_name": f"*{block_module}.*weight_quantizer", "cfg": {"enable": True}}, - {"quantizer_name": f"*{block_module}.*input_quantizer", "cfg": {"enable": True}}, + {"quantizer_name": "*weight_quantizer", "enable": False}, + {"quantizer_name": "*input_quantizer", "enable": False}, + {"quantizer_name": f"*{block_module}.*weight_quantizer", "enable": True}, + {"quantizer_name": f"*{block_module}.*input_quantizer", "enable": True}, ] for idx in excluded: - rules.append( - {"quantizer_name": f"*{block_module}.{idx}.*weight_quantizer", "cfg": {"enable": False}} - ) - rules.append( - {"quantizer_name": f"*{block_module}.{idx}.*input_quantizer", "cfg": {"enable": False}} - ) + rules.append({"quantizer_name": f"*{block_module}.{idx}.*weight_quantizer", "enable": False}) + rules.append({"quantizer_name": f"*{block_module}.{idx}.*input_quantizer", "enable": False}) return rules diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index e81dc51a66b..325d39e5bd7 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -39,7 +39,6 @@ get_model_filter_func, parse_extra_params, ) -from onnx_utils.export import generate_fp8_scales, modelopt_export_sd from pipeline_manager import PipelineManager from quantize_config import ( CalibrationConfig, @@ -319,6 +318,11 @@ def export_onnx( if not self.config.onnx_dir: return + # onnx_graphsurgeon (pulled in by onnx_utils.export) is an optional dependency + # only needed for the ONNX export path; import lazily so the HF-checkpoint + # export runs without it installed. + from onnx_utils.export import generate_fp8_scales, modelopt_export_sd + self.logger.info(f"Starting ONNX export to {self.config.onnx_dir}") if quant_format == QuantFormat.FP8 and self._has_conv_layers(backbone): @@ -407,6 +411,10 @@ def export_hf_ckpt(self, pipe: Any, model_config: ModelConfig | None = None) -> f"Invalid padding_strategy: {padding!r}. Expected 'row' or 'row_col'." ) kwargs["padding_strategy"] = padding + # The diffusion transformer is large (~20B params); the unified export's + # layerwise-metadata post-processing does not support sharded safetensors, so + # save each component as a single file (no *.safetensors.index.json). + kwargs.setdefault("max_shard_size", "200GB") export_hf_checkpoint(pipe, export_dir=self.config.hf_ckpt_dir, **kwargs) self.logger.info("HuggingFace checkpoint export completed successfully") diff --git a/modelopt/torch/export/diffusers_utils.py b/modelopt/torch/export/diffusers_utils.py index fedfe98e723..075f25a9101 100644 --- a/modelopt/torch/export/diffusers_utils.py +++ b/modelopt/torch/export/diffusers_utils.py @@ -735,15 +735,20 @@ def hide_quantizers_from_state_dict(model: nn.Module): # Store references to quantizers that we'll temporarily remove quantizer_backup: dict[str, dict[str, nn.Module]] = {} - for name, module in model.named_modules(): - if is_quantlinear(module): - backup = {} - for attr in ["weight_quantizer", "input_quantizer", "output_quantizer"]: - if hasattr(module, attr): - backup[attr] = getattr(module, attr) - delattr(module, attr) - if backup: - quantizer_backup[name] = backup + # Remove every quantizer submodule from *all* modules, not only recognized + # quant-linears: enabled input quantizers can also live on non-linear modules + # (e.g. norm layers whose activations were calibrated), and their ``_amax`` + # buffers must not leak into the saved checkpoint. Snapshot the module list + # first since we mutate the module tree while iterating. + for name, module in list(model.named_modules()): + backup = {} + for attr in ["weight_quantizer", "input_quantizer", "output_quantizer"]: + child = getattr(module, attr, None) + if isinstance(child, nn.Module): + backup[attr] = child + delattr(module, attr) + if backup: + quantizer_backup[name] = backup try: yield diff --git a/tests/_test_utils/torch/diffusers_models.py b/tests/_test_utils/torch/diffusers_models.py index 6e08518bdce..caf4966399f 100644 --- a/tests/_test_utils/torch/diffusers_models.py +++ b/tests/_test_utils/torch/diffusers_models.py @@ -289,6 +289,14 @@ def get_tiny_qwen_image_transformer(**config_kwargs): "axes_dims_rope": (8, 4, 4), # sums to attention_head_dim (16) } kwargs.update(**config_kwargs) + # Drop kwargs the installed diffusers QwenImageTransformer2DModel doesn't accept. + # `pooled_projection_dim` is present in the published config.json but was removed + # from the constructor in newer diffusers: from_pretrained tolerates the extra + # config key, but a direct constructor call raises TypeError. + import inspect + + accepted = set(inspect.signature(QwenImageTransformer2DModel.__init__).parameters) + kwargs = {k: v for k, v in kwargs.items() if k in accepted} return QwenImageTransformer2DModel(**kwargs) diff --git a/tests/examples/diffusers/test_qwen_block_range_recipe.py b/tests/examples/diffusers/test_qwen_block_range_recipe.py index d1fdf3b7c65..57cd5619591 100644 --- a/tests/examples/diffusers/test_qwen_block_range_recipe.py +++ b/tests/examples/diffusers/test_qwen_block_range_recipe.py @@ -51,7 +51,7 @@ def _disabled_block_indices(rules): """Indices of transformer blocks explicitly disabled by per-block rules.""" indices = set() for rule in rules: - if rule["cfg"].get("enable") is False: + if rule.get("enable") is False: match = _BLOCK_RULE_RE.fullmatch(rule["quantizer_name"]) if match: indices.add(int(match.group(1))) @@ -62,11 +62,12 @@ def test_recipe_excludes_first_and_last_two_blocks(): rules = build_block_range_quant_cfg(_StubBackbone(6), exclude_first_n=2, exclude_last_n=2) # 1. disable-all rules come first (weight + input). - assert rules[0] == {"quantizer_name": "*weight_quantizer", "cfg": {"enable": False}} - assert rules[1] == {"quantizer_name": "*input_quantizer", "cfg": {"enable": False}} - # 2. then enable only the transformer_blocks. - assert {"quantizer_name": "*transformer_blocks.*weight_quantizer", "cfg": {"enable": True}} in rules - assert {"quantizer_name": "*transformer_blocks.*input_quantizer", "cfg": {"enable": True}} in rules + assert rules[0] == {"quantizer_name": "*weight_quantizer", "enable": False} + assert rules[1] == {"quantizer_name": "*input_quantizer", "enable": False} + # 2. then re-enable only the transformer_blocks (top-level `enable`; a `None` cfg + # keeps the base preset's quant params). + assert {"quantizer_name": "*transformer_blocks.*weight_quantizer", "enable": True} in rules + assert {"quantizer_name": "*transformer_blocks.*input_quantizer", "enable": True} in rules # 3. then disable the first 2 and last 2 of the 6 blocks -> {0, 1, 4, 5}; quantize {2, 3}. assert _disabled_block_indices(rules) == {0, 1, 4, 5} From fb23155a7b1078802b17599d539f1db61c0b46f9 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 12 Jun 2026 16:21:44 -0700 Subject: [PATCH 13/15] Qwen-Image: drop operator-specific quantization harness from the example run_qwen_image_quantization.sh and its README are cluster-specific experiment/operator scripts (hard-coded /lustre paths) that do not belong in the upstream diffusers example. The feature itself (model registration, block-range recipe, FP8/NVFP4/SVDQuant export) is covered by the committed tests. The scripts are kept locally outside the repo. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Jingyu Xin --- .../qwen_image_svdquant/README.md | 108 ------------------ .../run_qwen_image_quantization.sh | 101 ---------------- 2 files changed, 209 deletions(-) delete mode 100644 examples/diffusers/quantization/qwen_image_svdquant/README.md delete mode 100755 examples/diffusers/quantization/qwen_image_svdquant/run_qwen_image_quantization.sh diff --git a/examples/diffusers/quantization/qwen_image_svdquant/README.md b/examples/diffusers/quantization/qwen_image_svdquant/README.md deleted file mode 100644 index 02daf02ac8c..00000000000 --- a/examples/diffusers/quantization/qwen_image_svdquant/README.md +++ /dev/null @@ -1,108 +0,0 @@ -# Qwen-Image Quantization (FP8 / NVFP4 / NVFP4-SVDQuant) - -A reproducible harness for quantizing [`Qwen/Qwen-Image`](https://huggingface.co/Qwen/Qwen-Image) -with the diffusers quantization example and exporting HuggingFace checkpoints. - -## What it does - -- Registers Qwen-Image in the diffusers quantization example (`--model qwen-image`). -- **Recipe**: quantizes only the linears under `transformer_blocks`, keeping the - **first 2 and last 2** of the 60 blocks (and everything outside - `transformer_blocks`: text encoder, VAE, embedders, norms, `proj_out`, …) in - original precision. The exclusion is applied **before calibration** so that for - SVDQuant the excluded blocks' weights stay bit-identical to the original. -- Produces three checkpoints: **FP8**, **NVFP4** (max), and **NVFP4 + SVDQuant**. -- Exports a HuggingFace unified checkpoint per component (safetensors + `config.json`). - -### SVDQuant checkpoint format (AWQ-aligned) - -For the SVDQuant export, the quantizer-owned tensors are promoted to clean, -module-level safetensors keys (mirroring how AWQ exports `pre_quant_scale`): - -| Tensor | Safetensors key | -|--------|-----------------| -| AWQ smoothing scale (`input_quantizer._pre_quant_scale`) | `.pre_quant_scale` | -| Low-rank factor A (`weight_quantizer.svdquant_lora_a`) | `.svdquant_lora_a` | -| Low-rank factor B (`weight_quantizer.svdquant_lora_b`) | `.svdquant_lora_b` | - -They are embedded in the component's main safetensors (no sidecar). The -`config.json`'s `quantization_config` follows the `nvfp4_awq` shape with -`"pre_quant_scale": true` plus the SVDQuant `lora_rank`, so a consumer can -reconstruct `y = NVFP4_GEMM(x) + (x @ lora_a^T) @ lora_b^T`. (No in-repo runtime -applies this residual yet; the checkpoint is a documented on-disk artifact.) - -## Layout (kernel-dev defaults) - -| Env var | Default | Purpose | -|---------|---------|---------| -| `KERNEL_DEV_ROOT` | `/lustre/fsw/coreai_dlalgo_modelopt/users/jingyux/kernel-dev` | Root for container/models/output | -| `MODEL_DIR` | `${KERNEL_DEV_ROOT}/models/Qwen-Image` | Local model cache | -| `OUTPUT_DIR` | `${KERNEL_DEV_ROOT}/qwen_image_ckpts` | Exported checkpoints | -| `HF_TOKEN_FILE` | `${KERNEL_DEV_ROOT}/HF_TOKEN.txt` | Hugging Face token file | -| `FORMATS` | `fp8 nvfp4 svdquant` | Formats to run | -| `CALIB_SIZE` / `BATCH_SIZE` / `N_STEPS` / `LOWRANK` | `64 / 2 / 20 / 32` | Calibration knobs | - -## 1. Build the container (once) - -The diffusers example needs a recent `diffusers` (with `QwenImagePipeline`) and -modelopt installed from source. From a base NGC PyTorch image: - -```bash -CONTAINER_DIR=/lustre/fsw/coreai_dlalgo_modelopt/users/jingyux/kernel-dev/container -mkdir -p "${CONTAINER_DIR}" - -# Import a base image to an enroot squashfs (adjust the tag as needed). -enroot import -o "${CONTAINER_DIR}/modelopt-diffusers.sqsh" \ - docker://nvcr.io#nvidia/pytorch:25.04-py3 - -# Install modelopt (from source) + example deps into the container, then re-save. -srun --container-image="${CONTAINER_DIR}/modelopt-diffusers.sqsh" \ - --container-mounts=/lustre:/lustre --container-save="${CONTAINER_DIR}/modelopt-diffusers.sqsh" \ - bash -lc ' - cd /lustre/fsw/coreai_dlalgo_modelopt/users/jingyux/kernel-dev/source/Model-Optimizer && - pip install -e ".[dev]" && - pip install -U "diffusers>=0.35" "transformers>=4.52" accelerate datasets && - python -c "from diffusers import QwenImagePipeline; print(\"QwenImagePipeline OK\")" - ' -``` - -## 2. Run quantization - -Inside the container (or via `srun`), run the harness: - -```bash -srun --gpus=1 \ - --container-image=/lustre/fsw/coreai_dlalgo_modelopt/users/jingyux/kernel-dev/container/modelopt-diffusers.sqsh \ - --container-mounts=/lustre:/lustre \ - bash examples/diffusers/quantization/qwen_image_svdquant/run_qwen_image_quantization.sh -``` - -This downloads `Qwen/Qwen-Image` to `MODEL_DIR` (idempotent), then for each -format writes `${OUTPUT_DIR}/qwen-image-/` (HF checkpoint + `sanity.png`). - -Run a single format, or preview the commands without executing: - -```bash -FORMATS=svdquant LOWRANK=32 bash .../run_qwen_image_quantization.sh -DRY_RUN=1 bash .../run_qwen_image_quantization.sh # print planned commands only -``` - -The equivalent direct `quantize.py` invocation for SVDQuant: - -```bash -python examples/diffusers/quantization/quantize.py \ - --model qwen-image --override-model-path "${MODEL_DIR}" --model-dtype BFloat16 \ - --format fp4 --quant-algo svdquant --lowrank 32 \ - --calib-size 64 --batch-size 2 --n-steps 20 \ - --hf-ckpt-dir "${OUTPUT_DIR}/qwen-image-svdquant" \ - --sanity-image-path "${OUTPUT_DIR}/qwen-image-svdquant/sanity.png" -``` - -## Notes - -- `Qwen/Qwen-Image` loads without `trust_remote_code`. -- The transformer is ~20B params; calibration needs a GPU with enough memory - (use `--cpu-offloading` if VRAM-limited). -- The `--sanity-image-path` image is generated from the **in-memory** quantized - pipeline before the weights are packed for export (a functional check of - quantized inference; it does not reload the exported checkpoint). diff --git a/examples/diffusers/quantization/qwen_image_svdquant/run_qwen_image_quantization.sh b/examples/diffusers/quantization/qwen_image_svdquant/run_qwen_image_quantization.sh deleted file mode 100755 index 5e0569c0e40..00000000000 --- a/examples/diffusers/quantization/qwen_image_svdquant/run_qwen_image_quantization.sh +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env bash -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Reproducible Qwen-Image quantization (FP8 / NVFP4 / NVFP4-SVDQuant) using the -# diffusers quantization example. This script is meant to run INSIDE a container -# that already has NVIDIA Model Optimizer installed from source and a -# Qwen-capable diffusers (see README.md for building the container and the -# Slurm/srun wrapper). -# -# It downloads Qwen/Qwen-Image (idempotently), then for each requested format -# runs `quantize.py` to calibrate the transformer (only `transformer_blocks`, -# excluding the first 2 / last 2 blocks), generate a quantized-inference sanity -# image, and export a HuggingFace checkpoint. -# -# All paths are parameterized via environment variables; the defaults match the -# kernel-dev experiment layout described in README.md. -set -euo pipefail - -# --- Configuration (override via environment) -------------------------------- -KERNEL_DEV_ROOT="${KERNEL_DEV_ROOT:-/lustre/fsw/coreai_dlalgo_modelopt/users/jingyux/kernel-dev}" -MODEL_ID="${MODEL_ID:-Qwen/Qwen-Image}" -MODEL_DIR="${MODEL_DIR:-${KERNEL_DEV_ROOT}/models/Qwen-Image}" -OUTPUT_DIR="${OUTPUT_DIR:-${KERNEL_DEV_ROOT}/qwen_image_ckpts}" -HF_TOKEN_FILE="${HF_TOKEN_FILE:-${KERNEL_DEV_ROOT}/HF_TOKEN.txt}" -# Path to the diffusers quantization example (this script lives one level below it). -QUANT_DIR="${QUANT_DIR:-$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)}" - -# Formats to run: any of {fp8, nvfp4, svdquant}. -FORMATS="${FORMATS:-fp8 nvfp4 svdquant}" - -# Calibration knobs (small defaults for a quick run; raise CALIB_SIZE for quality). -CALIB_SIZE="${CALIB_SIZE:-64}" -BATCH_SIZE="${BATCH_SIZE:-2}" -N_STEPS="${N_STEPS:-20}" -LOWRANK="${LOWRANK:-32}" -MODEL_DTYPE="${MODEL_DTYPE:-BFloat16}" - -# Set DRY_RUN=1 to print the planned commands without executing them. -DRY_RUN="${DRY_RUN:-0}" - -log() { echo "[qwen-image-quant] $*"; } -run() { - log "+ $*" - if [[ "${DRY_RUN}" != "1" ]]; then - "$@" - fi -} - -# --- Hugging Face token ------------------------------------------------------ -if [[ ! -r "${HF_TOKEN_FILE}" ]]; then - echo "ERROR: HF token file not found or not readable: ${HF_TOKEN_FILE}" >&2 - echo " Set HF_TOKEN_FILE to a readable file containing your Hugging Face token." >&2 - exit 1 -fi -HF_TOKEN="$(tr -d '[:space:]' < "${HF_TOKEN_FILE}")" -if [[ -z "${HF_TOKEN}" ]]; then - echo "ERROR: HF token file is empty: ${HF_TOKEN_FILE}" >&2 - exit 1 -fi -export HF_TOKEN -export HUGGING_FACE_HUB_TOKEN="${HF_TOKEN}" - -# --- Download the model (idempotent) ---------------------------------------- -log "Downloading ${MODEL_ID} -> ${MODEL_DIR} (skipped if already present)" -run mkdir -p "${MODEL_DIR}" -run huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_DIR}" --exclude "*.onnx" - -# --- Quantize + export for each format -------------------------------------- -mkdir -p "${OUTPUT_DIR}" -for fmt in ${FORMATS}; do - case "${fmt}" in - fp8) quant_args=(--format fp8 --quant-algo max) ;; - nvfp4) quant_args=(--format fp4 --quant-algo max) ;; - svdquant) quant_args=(--format fp4 --quant-algo svdquant --lowrank "${LOWRANK}") ;; - *) echo "ERROR: unknown format '${fmt}' (expected fp8|nvfp4|svdquant)" >&2; exit 1 ;; - esac - - out="${OUTPUT_DIR}/qwen-image-${fmt}" - log "=== Quantizing Qwen-Image (${fmt}) -> ${out} ===" - run python "${QUANT_DIR}/quantize.py" \ - --model qwen-image \ - --override-model-path "${MODEL_DIR}" \ - --model-dtype "${MODEL_DTYPE}" \ - "${quant_args[@]}" \ - --calib-size "${CALIB_SIZE}" \ - --batch-size "${BATCH_SIZE}" \ - --n-steps "${N_STEPS}" \ - --hf-ckpt-dir "${out}" \ - --sanity-image-path "${out}/sanity.png" - - # Verify the expected artifacts were produced (a missing artifact is a failure). - if [[ "${DRY_RUN}" != "1" ]]; then - [[ -f "${out}/sanity.png" ]] || { echo "ERROR: missing sanity image ${out}/sanity.png" >&2; exit 1; } - find "${out}" -name '*.safetensors' | grep -q . || { echo "ERROR: no safetensors under ${out}" >&2; exit 1; } - find "${out}" -name 'config.json' | grep -q . || { echo "ERROR: no config.json under ${out}" >&2; exit 1; } - fi - log "Done: ${fmt}. Checkpoint at ${out}, sanity image at ${out}/sanity.png" -done - -log "All requested formats complete. Checkpoints under ${OUTPUT_DIR}" From 80e22443a308a0d10956615d9c80d1661331dab0 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 12 Jun 2026 16:32:29 -0700 Subject: [PATCH 14/15] Qwen-Image: consolidate tests into the shared diffusers export suite Remove the standalone Qwen test files. The fp8/nvfp4/svdquant cases in test_export_diffusers_hf_ckpt.py already cover the block-range recipe (only transformer_blocks 2..57 quantized), the promoted SVDQuant keys + pre_quant_scale, the NVFP4_SVD quantization_config, and the no-leak check -- matching how SDXL/Flux/Wan are tested in the same file. Core SVDQuant forward/fold is unchanged and remains covered by existing upstream tests. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Jingyu Xin --- .../diffusers/test_qwen_block_range_recipe.py | 160 ------------------ .../diffusers/test_qwen_pipeline_loading.py | 74 -------- .../export/test_convert_hf_config_svdquant.py | 87 ---------- .../export/test_diffusers_qwen_export.py | 133 --------------- .../test_svdquant_forward_fold.py | 116 ------------- 5 files changed, 570 deletions(-) delete mode 100644 tests/examples/diffusers/test_qwen_block_range_recipe.py delete mode 100644 tests/examples/diffusers/test_qwen_pipeline_loading.py delete mode 100644 tests/unit/torch/export/test_convert_hf_config_svdquant.py delete mode 100644 tests/unit/torch/export/test_diffusers_qwen_export.py delete mode 100644 tests/unit/torch/quantization/test_svdquant_forward_fold.py diff --git a/tests/examples/diffusers/test_qwen_block_range_recipe.py b/tests/examples/diffusers/test_qwen_block_range_recipe.py deleted file mode 100644 index 57cd5619591..00000000000 --- a/tests/examples/diffusers/test_qwen_block_range_recipe.py +++ /dev/null @@ -1,160 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for the transformer-block-range quantization recipe (e.g. Qwen-Image). - -The recipe must quantize only the linears under ``transformer_blocks`` while -excluding the first/last N blocks, and it must be expressible as ``quant_cfg`` -rules applied before calibration (so SVDQuant never mutates the excluded blocks). -""" - -import re -import sys -from pathlib import Path - -import pytest - -# Importing the example module pulls in diffusers/torch/datasets/modelopt. -pytest.importorskip("diffusers") -pytest.importorskip("torch") - -# Make the diffusers quantization example importable. -_EXAMPLE_DIR = Path(__file__).parents[3] / "examples" / "diffusers" / "quantization" -if str(_EXAMPLE_DIR) not in sys.path: - sys.path.insert(0, str(_EXAMPLE_DIR)) - -from models_utils import build_block_range_quant_cfg # noqa: E402 - -_BLOCK_RULE_RE = re.compile(r"\*transformer_blocks\.(\d+)\.\*(?:weight|input)_quantizer") - - -class _StubBackbone: - """Minimal stand-in exposing a ``transformer_blocks`` sequence of length n.""" - - def __init__(self, num_blocks: int): - self.transformer_blocks = list(range(num_blocks)) - - -def _disabled_block_indices(rules): - """Indices of transformer blocks explicitly disabled by per-block rules.""" - indices = set() - for rule in rules: - if rule.get("enable") is False: - match = _BLOCK_RULE_RE.fullmatch(rule["quantizer_name"]) - if match: - indices.add(int(match.group(1))) - return indices - - -def test_recipe_excludes_first_and_last_two_blocks(): - rules = build_block_range_quant_cfg(_StubBackbone(6), exclude_first_n=2, exclude_last_n=2) - - # 1. disable-all rules come first (weight + input). - assert rules[0] == {"quantizer_name": "*weight_quantizer", "enable": False} - assert rules[1] == {"quantizer_name": "*input_quantizer", "enable": False} - # 2. then re-enable only the transformer_blocks (top-level `enable`; a `None` cfg - # keeps the base preset's quant params). - assert {"quantizer_name": "*transformer_blocks.*weight_quantizer", "enable": True} in rules - assert {"quantizer_name": "*transformer_blocks.*input_quantizer", "enable": True} in rules - # 3. then disable the first 2 and last 2 of the 6 blocks -> {0, 1, 4, 5}; quantize {2, 3}. - assert _disabled_block_indices(rules) == {0, 1, 4, 5} - - -def test_recipe_block_count_scales_with_model(): - # For a 60-block model (Qwen-Image), exclude {0, 1, 58, 59}; quantize 2..57. - rules = build_block_range_quant_cfg(_StubBackbone(60), exclude_first_n=2, exclude_last_n=2) - assert _disabled_block_indices(rules) == {0, 1, 58, 59} - - -@pytest.mark.parametrize("num_blocks", [5, 4, 3]) -def test_recipe_rejects_too_few_blocks(num_blocks): - # A 2 + 2 exclusion needs at least 6 blocks (>= 2 quantized middle blocks). - # A 5-block model leaves only 1 middle block and must be rejected too. - with pytest.raises(ValueError, match="at least"): - build_block_range_quant_cfg( - _StubBackbone(num_blocks), exclude_first_n=2, exclude_last_n=2 - ) - - -def test_recipe_missing_block_module_raises(): - class _NoBlocks: - pass - - with pytest.raises(ValueError, match="transformer_blocks"): - build_block_range_quant_cfg(_NoBlocks(), exclude_first_n=2, exclude_last_n=2) - - -def test_svdquant_recipe_leaves_excluded_blocks_bit_identical(): - """The block-range recipe must keep the excluded first/last blocks bit-identical - through SVDQuant (whose calibration subtracts a residual from every *enabled* - linear), while the middle blocks receive LoRA.""" - import torch - import torch.nn as nn - - import modelopt.torch.quantization as mtq - - class _Block(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.proj = nn.Linear(dim, dim) - - def forward(self, x): - return self.proj(x) - - class _Backbone(nn.Module): - def __init__(self, num_blocks: int = 6, dim: int = 32): - super().__init__() - self.transformer_blocks = nn.ModuleList(_Block(dim) for _ in range(num_blocks)) - - def forward(self, x): - for block in self.transformer_blocks: - x = block(x) - return x - - torch.manual_seed(0) - model = _Backbone(num_blocks=6, dim=32) - weights_before = { - i: model.transformer_blocks[i].proj.weight.detach().clone() for i in range(6) - } - - # Base rules quantize every linear weight/input quantizer; the recipe then - # disables all and re-enables only the middle transformer blocks (2, 3). - quant_cfg = { - "quant_cfg": [ - {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, - {"quantizer_name": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, - *build_block_range_quant_cfg(model, exclude_first_n=2, exclude_last_n=2), - ], - "algorithm": {"method": "svdquant", "lowrank": 4}, - } - calib_data = [torch.randn(2, 32) for _ in range(2)] - mtq.quantize(model, quant_cfg, lambda m: [m(batch) for batch in calib_data]) - - excluded = {0, 1, 4, 5} - for idx in range(6): - proj = model.transformer_blocks[idx].proj - lora_a = getattr(getattr(proj, "weight_quantizer", None), "svdquant_lora_a", None) - if idx in excluded: - # Never calibrated -> weight bit-identical, no LoRA residual. - assert torch.equal(proj.weight, weights_before[idx]), ( - f"excluded block {idx} weight was modified" - ) - assert lora_a is None, f"excluded block {idx} unexpectedly has SVDQuant LoRA" - else: - # Calibrated -> LoRA present and the residual was subtracted from the weight. - assert lora_a is not None, f"middle block {idx} is missing SVDQuant LoRA" - assert not torch.equal(proj.weight, weights_before[idx]), ( - f"middle block {idx} weight was not modified by SVDQuant" - ) diff --git a/tests/examples/diffusers/test_qwen_pipeline_loading.py b/tests/examples/diffusers/test_qwen_pipeline_loading.py deleted file mode 100644 index cbc49aa5ea3..00000000000 --- a/tests/examples/diffusers/test_qwen_pipeline_loading.py +++ /dev/null @@ -1,74 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Negative-path loading tests for Qwen-Image in the diffusers quantization example. - -These cover the negative loading paths without a GPU or a real model: -- selecting Qwen-Image when diffusers lacks the Qwen classes raises a clear, - actionable error (not an opaque failure); -- Qwen loading does not pass ``trust_remote_code``. -""" - -import logging -import sys -from pathlib import Path - -import pytest - -pytest.importorskip("diffusers") -pytest.importorskip("torch") - -_EXAMPLE_DIR = Path(__file__).parents[3] / "examples" / "diffusers" / "quantization" -if str(_EXAMPLE_DIR) not in sys.path: - sys.path.insert(0, str(_EXAMPLE_DIR)) - -import models_utils # noqa: E402 -import pipeline_manager # noqa: E402 -from models_utils import ModelType # noqa: E402 -from quantize_config import ModelConfig # noqa: E402 - - -def _qwen_pipeline_manager() -> "pipeline_manager.PipelineManager": - config = ModelConfig(model_type=ModelType.QWEN_IMAGE, backbone=["transformer"]) - return pipeline_manager.PipelineManager(config, logging.getLogger("qwen-loading-test")) - - -def test_missing_qwen_pipeline_raises_actionable_error(monkeypatch): - # Simulate a diffusers version without QwenImagePipeline. - monkeypatch.setitem(models_utils.MODEL_PIPELINE, ModelType.QWEN_IMAGE, None) - manager = _qwen_pipeline_manager() - with pytest.raises(ImportError, match="Qwen-Image requires"): - manager.create_pipeline() - - -def test_qwen_loading_does_not_pass_trust_remote_code(monkeypatch): - captured_kwargs: dict = {} - - class _FakeQwenPipeline: - @classmethod - def from_pretrained(cls, model_id, **kwargs): - captured_kwargs.update(kwargs) - return cls() - - def set_progress_bar_config(self, **kwargs): - pass - - monkeypatch.setitem(models_utils.MODEL_PIPELINE, ModelType.QWEN_IMAGE, _FakeQwenPipeline) - manager = _qwen_pipeline_manager() - manager.create_pipeline() - - # Qwen-Image must load without trust_remote_code. - assert captured_kwargs.get("trust_remote_code") is not True - assert "trust_remote_code" not in captured_kwargs diff --git a/tests/unit/torch/export/test_convert_hf_config_svdquant.py b/tests/unit/torch/export/test_convert_hf_config_svdquant.py deleted file mode 100644 index 13a22ab22d5..00000000000 --- a/tests/unit/torch/export/test_convert_hf_config_svdquant.py +++ /dev/null @@ -1,87 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for the NVFP4_SVD (SVDQuant) HF quantization-config conversion.""" - -from modelopt.torch.export.convert_hf_config import ( - _quant_algo_to_group_config, - convert_hf_quant_config_format, -) - - -def test_nvfp4_svd_group_config_mirrors_awq_with_pre_quant_scale(): - """The NVFP4_SVD config group is NVFP4 weights/activations + a pre_quant_scale flag.""" - group = _quant_algo_to_group_config("NVFP4_SVD", group_size=16) - assert group["pre_quant_scale"] is True - assert group["has_zero_point"] is False - assert group["weights"] == { - "dynamic": False, - "num_bits": 4, - "type": "float", - "group_size": 16, - } - assert group["input_activations"]["num_bits"] == 4 - assert group["input_activations"]["type"] == "float" - assert group["input_activations"]["group_size"] == 16 - - -def test_convert_hf_quant_config_format_nvfp4_svd(): - """A full NVFP4_SVD quantization dict converts to a complete compressed-tensors config.""" - input_config = { - "producer": {"name": "modelopt", "version": "0.0.0"}, - "quantization": { - "quant_algo": "NVFP4_SVD", - "group_size": 16, - "has_zero_point": False, - "pre_quant_scale": True, - "lora_rank": 32, - "exclude_modules": ["transformer_blocks.0.*", "proj_out"], - "kv_cache_quant_algo": None, - }, - } - - out = convert_hf_quant_config_format(input_config) - - # A real config group is emitted (not a bare {"quant_algo": ...} fallback). - assert "config_groups" in out - group = out["config_groups"]["group_0"] - assert group["pre_quant_scale"] is True - assert group["has_zero_point"] is False - assert group["lora_rank"] == 32 - assert group["weights"]["num_bits"] == 4 - assert group["weights"]["type"] == "float" - assert group["weights"]["group_size"] == 16 - assert group["input_activations"]["num_bits"] == 4 - assert group["targets"] == ["Linear"] - - # Top-level metadata is preserved. - assert out["quant_algo"] == "NVFP4_SVD" - assert out["ignore"] == ["transformer_blocks.0.*", "proj_out"] - assert out["quant_method"] == "modelopt" - - -def test_convert_hf_quant_config_format_nvfp4_svd_without_rank(): - """lora_rank is optional; omitting it must not break the conversion.""" - input_config = { - "quantization": { - "quant_algo": "NVFP4_SVD", - "group_size": 16, - "pre_quant_scale": True, - }, - } - out = convert_hf_quant_config_format(input_config) - group = out["config_groups"]["group_0"] - assert "lora_rank" not in group - assert group["pre_quant_scale"] is True diff --git a/tests/unit/torch/export/test_diffusers_qwen_export.py b/tests/unit/torch/export/test_diffusers_qwen_export.py deleted file mode 100644 index ad4b62b048f..00000000000 --- a/tests/unit/torch/export/test_diffusers_qwen_export.py +++ /dev/null @@ -1,133 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for the Qwen-Image SVDQuant diffusers export path. - -Covers the three pieces added for Qwen support: -- the Qwen branch of ``generate_diffusion_dummy_inputs`` (validated by running the - dummy forward on a real tiny ``QwenImageTransformer2DModel``), -- the strict-failure mode of ``_fuse_qkv_linears_diffusion``, -- promotion of quantizer-owned SVDQuant tensors to clean module-level keys that - survive ``hide_quantizers_from_state_dict``. -""" - -from functools import partial - -import pytest -import torch -import torch.nn as nn - -import modelopt.torch.quantization as mtq -from modelopt.torch.export.diffusers_utils import ( - generate_diffusion_dummy_forward_fn, - generate_diffusion_dummy_inputs, - hide_quantizers_from_state_dict, -) -from modelopt.torch.export.unified_export_hf import ( - _fuse_qkv_linears_diffusion, - _promote_quantizer_tensors_to_module, -) - - -class _MLP(nn.Module): - def __init__(self, dim: int = 64): - super().__init__() - self.fc1 = nn.Linear(dim, dim) - self.fc2 = nn.Linear(dim, dim) - - def forward(self, x): - return self.fc2(torch.relu(self.fc1(x))) - - -def _forward_loop(model, data): - for batch in data: - model(batch) - - -def _quantize(model: nn.Module, algorithm=None, dim: int = 64) -> nn.Module: - cfg = mtq.INT8_SMOOTHQUANT_CFG.copy() - if algorithm is not None: - cfg["algorithm"] = algorithm - data = [torch.randn(2, dim) for _ in range(2)] - mtq.quantize(model, cfg, partial(_forward_loop, data=data)) - return model - - -def test_qwen_dummy_inputs_drive_real_transformer_forward(): - """The Qwen dummy inputs must actually drive a real tiny Qwen transformer.""" - pytest.importorskip("diffusers") - from _test_utils.torch.diffusers_models import get_tiny_qwen_image_transformer - - transformer = get_tiny_qwen_image_transformer().to("cpu", torch.float32).eval() - - inputs = generate_diffusion_dummy_inputs(transformer, torch.device("cpu"), torch.float32) - assert inputs is not None - for key in ( - "hidden_states", - "encoder_hidden_states", - "encoder_hidden_states_mask", - "img_shapes", - ): - assert key in inputs, f"missing Qwen dummy input '{key}'" - assert inputs["hidden_states"].shape[-1] == transformer.config.in_channels - assert inputs["encoder_hidden_states"].shape[-1] == transformer.config.joint_attention_dim - - # Strongest check: the generated dummy inputs run through the real model. - with torch.no_grad(): - generate_diffusion_dummy_forward_fn(transformer)() - - -def test_qwen_qkv_fusion_strict_raises_on_failed_dummy_forward(): - """strict=True turns a dummy-forward failure into a hard error; strict=False does not.""" - model = _quantize(_MLP()) - - def _boom(): - raise RuntimeError("dummy forward failed") - - with pytest.raises(RuntimeError): - _fuse_qkv_linears_diffusion(model, dummy_forward_fn=_boom, strict=True) - - # Non-strict path warns and returns without raising. - _fuse_qkv_linears_diffusion(model, dummy_forward_fn=_boom, strict=False) - - -def test_svdquant_promotion_survives_hide_quantizers(): - """Promoted LoRA + pre_quant_scale land on the module under clean keys and - survive ``hide_quantizers_from_state_dict`` (which strips the quantizers).""" - model = _quantize(_MLP(), algorithm={"method": "svdquant", "lowrank": 8}) - - _promote_quantizer_tensors_to_module(model) - - linears = [m for m in model.modules() if isinstance(m, torch.nn.Linear)] - assert linears - for module in linears: - assert hasattr(module, "svdquant_lora_a") - assert hasattr(module, "svdquant_lora_b") - # INT8_SMOOTHQUANT produces a pre_quant_scale that is promoted too. - assert hasattr(module, "pre_quant_scale") - # Rank-consistent shapes: lora_a [rank, in], lora_b [out, rank]. - assert module.svdquant_lora_a.shape[1] == module.in_features - assert module.svdquant_lora_b.shape[0] == module.out_features - assert module.svdquant_lora_a.shape[0] == module.svdquant_lora_b.shape[1] - - with hide_quantizers_from_state_dict(model): - keys = list(model.state_dict().keys()) - - assert any(k.endswith(".svdquant_lora_a") for k in keys) - assert any(k.endswith(".svdquant_lora_b") for k in keys) - assert any(k.endswith(".pre_quant_scale") for k in keys) - # Clean keys only: no quantizer-prefixed keys remain once quantizers are hidden. - assert not any("weight_quantizer" in k for k in keys) - assert not any("input_quantizer" in k for k in keys) diff --git a/tests/unit/torch/quantization/test_svdquant_forward_fold.py b/tests/unit/torch/quantization/test_svdquant_forward_fold.py deleted file mode 100644 index 085414f93a3..00000000000 --- a/tests/unit/torch/quantization/test_svdquant_forward_fold.py +++ /dev/null @@ -1,116 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""SVDQuant forward / fold coverage. - -These tests protect the invariants the diffusers SVDQuant export relies on: the -LoRA factors stay on the ``weight_quantizer`` in the live model and the export -layer promotes them. They complement (and do not modify) the existing -``test_calib.py::test_svdquant_lora_weights``. -""" - -from functools import partial - -import torch -import torch.nn as nn - -import modelopt.torch.quantization as mtq - - -class _SVDMLP(nn.Module): - def __init__(self, dim: int = 64): - super().__init__() - self.fc1 = nn.Linear(dim, dim) - self.fc2 = nn.Linear(dim, dim) - - def forward(self, x): - return self.fc2(torch.relu(self.fc1(x))) - - -def _forward_loop(model, dataloader): - for batch in dataloader: - model(batch) - - -def _quantize_svdquant(dim: int = 64) -> nn.Module: - model = _SVDMLP(dim) - quant_config = mtq.INT8_SMOOTHQUANT_CFG.copy() - quant_config["algorithm"] = {"method": "svdquant", "lowrank": 8} - data = [torch.randn(2, dim) for _ in range(2)] - mtq.quantize(model, quant_config, partial(_forward_loop, dataloader=data)) - return model - - -def _quantized_linears(model: nn.Module): - return [m for m in model.modules() if isinstance(m, torch.nn.Linear)] - - -def test_svdquant_lora_stays_on_weight_quantizer(): - """LoRA lives on the quantizer, not the module (the export layer promotes it).""" - model = _quantize_svdquant() - linears = _quantized_linears(model) - assert linears - for module in linears: - wq = module.weight_quantizer - assert wq.svdquant_lora_a is not None - assert wq.svdquant_lora_b is not None - # Not refactored onto the module. - assert not hasattr(module, "svdquant_lora_a") - assert not hasattr(module, "svdquant_lora_b") - - -def test_svdquant_forward_includes_nonzero_residual(): - """The forward output includes a nonzero low-rank residual term.""" - model = _quantize_svdquant() - for module in _quantized_linears(model): - x = torch.randn(2, module.in_features) - - residual = module._compute_lora_residual(x) - assert residual is not None - assert torch.count_nonzero(residual) > 0 - - full = module(x) - - # Temporarily drop the LoRA buffers to get the base (no-residual) output. - wq = module.weight_quantizer - lora_a = wq._svdquant_lora_a - lora_b = wq._svdquant_lora_b - delattr(wq, "_svdquant_lora_a") - delattr(wq, "_svdquant_lora_b") - try: - base = module(x) - finally: - wq.register_buffer("_svdquant_lora_a", lora_a) - wq.register_buffer("_svdquant_lora_b", lora_b) - - # The residual measurably changes the forward output. - assert not torch.allclose(full, base) - - -def test_svdquant_fold_weight_removes_buffers_and_changes_weight(): - """fold_weight() folds the residual into the weight and drops the buffers.""" - model = _quantize_svdquant() - for module in _quantized_linears(model): - wq = module.weight_quantizer - assert hasattr(wq, "_svdquant_lora_a") - assert hasattr(wq, "_svdquant_lora_b") - - weight_before = module.weight.detach().clone() - module.fold_weight() - - assert not hasattr(wq, "_svdquant_lora_a") - assert not hasattr(wq, "_svdquant_lora_b") - # Folding (quantized weight + low-rank residual) changes the stored weight. - assert not torch.allclose(module.weight, weight_before) From 9b472b234f530cc5e040532d44ee94c15129b43d Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 12 Jun 2026 16:40:20 -0700 Subject: [PATCH 15/15] Qwen-Image: add a fast CPU test for the diffusers SVDQuant export promotion Covers svdquant calibration -> _promote_quantizer_tensors_to_module -> clean module-level keys (svdquant_lora_a/b, pre_quant_scale) with the quantizers hidden, plus the post-export cleanup. Runs on CPU in <1s (INT8_SMOOTHQUANT + svdquant on a tiny linear stack). The full NVFP4 end-to-end check remains test_qwen_image_hf_ckpt_export[qwen_nvfp4_svdquant]; svdquant calibration is already covered by test_calib.py. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Jingyu Xin --- .../torch/export/test_export_diffusers.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/unit/torch/export/test_export_diffusers.py b/tests/unit/torch/export/test_export_diffusers.py index 1a7a3495158..2f1e5085a3e 100644 --- a/tests/unit/torch/export/test_export_diffusers.py +++ b/tests/unit/torch/export/test_export_diffusers.py @@ -117,3 +117,54 @@ def test_flux2_dummy_inputs_shape(): # guidance_embeds defaults to True for Flux2 assert "guidance" in inputs + + +def test_svdquant_diffusers_export_promotes_clean_keys(): + """Fast CPU check of the diffusers SVDQuant export promotion. + + SVDQuant calibration stores the low-rank factors on ``weight_quantizer`` and the + smoothing scale on ``input_quantizer``; the diffusers export promotes both to + clean module-level keys (``svdquant_lora_a/b``, ``pre_quant_scale``) and hides the + quantizers, so the saved state dict carries no live quantizer tensors. The full + NVFP4 end-to-end coverage lives in the GPU test + ``tests/examples/diffusers/test_export_diffusers_hf_ckpt.py`` (``qwen_nvfp4_svdquant``). + """ + import copy + + import torch.nn as nn + + import modelopt.torch.quantization as mtq + from modelopt.torch.export.diffusers_utils import hide_quantizers_from_state_dict + + torch.manual_seed(0) + model = nn.Sequential(nn.Linear(64, 64), nn.Linear(64, 64)) + + quant_config = copy.deepcopy(mtq.INT8_SMOOTHQUANT_CFG) + quant_config["algorithm"] = {"method": "svdquant", "lowrank": 8} + mtq.quantize(model, quant_config, lambda m: m(torch.randn(8, 64))) + + # Calibration populated the quantizer-owned SVDQuant tensors. + linear = model[0] + assert linear.weight_quantizer.svdquant_lora_a is not None + assert linear.weight_quantizer.svdquant_lora_b is not None + assert getattr(linear.input_quantizer, "_pre_quant_scale", None) is not None + + # Export promotes them to clean module-level keys and hides the quantizers. + unified_export_hf._promote_quantizer_tensors_to_module(model) + with hide_quantizers_from_state_dict(model): + keys = set(model.state_dict().keys()) + + assert any(k.endswith(".svdquant_lora_a") for k in keys) + assert any(k.endswith(".svdquant_lora_b") for k in keys) + assert any(k.endswith(".pre_quant_scale") for k in keys) + assert not any("weight_quantizer" in k or "input_quantizer" in k for k in keys), ( + "live quantizer state leaked into the exported state dict" + ) + + # The promotion is undone after export, leaving the live module unchanged. + unified_export_hf._remove_promoted_quantizer_tensors(model) + keys_after = set(model.state_dict().keys()) + assert not any( + k.endswith((".svdquant_lora_a", ".svdquant_lora_b", ".pre_quant_scale")) + for k in keys_after + )