diff --git a/examples/diffusers/quantization/models_utils.py b/examples/diffusers/quantization/models_utils.py index b59744282f6..1126a421390 100644 --- a/examples/diffusers/quantization/models_utils.py +++ b/examples/diffusers/quantization/models_utils.py @@ -18,6 +18,7 @@ from enum import Enum from typing import Any +import torch from diffusers import ( DiffusionPipeline, FluxPipeline, @@ -30,11 +31,19 @@ from diffusers import Flux2Pipeline except ImportError: Flux2Pipeline = None + +# Qwen-Image classes were added in a recent diffusers release; import lazily so +# this example still imports on older diffusers versions. +try: + from diffusers import QwenImagePipeline +except ImportError: + QwenImagePipeline = None from utils import ( filter_func_default, filter_func_flux_dev, filter_func_ltx2_vae, filter_func_ltx_video, + filter_func_qwen_image, filter_func_wan_vae, filter_func_wan_video, ) @@ -54,6 +63,7 @@ class ModelType(str, Enum): LTX2 = "ltx-2" WAN22_T2V_14b = "wan2.2-t2v-14b" WAN22_T2V_5b = "wan2.2-t2v-5b" + QWEN_IMAGE = "qwen-image" _FILTER_FUNC_MAP: dict[ModelType, Callable[[str], bool]] = { @@ -63,6 +73,7 @@ class ModelType(str, Enum): ModelType.LTX2: filter_func_ltx_video, ModelType.WAN22_T2V_14b: filter_func_wan_video, ModelType.WAN22_T2V_5b: filter_func_wan_video, + ModelType.QWEN_IMAGE: filter_func_qwen_image, } _VAE_FILTER_FUNC_MAP: dict[tuple[ModelType, str], Callable[[str], bool]] = { @@ -95,6 +106,7 @@ def get_model_filter_func( ModelType.LTX2: "Lightricks/LTX-2", ModelType.WAN22_T2V_14b: "Wan-AI/Wan2.2-T2V-A14B-Diffusers", ModelType.WAN22_T2V_5b: "Wan-AI/Wan2.2-TI2V-5B-Diffusers", + ModelType.QWEN_IMAGE: "Qwen/Qwen-Image", } MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline] | None] = { @@ -109,6 +121,7 @@ def get_model_filter_func( ModelType.LTX2: None, ModelType.WAN22_T2V_14b: WanPipeline, ModelType.WAN22_T2V_5b: WanPipeline, + ModelType.QWEN_IMAGE: QwenImagePipeline, } # Shared dataset configurations @@ -226,6 +239,23 @@ def get_model_filter_func( ), }, }, + ModelType.QWEN_IMAGE: { + "backbone": "transformer", + "dataset": _SD_PROMPTS_DATASET, + "inference_extra_args": { + "height": 1024, + "width": 1024, + }, + # Quantize only ``transformer_blocks``; keep the first 2 and last 2 blocks + # (and everything outside ``transformer_blocks``) in original precision. + # Applied before calibration via ``build_block_range_quant_cfg`` so SVDQuant + # never mutates the excluded blocks' weights. + "block_range": { + "exclude_first_n": 2, + "exclude_last_n": 2, + "block_module": "transformer_blocks", + }, + }, } @@ -272,3 +302,70 @@ def parse_extra_params( i += 1 return extra_params + + +def build_block_range_quant_cfg( + backbone: torch.nn.Module, + exclude_first_n: int, + exclude_last_n: int, + block_module: str = "transformer_blocks", +) -> list[dict[str, Any]]: + """Build ordered ``quant_cfg`` rules for a transformer-block-only recipe. + + The rules quantize only the linears under ``block_module`` while keeping the + first ``exclude_first_n`` and last ``exclude_last_n`` blocks -- and everything + outside ``block_module`` -- in original precision. + + The rules are meant to be appended to the ``quant_cfg`` list consumed by + ``mtq.quantize`` so the selection is applied BEFORE calibration. This is + required for SVDQuant, whose calibration subtracts a low-rank residual from + the weights of every *enabled* linear: disabling the excluded blocks only + after calibration would leave their weights mutated instead of bit-identical + to the original precision. + + Rules are applied in order with later rules overriding earlier ones: + 1. disable every linear weight/input quantizer, + 2. re-enable only those under ``block_module`` (``enable`` is a top-level + QuantizerCfgEntry toggle; a ``None`` cfg keeps the base preset's quant params), + 3. disable the first/last ``n`` blocks. + + Raises: + ValueError: if the backbone has no ``block_module`` list, or it has fewer + than ``exclude_first_n + exclude_last_n + 2`` blocks (it requires at + least two quantized middle blocks). + """ + blocks = getattr(backbone, block_module, None) + if blocks is None or not hasattr(blocks, "__len__"): + raise ValueError( + f"Backbone {type(backbone).__name__} has no '{block_module}' module list; " + "cannot build the transformer-block-range recipe." + ) + num_blocks = len(blocks) + # Require at least two quantized middle blocks so the recipe actually + # quantizes something (excluding first/last alone could otherwise leave 0-1 + # quantized blocks). For the default 2+2 recipe this means n >= 6. + min_blocks = exclude_first_n + exclude_last_n + 2 + if num_blocks < min_blocks: + raise ValueError( + f"'{block_module}' has only {num_blocks} block(s); excluding the first " + f"{exclude_first_n} and last {exclude_last_n} requires at least {min_blocks} blocks " + f"(at least 2 quantized middle blocks)." + ) + + excluded = sorted( + set(range(exclude_first_n)) | set(range(num_blocks - exclude_last_n, num_blocks)) + ) + # `enable` is a top-level QuantizerCfgEntry field (independent of `cfg`); a `None` + # cfg leaves the base preset's quant params untouched, so disabling then + # re-enabling restores the original (FP8/NVFP4/...) attributes. Putting `enable` + # under `cfg` is rejected by the QuantizerAttributeConfig validator. + rules: list[dict[str, Any]] = [ + {"quantizer_name": "*weight_quantizer", "enable": False}, + {"quantizer_name": "*input_quantizer", "enable": False}, + {"quantizer_name": f"*{block_module}.*weight_quantizer", "enable": True}, + {"quantizer_name": f"*{block_module}.*input_quantizer", "enable": True}, + ] + for idx in excluded: + rules.append({"quantizer_name": f"*{block_module}.{idx}.*weight_quantizer", "enable": False}) + rules.append({"quantizer_name": f"*{block_module}.{idx}.*input_quantizer", "enable": False}) + return rules diff --git a/examples/diffusers/quantization/pipeline_manager.py b/examples/diffusers/quantization/pipeline_manager.py index 85e335ba787..5496f131ba2 100644 --- a/examples/diffusers/quantization/pipeline_manager.py +++ b/examples/diffusers/quantization/pipeline_manager.py @@ -61,6 +61,12 @@ def create_pipeline_from( """ pipeline_cls = MODEL_PIPELINE[model_type] if pipeline_cls is None: + if model_type == ModelType.QWEN_IMAGE: + raise ImportError( + "Qwen-Image requires a diffusers version that provides " + "QwenImagePipeline. Please upgrade diffusers (e.g. " + "`pip install -U diffusers`) to a release that includes Qwen-Image." + ) raise ValueError(f"Model type {model_type.value} does not use diffusers pipelines.") model_id = ( MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path @@ -99,6 +105,12 @@ def create_pipeline(self) -> Any: pipeline_cls = MODEL_PIPELINE[self.config.model_type] if pipeline_cls is None: + if self.config.model_type == ModelType.QWEN_IMAGE: + raise ImportError( + "Qwen-Image requires a diffusers version that provides " + "QwenImagePipeline. Please upgrade diffusers (e.g. " + "`pip install -U diffusers`) to a release that includes Qwen-Image." + ) raise ValueError( f"Model type {self.config.model_type.value} does not use diffusers pipelines." ) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 299a101172a..325d39e5bd7 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -32,8 +32,13 @@ set_quant_config_attr, ) from diffusers import DiffusionPipeline -from models_utils import MODEL_DEFAULTS, ModelType, get_model_filter_func, parse_extra_params -from onnx_utils.export import generate_fp8_scales, modelopt_export_sd +from models_utils import ( + MODEL_DEFAULTS, + ModelType, + build_block_range_quant_cfg, + get_model_filter_func, + parse_extra_params, +) from pipeline_manager import PipelineManager from quantize_config import ( CalibrationConfig, @@ -163,6 +168,28 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any: } ) + # Apply the transformer-block-range recipe (e.g. Qwen-Image) BEFORE + # calibration. This restricts quantization to `transformer_blocks` and + # excludes the first/last N blocks. It must run before calibration so that + # SVDQuant does not mutate the weights of the excluded blocks. The recipe + # is format-agnostic (applies to FP8/NVFP4/SVDQuant alike). + block_range = MODEL_DEFAULTS.get(self.model_config.model_type, {}).get("block_range") + if block_range is not None: + recipe_rules = build_block_range_quant_cfg( + backbone, + exclude_first_n=block_range.get("exclude_first_n", 2), + exclude_last_n=block_range.get("exclude_last_n", 2), + block_module=block_range.get("block_module", "transformer_blocks"), + ) + self.logger.info( + f"Applying block-range recipe ({len(recipe_rules)} rules) for " + f"{self.model_config.model_type.value}: quantize only " + f"'{block_range.get('block_module', 'transformer_blocks')}' excluding " + f"first {block_range.get('exclude_first_n', 2)} / last " + f"{block_range.get('exclude_last_n', 2)} blocks." + ) + quant_cfg_list.extend(recipe_rules) + quant_config = {**base_cfg, "quant_cfg": quant_cfg_list} set_quant_config_attr( quant_config, @@ -291,6 +318,11 @@ def export_onnx( if not self.config.onnx_dir: return + # onnx_graphsurgeon (pulled in by onnx_utils.export) is an optional dependency + # only needed for the ONNX export path; import lazily so the HF-checkpoint + # export runs without it installed. + from onnx_utils.export import generate_fp8_scales, modelopt_export_sd + self.logger.info(f"Starting ONNX export to {self.config.onnx_dir}") if quant_format == QuantFormat.FP8 and self._has_conv_layers(backbone): @@ -379,6 +411,10 @@ def export_hf_ckpt(self, pipe: Any, model_config: ModelConfig | None = None) -> f"Invalid padding_strategy: {padding!r}. Expected 'row' or 'row_col'." ) kwargs["padding_strategy"] = padding + # The diffusion transformer is large (~20B params); the unified export's + # layerwise-metadata post-processing does not support sharded safetensors, so + # save each component as a single file (no *.safetensors.index.json). + kwargs.setdefault("max_shard_size", "200GB") export_hf_checkpoint(pipe, export_dir=self.config.hf_ckpt_dir, **kwargs) self.logger.info("HuggingFace checkpoint export completed successfully") @@ -542,6 +578,15 @@ def create_argument_parser() -> argparse.ArgumentParser: export_group.add_argument( "--restore-from", type=str, help="Path to restore from previous checkpoint" ) + export_group.add_argument( + "--sanity-image-path", + type=str, + default=None, + help="If set, generate one image from the in-memory quantized pipeline (after " + "quantization, before the weights are packed for export) and save it here. This is " + "a quick functional sanity check of quantized inference; it does NOT reload the " + "exported checkpoint.", + ) export_group.add_argument( "--trt-high-precision-dtype", type=str, @@ -681,6 +726,29 @@ def forward_loop(mod): pipeline_manager.print_quant_summary() + # Optional functional sanity check: generate one image from the in-memory + # quantized pipeline. This runs BEFORE export (while weights are still + # fake-quantized and runnable, not yet packed) and does not reload the + # exported checkpoint. + if args.sanity_image_path: + try: + logger.info(f"Generating sanity image to {args.sanity_image_path}") + inference_args = MODEL_DEFAULTS.get(model_type, {}).get("inference_extra_args", {}) + result = pipe( + prompt="A high-quality photo of a cat wearing sunglasses", + num_inference_steps=calib_config.n_steps, + **inference_args, + ) + sanity_path = Path(args.sanity_image_path) + sanity_path.parent.mkdir(parents=True, exist_ok=True) + result.images[0].save(str(sanity_path)) + logger.info("Sanity image saved successfully") + except Exception as sanity_error: + # A requested sanity image is a positive success criterion: if it + # cannot be produced, fail loudly rather than reporting success. + logger.error(f"Sanity image generation failed: {sanity_error}", exc_info=True) + raise + for backbone_name, backbone in pipeline_manager.iter_backbones(): export_manager.export_onnx( pipe, diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index d102e83e068..c3cfdcd5cdd 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -111,6 +111,30 @@ def filter_func_wan_video(name: str) -> bool: return pattern.match(name) is not None +# Qwen-Image's transformer has 60 ``transformer_blocks``. The recipe quantizes +# only those blocks while keeping the first two and last two -- and everything +# outside ``transformer_blocks`` -- in original precision. The model-agnostic, +# config-driven form of this recipe (deriving the block count from the model) +# lives in quantize.py; this name-only filter covers the plain FP8/NVFP4 path +# for the full 60-block Qwen-Image transformer. +QWEN_IMAGE_NUM_TRANSFORMER_BLOCKS = 60 +_QWEN_IMAGE_BLOCK_RE = re.compile(r"(?:^|\.)transformer_blocks\.(\d+)(?:\.|$)") + + +def filter_func_qwen_image(name: str) -> bool: + """Filter function specifically for Qwen-Image models. + + Returns ``True`` for modules to keep in original precision (quantization + disabled): everything outside ``transformer_blocks``, plus the first two and + last two transformer blocks. + """ + match = _QWEN_IMAGE_BLOCK_RE.search(name) + if match is None: + return True + block_idx = int(match.group(1)) + return block_idx < 2 or block_idx >= QWEN_IMAGE_NUM_TRANSFORMER_BLOCKS - 2 + + def load_calib_prompts( batch_size, calib_data_path: str | Path = "Gustavosta/Stable-Diffusion-Prompts", diff --git a/modelopt/torch/export/convert_hf_config.py b/modelopt/torch/export/convert_hf_config.py index 06e5923a30f..6f7dedb97c8 100644 --- a/modelopt/torch/export/convert_hf_config.py +++ b/modelopt/torch/export/convert_hf_config.py @@ -62,6 +62,19 @@ def _quant_algo_to_group_config(quant_algo: str, group_size: int | None = None) return { "weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": gs}, } + elif quant_algo == "NVFP4_SVD": + gs = group_size or 16 + return { + "input_activations": { + "dynamic": False, + "num_bits": 4, + "type": "float", + "group_size": gs, + }, + "weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": gs}, + "has_zero_point": False, + "pre_quant_scale": True, + } elif quant_algo in ("NVFP4_AWQ", "W4A8_AWQ"): gs = group_size or 128 return { @@ -196,6 +209,30 @@ def convert_hf_quant_config_format(input_config: dict[str, Any]) -> dict[str, An "targets": ["Linear"], } new_config["config_groups"] = {"group_0": config_group_details} + elif quant_algo_value == "NVFP4_SVD": + # NVFP4 + SVDQuant: NVFP4 weights/activations plus an AWQ-style + # pre_quant_scale and a low-rank residual (svdquant_lora_a/b) stored as + # .pre_quant_scale / .svdquant_lora_{a,b} in the + # safetensors. The config mirrors NVFP4 with a pre_quant_scale flag and + # the LoRA rank so consumers can reconstruct + # ``y = NVFP4_GEMM(x) + (x @ lora_a^T) @ lora_b^T``. + group_size = original_quantization_details.get("group_size", 16) + config_group_details = { + "input_activations": { + "dynamic": False, + "num_bits": 4, + "type": "float", + "group_size": group_size, + }, + "weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": group_size}, + "has_zero_point": False, + "pre_quant_scale": True, + "targets": ["Linear"], + } + lora_rank = original_quantization_details.get("lora_rank") + if lora_rank is not None: + config_group_details["lora_rank"] = lora_rank + new_config["config_groups"] = {"group_0": config_group_details} elif quant_algo_value == "MIXED_PRECISION": quantized_layers = original_quantization_details.get("quantized_layers", {}) diff --git a/modelopt/torch/export/diffusers_utils.py b/modelopt/torch/export/diffusers_utils.py index 9620c97c10e..075f25a9101 100644 --- a/modelopt/torch/export/diffusers_utils.py +++ b/modelopt/torch/export/diffusers_utils.py @@ -142,6 +142,11 @@ def _is_model_type(module_path: str, class_name: str, fallback: bool) -> bool: "UNet2DConditionModel", "unet" in model_class_name.lower(), ) + is_qwen = _is_model_type( + "diffusers.models.transformers", + "QwenImageTransformer2DModel", + "qwen" in model_class_name.lower(), + ) cfg = getattr(model, "config", None) @@ -321,6 +326,50 @@ def _wan_inputs() -> dict[str, torch.Tensor]: "return_dict": False, } + def _qwen_inputs() -> dict[str, Any]: + # QwenImageTransformer2DModel does NOT take the standard + # (hidden_states[B,C,H,W], timestep, encoder_hidden_states) triple. It expects + # *packed* latents [B, (H//2)*(W//2), in_channels] plus encoder_hidden_states, + # encoder_hidden_states_mask, img_shapes, and optional guidance. + # Timesteps are continuous in [0, 1] (not the diffusers [0, 1000] scale). + in_channels = getattr(cfg, "in_channels", 64) + joint_attention_dim = getattr(cfg, "joint_attention_dim", 3584) + guidance_embeds = getattr(cfg, "guidance_embeds", False) + + # Small packed spatial grid (already divided by the 2x2 patch size). + packed_h = packed_w = 4 + img_seq_len = packed_h * packed_w + text_seq_len = 8 + + dummy_inputs: dict[str, Any] = { + "hidden_states": torch.randn( + batch_size, img_seq_len, in_channels, device=device, dtype=dtype + ), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype + ), + "encoder_hidden_states_mask": torch.ones( + batch_size, text_seq_len, device=device, dtype=torch.int64 + ), + "timestep": torch.tensor([0.5], device=device, dtype=dtype).expand(batch_size), + "img_shapes": [[(1, packed_h, packed_w)]] * batch_size, + "return_dict": False, + } + if guidance_embeds: + dummy_inputs["guidance"] = torch.tensor([4.0], device=device, dtype=torch.float32) + + # Only pass kwargs the installed QwenImageTransformer2DModel.forward accepts + # (signatures vary across diffusers versions); prevents the strict QKV-fusion + # dummy forward from failing on an unexpected keyword argument. + import inspect + + try: + accepted = set(inspect.signature(model.forward).parameters) + dummy_inputs = {k: v for k, v in dummy_inputs.items() if k in accepted} + except (TypeError, ValueError): + pass + return dummy_inputs + def _generic_transformer_inputs() -> dict[str, torch.Tensor] | None: # Try generic transformer handling for other model types # Check if model has common transformer attributes @@ -366,6 +415,7 @@ def _generic_transformer_inputs() -> dict[str, torch.Tensor] | None: ("dit", is_dit, _dit_inputs), ("wan", is_wan, _wan_inputs), ("unet", is_unet, _unet_inputs), + ("qwen", is_qwen, _qwen_inputs), ] for _, matches, build_inputs in model_input_builders: @@ -685,15 +735,20 @@ def hide_quantizers_from_state_dict(model: nn.Module): # Store references to quantizers that we'll temporarily remove quantizer_backup: dict[str, dict[str, nn.Module]] = {} - for name, module in model.named_modules(): - if is_quantlinear(module): - backup = {} - for attr in ["weight_quantizer", "input_quantizer", "output_quantizer"]: - if hasattr(module, attr): - backup[attr] = getattr(module, attr) - delattr(module, attr) - if backup: - quantizer_backup[name] = backup + # Remove every quantizer submodule from *all* modules, not only recognized + # quant-linears: enabled input quantizers can also live on non-linear modules + # (e.g. norm layers whose activations were calibrated), and their ``_amax`` + # buffers must not leak into the saved checkpoint. Snapshot the module list + # first since we mutate the module tree while iterating. + for name, module in list(model.named_modules()): + backup = {} + for attr in ["weight_quantizer", "input_quantizer", "output_quantizer"]: + child = getattr(module, attr, None) + if isinstance(child, nn.Module): + backup[attr] = child + delattr(module, attr) + if backup: + quantizer_backup[name] = backup try: yield diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index d8ddf442924..e35ead6253b 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -749,9 +749,13 @@ def process_layer_quant_config(layer_config_dict): "group_size": block_size_value, } elif v == "nvfp4_svdquant": + # SVDQuant builds on the AWQ-style pre_quant_scale smoothing, so its + # config mirrors nvfp4_awq (group_size + pre_quant_scale flag). layer_config = { "quant_algo": "NVFP4_SVD", "group_size": block_size_value, + "has_zero_point": False, + "pre_quant_scale": True, } elif v == "mxfp8": layer_config = { diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index ef5757aa0cb..abbea9fc37e 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -960,7 +960,9 @@ def _export_transformers_checkpoint( def _fuse_qkv_linears_diffusion( - model: nn.Module, dummy_forward_fn: Callable[[], None] | None = None + model: nn.Module, + dummy_forward_fn: Callable[[], None] | None = None, + strict: bool = False, ) -> None: """Fuse QKV linear layers that share the same input for diffusion models. @@ -971,7 +973,8 @@ def _fuse_qkv_linears_diffusion( Note: This is a simplified version for diffusion models that: - Handles QKV fusion (shared input detection) - Filters to only fuse actual QKV projection layers (not AdaLN, FFN, etc.) - - Skips pre_quant_scale handling (TODO for future) + - Skips pre_quant_scale *fusion* (the export path promotes pre_quant_scale to + module-level keys separately; see _promote_quantizer_tensors_to_module) - Skips FFN fusion with layernorm (TODO for future) Args: @@ -994,6 +997,11 @@ def _fuse_qkv_linears_diffusion( model, dummy_forward_fn, collect_layernorms=False ) except Exception as e: + if strict: + raise RuntimeError( + f"QKV fusion dummy forward failed for {type(model).__name__}; a working " + f"dummy forward is required to export this model correctly. Original error: {e}" + ) from e print(f"Warning: Failed to run dummy forward for QKV fusion: {e}") print("Skipping QKV fusion. Quantization may still work but amax values won't be unified.") return @@ -1013,6 +1021,73 @@ def _fuse_qkv_linears_diffusion( ) +def _detect_svdquant_rank(component: nn.Module) -> int | None: + """Return the SVDQuant low-rank dimension from the first SVDQuant linear, if any. + + ``svdquant_lora_a`` has shape ``(rank, in_features)``, so its first dimension + is the low-rank size. + """ + for _, sub_module in component.named_modules(): + weight_quantizer = getattr(sub_module, "weight_quantizer", None) + lora_a = getattr(weight_quantizer, "svdquant_lora_a", None) + if lora_a is not None: + return int(lora_a.shape[0]) + return None + + +def _promote_quantizer_tensors_to_module(component: nn.Module) -> None: + """Promote quantizer-owned export tensors onto their parent linear module. + + The diffusers export path saves via ``save_pretrained`` inside + :func:`hide_quantizers_from_state_dict` (which deletes the ``weight_quantizer`` + / ``input_quantizer`` submodules) and -- unlike the transformers path -- does + NOT run :func:`postprocess_state_dict`. Without this step the AWQ smoothing + scale and the SVDQuant low-rank factors would be dropped from the exported + checkpoint. We register them as module buffers under clean, AWQ-aligned keys + so they are embedded in the component's main safetensors: + + - ``input_quantizer._pre_quant_scale`` -> ``.pre_quant_scale`` + (the same key the transformers/AWQ path produces via postprocess_state_dict) + - ``weight_quantizer.svdquant_lora_a`` -> ``.svdquant_lora_a`` + - ``weight_quantizer.svdquant_lora_b`` -> ``.svdquant_lora_b`` + + This runs after :func:`_process_quantized_modules` (which leaves these + quantizer buffers in place) and before ``save_pretrained``. + """ + for _, sub_module in component.named_modules(): + if not is_quantlinear(sub_module): + continue + + # register_buffer overwrites an existing buffer of the same name, so a + # repeated export refreshes (rather than keeps stale) promoted tensors. + input_quantizer = getattr(sub_module, "input_quantizer", None) + pre_quant_scale = getattr(input_quantizer, "_pre_quant_scale", None) + if pre_quant_scale is not None: + sub_module.register_buffer("pre_quant_scale", pre_quant_scale.detach().clone()) + + weight_quantizer = getattr(sub_module, "weight_quantizer", None) + lora_a = getattr(weight_quantizer, "svdquant_lora_a", None) + lora_b = getattr(weight_quantizer, "svdquant_lora_b", None) + if lora_a is not None and lora_b is not None: + sub_module.register_buffer("svdquant_lora_a", lora_a.detach().clone()) + sub_module.register_buffer("svdquant_lora_b", lora_b.detach().clone()) + + +def _remove_promoted_quantizer_tensors(component: nn.Module) -> None: + """Undo :func:`_promote_quantizer_tensors_to_module`. + + Removes the temporary module-level export buffers (``svdquant_lora_a/b`` and + ``pre_quant_scale``) so the live module is unchanged after export, keeping + repeated export / post-export module reuse correct. The quantizer-owned tensors + (``weight_quantizer.svdquant_lora_a/b``, ``input_quantizer._pre_quant_scale``) + are left untouched. + """ + for _, sub_module in component.named_modules(): + for buffer_name in ("svdquant_lora_a", "svdquant_lora_b", "pre_quant_scale"): + if buffer_name in getattr(sub_module, "_buffers", {}): + del sub_module._buffers[buffer_name] + + def _export_diffusers_checkpoint( pipe: Any, dtype: torch.dtype | None, @@ -1041,7 +1116,7 @@ def _export_diffusers_checkpoint( """ export_dir = Path(export_dir) - # Step 1: Get all pipeline components (nn.Module, tokenizers, schedulers, etc.) + # Get all pipeline components (nn.Module, tokenizers, schedulers, etc.) all_components = get_diffusion_components(pipe, components) if not all_components: @@ -1063,7 +1138,7 @@ def _export_diffusers_checkpoint( except Exception: is_diffusers_pipe = False - # Step 3: Export each nn.Module component with quantization handling + # Export each nn.Module component with quantization handling for component_name, component in module_components.items(): is_quantized = has_quantized_modules(component) status = "quantized" if is_quantized else "non-quantized" @@ -1082,20 +1157,38 @@ def _export_diffusers_checkpoint( component_dtype = dtype if dtype is not None else infer_dtype_from_model(component) if is_quantized: - # Step 3.5: Fuse QKV linears that share the same input (unify amax values) + # Fuse QKV linears that share the same input (unify amax values) # This is similar to requantize_resmooth_fused_llm_layers but simplified for diffusion - # TODO: Add pre_quant_scale handling and FFN fusion for AWQ-style quantization + # TODO: Add FFN fusion for AWQ-style quantization (pre_quant_scale is + # promoted to module keys at export by _promote_quantizer_tensors_to_module below) print(f" Running QKV fusion for {component_name}...") - _fuse_qkv_linears_diffusion(component) + # Qwen-Image's packed-latent forward signature is non-standard; if the + # dummy forward fails for it, fail loudly rather than silently skipping + # fusion (which would export un-unified amax values). + is_qwen_component = "qwen" in type(component).__name__.lower() + _fuse_qkv_linears_diffusion(component, strict=is_qwen_component) - # Step 4: Process quantized modules (convert weights, register scales) + # Process quantized modules (convert weights, register scales) _process_quantized_modules(component, component_dtype, is_modelopt_qlora=False) - # Step 5: Build quantization config + # Promote quantizer-owned tensors (AWQ pre_quant_scale and SVDQuant + # LoRA factors) onto the module so they survive + # hide_quantizers_from_state_dict and are embedded in the component's + # main safetensors under clean, AWQ-aligned keys. + _promote_quantizer_tensors_to_module(component) + + # Build quantization config quant_config = get_quant_config(component, is_modelopt_qlora=False) + if quant_config: + quantization_details = quant_config.get("quantization", {}) + # Record the SVDQuant low-rank size so consumers know the LoRA shape. + if quantization_details.get("quant_algo") == "NVFP4_SVD": + svdquant_rank = _detect_svdquant_rank(component) + if svdquant_rank is not None: + quantization_details["lora_rank"] = svdquant_rank hf_quant_config = convert_hf_quant_config_format(quant_config) if quant_config else None - # Step 6: Save the component + # Save the component # - diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter # - for non-diffusers modules (e.g., LTX-2 transformer), fall back to torch.save if hasattr(component, "save_pretrained"): @@ -1105,7 +1198,7 @@ def _export_diffusers_checkpoint( with hide_quantizers_from_state_dict(component): _save_component_state_dict_safetensors(component, component_export_dir) - # Step 7: Post-process — merge, metadata, padding, swizzle + # Post-process — merge, metadata, padding, swizzle _postprocess_safetensors( component_export_dir, pipe, @@ -1113,7 +1206,7 @@ def _export_diffusers_checkpoint( **kwargs, ) - # Step 8: Update config.json with quantization info + # Update config.json with quantization info if hf_quant_config is not None: config_path = component_export_dir / "config.json" if config_path.exists(): @@ -1122,13 +1215,17 @@ def _export_diffusers_checkpoint( config_data["quantization_config"] = hf_quant_config with open(config_path, "w") as file: json.dump(config_data, file, indent=4) + + # Drop the temporary promoted export buffers so the live module is + # unchanged after export (supports repeated export / module reuse). + _remove_promoted_quantizer_tensors(component) # Non-quantized component: just save as-is elif hasattr(component, "save_pretrained"): component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) else: _save_component_state_dict_safetensors(component, component_export_dir) - # Step 9: Update config.json with sparse attention info (both quantized and non-quantized) + # Update config.json with sparse attention info (both quantized and non-quantized) if export_sparse_attention_config is not None: sparse_attn_config = export_sparse_attention_config(component) if sparse_attn_config is not None: @@ -1143,7 +1240,7 @@ def _export_diffusers_checkpoint( print(f" Saved to: {component_export_dir}") - # Step 4: Export non-nn.Module components (tokenizers, schedulers, feature extractors, etc.) + # Export non-nn.Module components (tokenizers, schedulers, feature extractors, etc.) if is_diffusers_pipe: for component_name, component in all_components.items(): # Skip nn.Module components (already handled above) @@ -1171,7 +1268,7 @@ def _export_diffusers_checkpoint( print(f" Saved to: {component_export_dir}") - # Step 5: For pipelines, also save model_index.json + # For pipelines, also save model_index.json if is_diffusers_pipe: model_index_path = export_dir / "model_index.json" is_partial_export = components is not None diff --git a/tests/_test_utils/torch/diffusers_models.py b/tests/_test_utils/torch/diffusers_models.py index a42ebdd8ffb..caf4966399f 100644 --- a/tests/_test_utils/torch/diffusers_models.py +++ b/tests/_test_utils/torch/diffusers_models.py @@ -289,6 +289,14 @@ def get_tiny_qwen_image_transformer(**config_kwargs): "axes_dims_rope": (8, 4, 4), # sums to attention_head_dim (16) } kwargs.update(**config_kwargs) + # Drop kwargs the installed diffusers QwenImageTransformer2DModel doesn't accept. + # `pooled_projection_dim` is present in the published config.json but was removed + # from the constructor in newer diffusers: from_pretrained tolerates the extra + # config key, but a direct constructor call raises TypeError. + import inspect + + accepted = set(inspect.signature(QwenImageTransformer2DModel.__init__).parameters) + kwargs = {k: v for k, v in kwargs.items() if k in accepted} return QwenImageTransformer2DModel(**kwargs) @@ -311,21 +319,46 @@ def get_tiny_qwen_image_vae(**config_kwargs): return AutoencoderKLQwenImage(**kwargs) +def _build_local_qwen2_tokenizer(out_dir: Path): + """Build a tiny, fully offline byte-level Qwen2 tokenizer (no Hub access). + + Uses the GPT-2/Qwen byte->unicode mapping for the 256 single-byte tokens plus + Qwen's core special tokens, with an empty merge table (pure byte-level + fallback). This is enough to tokenize calibration prompts so the pipeline runs + end-to-end; it is not meant for high-quality text. + """ + import json + + import transformers + from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode + + out_dir.mkdir(parents=True, exist_ok=True) + vocab = {token: idx for idx, token in enumerate(bytes_to_unicode().values())} + for special in ("<|endoftext|>", "<|im_start|>", "<|im_end|>"): + vocab.setdefault(special, len(vocab)) + (out_dir / "vocab.json").write_text(json.dumps(vocab)) + (out_dir / "merges.txt").write_text("#version: 0.2\n") + + return transformers.Qwen2Tokenizer( + vocab_file=str(out_dir / "vocab.json"), + merges_file=str(out_dir / "merges.txt"), + unk_token="<|endoftext|>", + eos_token="<|endoftext|>", + pad_token="<|endoftext|>", + ) + + def create_tiny_qwen_image_pipeline_dir(tmp_path: Path) -> Path: - """Create and save a tiny Qwen-Image pipeline to a directory (SKETCH). - - Mirrors ``create_tiny_wan22_pipeline_dir``. Needs in-container validation; the - fragile piece is the Qwen2.5-VL text encoder. This prefers a tiny-random HF model - (as Wan uses ``hf-internal-testing/tiny-random-t5``); if that id drifts or the - config schema differs across transformers versions, copy the text-encoder - construction from diffusers' own QwenImage fast test - (``tests/pipelines/qwenimage/test_qwenimage.py``). - - For the DMD2 mock-data training path the transformer consumes the dataloader's - embeddings rather than the text encoder, so the bundled tiny text encoder only - needs to load; its hidden size is intentionally decoupled from the transformer's - ``joint_attention_dim`` (set the dataloader's ``text_embed_dim`` to match instead). - The saved dir loads with ``QwenImagePipeline.from_pretrained(path)``. + """Create and save a tiny, fully offline Qwen-Image pipeline to a directory. + + Mirrors diffusers' ``QwenImagePipelineFastTests.get_dummy_components`` but with + no Hub access: the Qwen2.5-VL text encoder is built inline from a tiny + ``Qwen2_5_VLConfig``, and the tokenizer is built locally by + ``_build_local_qwen2_tokenizer`` (byte-level vocab written to a temp dir). The + transformer uses ``num_layers=6`` so the first-2/last-2 block-range recipe is + valid, and its ``joint_attention_dim`` matches the text encoder ``hidden_size`` + (16) so the pipeline runs end-to-end during quantization calibration. The saved + dir loads with ``QwenImagePipeline.from_pretrained(path)``. """ if QwenImageTransformer2DModel is None or AutoencoderKLQwenImage is None: pytest.skip("QwenImage diffusers classes not available in this diffusers version.") @@ -333,27 +366,52 @@ def create_tiny_qwen_image_pipeline_dir(tmp_path: Path) -> Path: transformers = pytest.importorskip("transformers") - # Tiny Qwen2.5-VL text encoder + matching Qwen2 tokenizer (loaded, but bypassed - # during DMD2 mock-data training). - # NOTE (validated 2026-06-06): the hf-internal-testing id below does NOT exist on the - # Hub, so this fixture currently skips. To make the recipe e2e runnable in CI, - # construct the encoder inline from a tiny ``Qwen2_5_VLConfig`` (nested text + vision - # config) — mirror diffusers' ``QwenImagePipelineFastTests.get_dummy_components`` in - # ``tests/pipelines/qwenimage/test_qwenimage.py``. - tiny_id = "hf-internal-testing/tiny-random-Qwen2_5_VLForConditionalGeneration" - try: - text_encoder = transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained(tiny_id) - tokenizer = transformers.Qwen2Tokenizer.from_pretrained(tiny_id) - except Exception as exc: # pragma: no cover - depends on hub availability / version - pytest.skip( - f"tiny Qwen2.5-VL text encoder unavailable ({exc}); " - "copy the fixture from diffusers' QwenImage fast test" - ) + # Tiny Qwen2.5-VL text encoder, built offline from a tiny config (no Hub model + # load), mirroring diffusers' QwenImagePipelineFastTests.get_dummy_components. + qwen_vl_config = transformers.Qwen2_5_VLConfig( + text_config={ + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "rope_scaling": { + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000.0, + }, + vision_config={ + "depth": 2, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 2, + "out_hidden_size": 16, + }, + hidden_size=16, + vocab_size=152064, + vision_end_token_id=151653, + vision_start_token_id=151652, + vision_token_id=151654, + ) + text_encoder = transformers.Qwen2_5_VLForConditionalGeneration(qwen_vl_config).eval() + + # Deterministic local byte-level Qwen2 tokenizer (built offline; no Hub, no skip). + tokenizer = _build_local_qwen2_tokenizer(tmp_path / "qwen_tokenizer") torch.manual_seed(0) - transformer = get_tiny_qwen_image_transformer() + # num_layers=6 so the first-2/last-2 block-range recipe (which needs >=6 blocks) + # is valid; joint_attention_dim must match the text encoder hidden_size (16). + transformer = get_tiny_qwen_image_transformer( + num_layers=6, + in_channels=16, + out_channels=4, + joint_attention_dim=16, + num_attention_heads=3, + ) torch.manual_seed(0) - vae = get_tiny_qwen_image_vae() + vae = get_tiny_qwen_image_vae(z_dim=4, latents_mean=[0.0] * 4, latents_std=[1.0] * 4) scheduler = FlowMatchEulerDiscreteScheduler( base_image_seq_len=256, diff --git a/tests/examples/diffusers/conftest.py b/tests/examples/diffusers/conftest.py index e704f6d5879..625f9bc3415 100644 --- a/tests/examples/diffusers/conftest.py +++ b/tests/examples/diffusers/conftest.py @@ -35,9 +35,11 @@ def tiny_wan22_path(tmp_path_factory): def tiny_qwen_image_path(tmp_path_factory): """Create a tiny Qwen-Image pipeline and return its path (built once per session). - SKETCH fixture for the recipe-level DMD2 e2e (``test_fastgen_recipe_e2e.py``). - See ``create_tiny_qwen_image_pipeline_dir`` for caveats — notably the tiny - Qwen2.5-VL text encoder, which needs in-container validation. + Used by the diffusers Qwen export tests and the recipe-level DMD2 e2e + (``test_fastgen_recipe_e2e.py``). The pipeline is built fully offline by + ``create_tiny_qwen_image_pipeline_dir`` (inline tiny Qwen2.5-VL text encoder + + local byte-level tokenizer); it skips only when the diffusers Qwen classes are + unavailable. """ try: from _test_utils.torch.diffusers_models import create_tiny_qwen_image_pipeline_dir diff --git a/tests/examples/diffusers/test_export_diffusers_hf_ckpt.py b/tests/examples/diffusers/test_export_diffusers_hf_ckpt.py index 88821bbf8f7..9e2e30a160c 100644 --- a/tests/examples/diffusers/test_export_diffusers_hf_ckpt.py +++ b/tests/examples/diffusers/test_export_diffusers_hf_ckpt.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from pathlib import Path from typing import NamedTuple @@ -130,6 +131,151 @@ def test_diffusers_hf_ckpt_export(model: DiffuserHfExportModel, tmp_path: Path) assert len(weight_files) > 0, f"No weight files (.safetensors or .bin) found in {hf_ckpt_dir}" +class QwenHfExportModel(NamedTuple): + format_type: str + quant_algo: str + is_svdquant: bool + + def quantize_and_export_hf(self, tiny_qwen_image_path: str, tmp_path: Path) -> Path: + hf_ckpt_dir = tmp_path / f"qwen_{self.format_type}_{self.quant_algo}_hf_ckpt" + cmd_args = [ + "python", + "quantize.py", + "--model", + "qwen-image", + "--override-model-path", + str(tiny_qwen_image_path), + "--format", + self.format_type, + "--quant-algo", + self.quant_algo, + "--collect-method", + "default", + "--model-dtype", + "BFloat16", + "--trt-high-precision-dtype", + "BFloat16", + "--calib-size", + "2", + "--batch-size", + "1", + "--n-steps", + "2", + "--hf-ckpt-dir", + str(hf_ckpt_dir), + ] + if self.is_svdquant: + cmd_args.extend(["--lowrank", "8"]) + run_example_command(cmd_args, "diffusers/quantization") + return hf_ckpt_dir + + +def _module_prefixes(keys: set[str], suffix: str) -> set[str]: + """Module paths (key minus suffix) for every key ending in ``suffix``.""" + return {k[: -len(suffix)] for k in keys if k.endswith(suffix)} + + +def _block_indices(prefixes: set[str]) -> set[int]: + """transformer_blocks indices referenced by a set of module prefixes.""" + import re + + indices = set() + for prefix in prefixes: + match = re.search(r"transformer_blocks\.(\d+)\.", prefix) + if match: + indices.add(int(match.group(1))) + return indices + + +# Tiny Qwen fixture has 6 transformer blocks; the recipe excludes the first 2 and +# last 2, so only blocks 2 and 3 are quantized. +_QWEN_QUANTIZED_BLOCKS = {2, 3} +_QWEN_LORA_RANK = 8 + + +@pytest.mark.parametrize( + "qwen_model", + [ + pytest.param(QwenHfExportModel("fp8", "max", False), marks=minimum_sm(89)), + pytest.param(QwenHfExportModel("fp4", "max", False), marks=minimum_sm(89)), + pytest.param(QwenHfExportModel("fp4", "svdquant", True), marks=minimum_sm(89)), + ], + ids=["qwen_fp8_max", "qwen_nvfp4_max", "qwen_nvfp4_svdquant"], +) +def test_qwen_image_hf_ckpt_export( + qwen_model: QwenHfExportModel, tiny_qwen_image_path: str, tmp_path: Path +) -> None: + from safetensors import safe_open + + hf_ckpt_dir = qwen_model.quantize_and_export_hf(tiny_qwen_image_path, tmp_path) + assert hf_ckpt_dir.exists(), f"HF checkpoint directory was not created: {hf_ckpt_dir}" + + # The transformer is the quantized component. + transformer_dir = hf_ckpt_dir / "transformer" + config_path = transformer_dir / "config.json" + assert config_path.exists(), f"no transformer/config.json in {hf_ckpt_dir}" + quant_config = json.loads(config_path.read_text()).get("quantization_config") + assert quant_config is not None, "missing quantization_config" + assert quant_config.get("quant_method") == "modelopt" + + keys: set[str] = set() + lora_tensors: dict[str, "object"] = {} + safetensors_files = sorted(transformer_dir.rglob("*.safetensors")) + assert safetensors_files, f"no safetensors in {transformer_dir}" + for path in safetensors_files: + with safe_open(str(path), framework="pt") as handle: + for key in handle.keys(): + keys.add(key) + if key.endswith(".svdquant_lora_a") or key.endswith(".svdquant_lora_b"): + lora_tensors[key] = handle.get_tensor(key) + + # No live quantizer state should leak into the exported checkpoint. + assert not any("weight_quantizer" in k for k in keys), "quantizer keys leaked into export" + assert not any("input_quantizer._amax" in k for k in keys) + + # Recipe: only the middle transformer blocks are quantized — first-2/last-2 of + # transformer_blocks are excluded, and nothing outside transformer_blocks. + weight_scale_prefixes = _module_prefixes(keys, ".weight_scale") + assert weight_scale_prefixes, "no quantized linears found in export" + assert all("transformer_blocks." in p for p in weight_scale_prefixes), ( + f"a non-transformer_blocks module was quantized: {weight_scale_prefixes}" + ) + assert _block_indices(weight_scale_prefixes) == _QWEN_QUANTIZED_BLOCKS, ( + f"expected only blocks {_QWEN_QUANTIZED_BLOCKS} quantized" + ) + + if qwen_model.is_svdquant: + a_prefixes = _module_prefixes(keys, ".svdquant_lora_a") + b_prefixes = _module_prefixes(keys, ".svdquant_lora_b") + pqs_prefixes = _module_prefixes(keys, ".pre_quant_scale") + assert a_prefixes, "no promoted svdquant_lora_a keys" + # Every promoted linear carries lora_a, lora_b, and pre_quant_scale, and + # every quantized linear is promoted (the sets are identical). + assert a_prefixes == b_prefixes == pqs_prefixes == weight_scale_prefixes + assert _block_indices(a_prefixes) == _QWEN_QUANTIZED_BLOCKS + # Rank-consistent shapes; lora_a=[rank, in], lora_b=[out, rank], rank == --lowrank. + for key, tensor in lora_tensors.items(): + if key.endswith(".svdquant_lora_a"): + assert tensor.shape[0] == _QWEN_LORA_RANK + else: + assert tensor.shape[1] == _QWEN_LORA_RANK + # NVFP4 secondary scales are present. + assert any(k.endswith(".weight_scale_2") for k in keys) + # config schema (modeled on nvfp4_awq). + assert quant_config.get("quant_algo") == "NVFP4_SVD" + group = next(iter(quant_config.get("config_groups", {}).values()), {}) + assert group.get("lora_rank") == _QWEN_LORA_RANK + assert group.get("pre_quant_scale") is True + assert group.get("has_zero_point") is False + assert quant_config.get("ignore"), "expected excluded modules in 'ignore'" + else: + # Plain FP8/NVFP4: weight scales present, no SVDQuant tensors. + assert weight_scale_prefixes, "no weight_scale in export" + assert not any(k.endswith(".svdquant_lora_a") for k in keys) + if qwen_model.format_type == "fp4": + assert any(k.endswith(".weight_scale_2") for k in keys) + + class Wan22HfExportModel(NamedTuple): model: str backbone: str | None diff --git a/tests/unit/torch/export/test_export_diffusers.py b/tests/unit/torch/export/test_export_diffusers.py index 1a7a3495158..2f1e5085a3e 100644 --- a/tests/unit/torch/export/test_export_diffusers.py +++ b/tests/unit/torch/export/test_export_diffusers.py @@ -117,3 +117,54 @@ def test_flux2_dummy_inputs_shape(): # guidance_embeds defaults to True for Flux2 assert "guidance" in inputs + + +def test_svdquant_diffusers_export_promotes_clean_keys(): + """Fast CPU check of the diffusers SVDQuant export promotion. + + SVDQuant calibration stores the low-rank factors on ``weight_quantizer`` and the + smoothing scale on ``input_quantizer``; the diffusers export promotes both to + clean module-level keys (``svdquant_lora_a/b``, ``pre_quant_scale``) and hides the + quantizers, so the saved state dict carries no live quantizer tensors. The full + NVFP4 end-to-end coverage lives in the GPU test + ``tests/examples/diffusers/test_export_diffusers_hf_ckpt.py`` (``qwen_nvfp4_svdquant``). + """ + import copy + + import torch.nn as nn + + import modelopt.torch.quantization as mtq + from modelopt.torch.export.diffusers_utils import hide_quantizers_from_state_dict + + torch.manual_seed(0) + model = nn.Sequential(nn.Linear(64, 64), nn.Linear(64, 64)) + + quant_config = copy.deepcopy(mtq.INT8_SMOOTHQUANT_CFG) + quant_config["algorithm"] = {"method": "svdquant", "lowrank": 8} + mtq.quantize(model, quant_config, lambda m: m(torch.randn(8, 64))) + + # Calibration populated the quantizer-owned SVDQuant tensors. + linear = model[0] + assert linear.weight_quantizer.svdquant_lora_a is not None + assert linear.weight_quantizer.svdquant_lora_b is not None + assert getattr(linear.input_quantizer, "_pre_quant_scale", None) is not None + + # Export promotes them to clean module-level keys and hides the quantizers. + unified_export_hf._promote_quantizer_tensors_to_module(model) + with hide_quantizers_from_state_dict(model): + keys = set(model.state_dict().keys()) + + assert any(k.endswith(".svdquant_lora_a") for k in keys) + assert any(k.endswith(".svdquant_lora_b") for k in keys) + assert any(k.endswith(".pre_quant_scale") for k in keys) + assert not any("weight_quantizer" in k or "input_quantizer" in k for k in keys), ( + "live quantizer state leaked into the exported state dict" + ) + + # The promotion is undone after export, leaving the live module unchanged. + unified_export_hf._remove_promoted_quantizer_tensors(model) + keys_after = set(model.state_dict().keys()) + assert not any( + k.endswith((".svdquant_lora_a", ".svdquant_lora_b", ".pre_quant_scale")) + for k in keys_after + )