diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 49c58586674..b218cdbb5bb 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,6 +8,7 @@ Changelog - Add the ``day0-release`` agent skill (``.agents/skills/day0-release/``), a deterministic end-to-end driver that chains the PTQ → evaluation → comparison skills (the evaluation stage deploys the checkpoint itself) with an enforced gate after each stage and returns a publish decision (ACCEPT / REGRESSION / ANOMALOUS / INFEASIBLE). Ships three GPU-free, unit-tested gate scripts (``gate_ptq.py``, ``gate_run.py``, ``gate_compare.py``) that validate checkpoint coverage, evaluation-run completeness, and baseline-vs-candidate accuracy threshold. v1 reports and stops on regression; the recipe-search loop is deferred. - Add **streaming** speculative-decoding training (EAGLE3 / DFlash): the draft trains on base-model hidden states produced on the fly by a co-located ``vllm serve`` (no disk dump), moved trainer-side over NIXL RDMA, scaling to multi-node (dedicated serve replicas + DDP trainers). New launcher examples for NVFP4 Kimi-K2.5 / K2.6 on GB200/aarch64 under ``tools/launcher/examples/moonshotai/``. +- Add tied-weight PTQ and HF-checkpoint export support for block-diffusion encoder-decoder LLMs (e.g. DiffusionGemma) whose encoder/decoder stacks share parameters via HF ``_tied_weights_keys``. ``_export_quantized_weight`` and ``_export_fused_experts`` now alias bit-identical packed ``weight`` / ``weight_scale`` / ``weight_scale_2`` buffers across modules sharing a source weight ``data_ptr()`` so the downstream ``postprocess_state_dict`` dedup catches them (~42% storage reduction on ``nvfp4_experts_only`` for tied 26B MoE checkpoints). New ``sync_tied_input_amax`` helper max-merges per-side ``input_quantizer.amax`` across tied modules before export so single-backbone consumers that load one ``input_scale`` per parameter don't clip either side. Opt-in ``--canonical_tied_naming`` flag (default off) reorders the state_dict so canonical-side keys per HF's ``_tied_weights_keys`` declaration win the data_ptr dedup. ``default_disabled_quantizers`` gains a ``*self_conditioning*`` wildcard companion to the upstream vision excludes (PR #1691). ``hf_ptq.py`` also unwraps ``ModelOutput`` dataclasses from ``.generate()`` so the preview decode works on diffusion models. Non-tied models see no behavioral change. 0.45 (2026-06-xx) ^^^^^^^^^^^^^^^^^ diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index d36754a8d42..1d3c196d428 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -806,7 +806,13 @@ def is_model_on_gpu(model) -> bool: def is_enc_dec(model_type) -> bool: - """Return if the model is a encoder-decoder model.""" + """Return whether the model_type uses encoder-decoder-style preview decode. + + Controls whether ``hf_ptq.py`` slices off the prompt prefix from + ``.generate()`` output. ``diffusion_gemma`` is structurally encoder-decoder + but returns prompt+canvas concatenated, so it stays OFF this list (AR-style + decode applies). + """ return model_type in ["t5", "bart", "whisper"] diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index afb725988c8..14215853eaa 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -774,6 +774,7 @@ def export_quantized( full_model, export_dir=export_path, extra_state_dict=mtp_state_dict, + canonical_tied_naming=args.canonical_tied_naming, ) if args.qformat == "w4a16_nvfp4": @@ -941,6 +942,11 @@ def input_decode(input_ids): raise ValueError("The processor or tokenizer must be set") def output_decode(generated_ids, input_shape): + # Some `.generate()` returns a ModelOutput dataclass (e.g. DiffusionGemma); + # unwrap to the token tensor so downstream slicing works uniformly. + if hasattr(generated_ids, "sequences"): + generated_ids = generated_ids.sequences + if is_enc_dec(model_type): if processor is not None and isinstance(processor, WhisperProcessor): return processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] @@ -1252,6 +1258,19 @@ def parse_args() -> argparse.Namespace: default=512, ) parser.add_argument("--export_path", default="exported_model") + parser.add_argument( + "--canonical_tied_naming", + type=lambda s: s.lower() in ("1", "true", "yes"), + default=False, + help=( + "If True, reorder the exported state_dict so tied-weight aliases " + "dedup to the canonical side declared in the model's HF " + "_tied_weights_keys (e.g. decoder-side for DiffusionGemma4). Off " + "by default to avoid renaming exported keys for models whose " + "downstream consumers expect the legacy (registration-order) " + "winner." + ), + ) parser.add_argument( "--dataset", help=( diff --git a/modelopt/torch/export/model_utils.py b/modelopt/torch/export/model_utils.py index 3bd72d9de91..9c49cae0cf1 100755 --- a/modelopt/torch/export/model_utils.py +++ b/modelopt/torch/export/model_utils.py @@ -33,6 +33,9 @@ "Qwen3Next": "qwen3next", "QWen": "qwen", "RecurrentGemma": "recurrentgemma", + # DiffusionGemma must come before "Gemma" — get_model_type substring-matches + # in order, and "gemma" is a substring of "diffusiongemma". + "DiffusionGemma": "diffusion_gemma", "Gemma3": "gemma3", "Gemma2": "gemma2", "Gemma": "gemma", diff --git a/modelopt/torch/export/moe_utils.py b/modelopt/torch/export/moe_utils.py index e325e5346f1..3b0e49fe54a 100644 --- a/modelopt/torch/export/moe_utils.py +++ b/modelopt/torch/export/moe_utils.py @@ -42,6 +42,13 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: {E}.gate_proj.weight, {E}.gate_proj.weight_scale, ... {E}.up_proj.weight, {E}.up_proj.weight_scale, ... {E}.down_proj.weight, {E}.down_proj.weight_scale, ... + + Tied-experts dedup: when multiple fused-expert modules share their 3-D + source params via HF ``_tied_weights_keys``, the unpacking creates fresh + per-expert tensors that break the tie. We cache the source ``data_ptr()`` + at entry and on a later cache hit alias the per-expert ``weight`` / + ``weight_scale`` / ``weight_scale_2`` back to the prior module so + downstream dedup catches them. ``input_scale`` is left per-side. """ from modelopt.torch.export.unified_export_hf import _export_quantized_weight from modelopt.torch.quantization.plugins.huggingface import _get_fused_expert_intermediate_dim @@ -49,6 +56,10 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: n = module.num_experts expert_dim = _get_fused_expert_intermediate_dim(module) + # Capture source tensor identities BEFORE unpacking (the source + # attrs are deleted at the end of this function). + _source_key = (module.gate_up_proj.data_ptr(), module.down_proj.data_ptr()) + # 1. Shared input quantizers — one per projection type, shared across all experts. gate_up_input_q = module.gate_up_proj_input_quantizer down_input_q = module.down_proj_input_quantizer @@ -178,6 +189,46 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: if hasattr(module, attr): delattr(module, attr) + # 5. Tied-experts dedup: if this module's source params have been seen + # before, alias the bit-identical per-expert buffers (weight, + # weight_scale, weight_scale_2, input_scale) to the previously-unpacked + # module. input_scale is safe to alias because sync_tied_input_amax + # runs earlier in _export_transformers_checkpoint and max-merges the + # shared input_quantizer amaxes across tied fused-experts modules, so + # both sides now derive bit-identical input_scale values. + _cache = _export_fused_experts.__dict__.setdefault("_tied_unpacked_cache", {}) + _prior = _cache.get(_source_key) + if _prior is not None and _prior is not module: + for _idx in range(n): + _cur_expert = getattr(module, str(_idx), None) + _prior_expert = getattr(_prior, str(_idx), None) + if _cur_expert is None or _prior_expert is None: + continue + for _proj_name in ("gate_proj", "up_proj", "down_proj"): + _cur_proj = getattr(_cur_expert, _proj_name, None) + _prior_proj = getattr(_prior_expert, _proj_name, None) + if _cur_proj is None or _prior_proj is None: + continue + # Alias the weight (Parameter) so both sides reference the + # same nn.Parameter → same data_ptr() → existing dedup + # in postprocess_state_dict will drop the duplicate. + if hasattr(_prior_proj, "weight"): + _cur_proj.weight = _prior_proj.weight + # Alias the bit-identical scale buffers (including + # input_scale, made safe by sync_tied_input_amax pre-export + # merging). Re-register to ensure data_ptr() matches the + # prior side's tensor. + for _attr in ("weight_scale", "weight_scale_2", "input_scale"): + if not hasattr(_prior_proj, _attr): + continue + if _attr in _cur_proj._buffers: + del _cur_proj._buffers[_attr] + elif hasattr(_cur_proj, _attr): + delattr(_cur_proj, _attr) + _cur_proj.register_buffer(_attr, getattr(_prior_proj, _attr)) + else: + _cache[_source_key] = module + def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | None = None): """Collect expert_token_count from all quantized MoE layers and save as an HTML table. diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index ef5757aa0cb..89973facce1 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -520,6 +520,14 @@ def _export_quantized_weight( The export includes converting weight tensor to correct quantized values and quantized dtype, and registering scaling factors. + + Tied-weight dedup: the setattr below replaces ``.weight`` with a fresh + ``nn.Parameter`` wrapping packed bytes, breaking any HF-level tie. + We capture ``weight.data_ptr()`` before the replacement and consult a + function-local cache at the end; on cache hit, ``weight`` / ``weight_scale`` / + ``weight_scale_2`` are re-pointed at the previously-processed module so the + downstream data_ptr dedup catches them. Uses memory identity only — no + ``_tied_weights_keys`` lookup, no-op for non-tied modules. """ quantization_format = get_quantization_format(sub_module) if quantization_format == QUANTIZATION_NONE: @@ -528,6 +536,13 @@ def _export_quantized_weight( block_size = get_weight_block_size(sub_module, weight_name) quantizer_attrs = quantizer_attr_names(weight_name) weight: nn.Parameter = getattr(sub_module, weight_name) + + # Capture source identity BEFORE any tensor-creating operation below. + # For HF-tied weights this matches across all modules sharing the + # underlying Parameter; the cache lookup at the end of this function + # uses it to detect ties whose Python identity is about to be broken + # by the setattr on `weight_name` further down. + _tied_source_data_ptr = weight.data_ptr() weight_quantizer: TensorQuantizer | SequentialQuantizer = getattr( sub_module, quantizer_attrs.weight_quantizer ) @@ -703,9 +718,177 @@ def _export_quantized_weight( if weight_scale is not None: sub_module.register_buffer(quantizer_attrs.weight_scale, weight_scale) + # Tied-weight dedup: if a previously-processed module shared the same + # source weight memory, alias the packed weight + scale buffers so the + # downstream data_ptr dedup in postprocess_state_dict can collapse them. + # input_scale is safe to alias because sync_tied_input_amax (earlier in + # this export) already max-merged the per-side amaxes. + _cache = _export_quantized_weight.__dict__.setdefault("_tied_weight_alias_cache", {}) + _prior = _cache.get(_tied_source_data_ptr) + if _prior is not None and _prior is not sub_module: + if hasattr(_prior, weight_name): + setattr(sub_module, weight_name, getattr(_prior, weight_name)) + for _attr in ( + quantizer_attrs.weight_scale, + quantizer_attrs.weight_scale_2, + quantizer_attrs.input_scale, + ): + if _attr is None or not hasattr(_prior, _attr): + continue + if _attr in sub_module._buffers: + del sub_module._buffers[_attr] + elif hasattr(sub_module, _attr): + delattr(sub_module, _attr) + sub_module.register_buffer(_attr, getattr(_prior, _attr)) + else: + _cache[_tied_source_data_ptr] = sub_module + torch.cuda.empty_cache() +def _collect_canonical_tied_patterns( + model: nn.Module, +) -> tuple[list[re.Pattern], list[str]]: + """Walk the model and collect canonical-side tied-weight matchers. + + Patterns are submodule-prefixed regexes from each module's + ``_tied_weights_keys`` dict-style declaration (the prefix matters + for nested models where the dict lives on an inner submodule). + Side substrings are dot-separated tokens that appear only on the + canonical side of those declarations — needed because modelopt's + per-expert unpacking creates post-export keys (e.g. + ``…experts.Y.gate_proj.input_scale``) that HF's regexes never knew + about. List-style (legacy) declarations are skipped. + """ + patterns: list[re.Pattern] = [] + alias_token_set: set[str] = set() + canonical_token_set: set[str] = set() + + def _tokens(s: str) -> set[str]: + """Identifiers in a regex string, with regex specials as separators.""" + return {tok for tok in re.split(r"[^A-Za-z0-9_]+", s) if tok} + + for name, submodule in model.named_modules(): + tied = getattr(submodule, "_tied_weights_keys", None) + if not isinstance(tied, dict) or not tied: + continue + prefix = f"{name}." if name else "" + for alias_pat, canonical_pat in tied.items(): + patterns.append(re.compile(prefix + canonical_pat)) + alias_token_set.update(_tokens(prefix + alias_pat)) + canonical_token_set.update(_tokens(prefix + canonical_pat)) + + # Tokens unique to the canonical side become substring matchers. + side_substrings = sorted(canonical_token_set - alias_token_set) + return patterns, side_substrings + + +def _reorder_canonical_first(state_dict: dict, model: nn.Module) -> dict: + r"""Reorder ``state_dict`` so canonical-side tied keys iterate first. + + Lets the downstream first-wins data_ptr dedup keep canonical names. + Uses both regex patterns and substring matchers from + :func:`_collect_canonical_tied_patterns`. No-op when the model + declares no dict-style ``_tied_weights_keys``. + """ + canonical_patterns, side_substrings = _collect_canonical_tied_patterns(model) + if not canonical_patterns and not side_substrings: + return state_dict + + def _has_side_substring(key: str) -> bool: + # Require the token to appear as a proper dot-separated path + # component, not just as a substring of an unrelated identifier. + for tok in side_substrings: + if ( + f".{tok}." in key + or key.startswith(f"{tok}.") + or key.endswith(f".{tok}") + or key == tok + ): + return True + return False + + head: dict = {} + tail: dict = {} + for k, v in state_dict.items(): + if any(p.search(k) for p in canonical_patterns) or _has_side_substring(k): + head[k] = v + else: + tail[k] = v + head.update(tail) + return head + + +def sync_tied_input_amax(model: nn.Module) -> int: + """Max-merge input_quantizer amaxes across modules sharing a weight ``data_ptr``. + + Closes the loop on ``input_scale`` for HF-tied modules whose forward + paths see different activation distributions (encoder vs decoder in + YOCO-style models). Must run BEFORE per-module export so the merged + amax flows into ``input_scale`` derivation. Handles both dense + Linears (keyed by ``weight.data_ptr()``) and fused MoE (keyed by + ``(gate_up_proj, down_proj)`` data_ptr tuple). Returns the number of + tied groups merged. + """ + from collections import defaultdict + + by_dp: dict = defaultdict(list) + for _, m in model.named_modules(): + # Fused MoE: 3-D source tensors with shared input quantizers + if ( + hasattr(m, "gate_up_proj_input_quantizer") + and hasattr(m, "gate_up_proj") + and hasattr(m, "down_proj") + and m.gate_up_proj.dim() == 3 + ): + key = ("moe", m.gate_up_proj.data_ptr(), m.down_proj.data_ptr()) + by_dp[key].append(m) + # Dense quantized Linear with an input_quantizer + elif ( + hasattr(m, "input_quantizer") + and hasattr(m, "weight") + and isinstance(m.weight, torch.nn.Parameter) + ): + by_dp[("dense", m.weight.data_ptr())].append(m) + + def _merge(quantizers: list) -> bool: + """Max-merge amaxes across the quantizer list. Returns True on merge.""" + valid = [ + q + for q in quantizers + if q is not None + and getattr(q, "is_enabled", False) + and getattr(q, "_amax", None) is not None + and not q._amax.is_meta + ] + if len(valid) < 2: + return False + # Require scalar (per-tensor) amax — matches preprocess_linear_fusion. + if any(q._amax.numel() != 1 for q in valid): + warnings.warn( + "sync_tied_input_amax: non-scalar input_quantizer amax encountered " + "in a tied group; skipping. Only per-tensor input quantizers are " + "supported for tied-modules merging." + ) + return False + merged = torch.max(torch.stack([q.amax for q in valid])) + for q in valid: + q.amax = merged.clone() + return True + + synced = 0 + for key, modules in by_dp.items(): + if len(modules) < 2: + continue + if key[0] == "moe": + for q_name in ("gate_up_proj_input_quantizer", "down_proj_input_quantizer"): + if _merge([getattr(m, q_name, None) for m in modules]): + synced += 1 + elif _merge([m.input_quantizer for m in modules]): + synced += 1 + return synced + + def _process_quantized_modules( model: nn.Module, dtype: torch.dtype, @@ -815,7 +998,11 @@ def _process_quantized_modules( def _export_transformers_checkpoint( - model: nn.Module, dtype: torch.dtype | None = None, is_modelopt_qlora: bool = False, **kwargs + model: nn.Module, + dtype: torch.dtype | None = None, + is_modelopt_qlora: bool = False, + canonical_tied_naming: bool = False, + **kwargs, ) -> tuple[dict[str, Any], dict[str, Any]]: """Exports the torch model to the packed checkpoint with original HF naming. @@ -935,6 +1122,15 @@ def _export_transformers_checkpoint( f"Taking element-wise max of amaxes for serving-engine fusion." ) + # Merge per-side input_quantizer amaxes BEFORE _process_quantized_modules, + # so the merged value flows into input_scale derivation downstream. + synced_input = sync_tied_input_amax(model) + if synced_input: + print( + f"sync_tied_input_amax: max-merged input_quantizer amaxes across " + f"{synced_input} tied module group(s)" + ) + # Process all quantized modules and export weights _process_quantized_modules(model, dtype, is_modelopt_qlora) @@ -952,6 +1148,16 @@ def _export_transformers_checkpoint( # We define kv cache scale as amax / 448 for both FP8 and NVFP4 KV cache quantization. kv_cache_max_bound = 448 kv_cache_format = quant_config["quantization"]["kv_cache_quant_algo"] + + # Optionally reorder so canonical-side tied keys (per HF's + # _tied_weights_keys) iterate first into postprocess_state_dict's + # first-wins data_ptr dedup. Off by default to avoid renaming exported + # keys for models whose downstream consumers expect the legacy + # (registration-order) winner; opt in for models where matching HF's + # own naming convention matters (e.g. DiffusionGemma4 → decoder names). + if canonical_tied_naming: + quantized_state_dict = _reorder_canonical_first(quantized_state_dict, model) + quantized_state_dict = postprocess_state_dict( quantized_state_dict, kv_cache_max_bound, kv_cache_format, is_modelopt_qlora ) @@ -1289,6 +1495,7 @@ def export_hf_checkpoint( components: list[str] | None = None, extra_state_dict: dict[str, torch.Tensor] | None = None, max_shard_size: int | str = "10GB", + canonical_tied_naming: bool = False, **kwargs, ): """Export quantized HuggingFace model checkpoint (transformers or diffusers). @@ -1308,6 +1515,11 @@ def export_hf_checkpoint( to export. If None, all quantized components are exported. extra_state_dict: Extra state dictionary to add to the exported model. max_shard_size: Maximum size of each safetensors shard file. Defaults to "10GB". + canonical_tied_naming: If True, reorder the state_dict so tied-weight + aliases dedup to the canonical side declared in the model's HF + ``_tied_weights_keys`` (e.g. decoder-side for DiffusionGemma4). + Off by default to avoid renaming exported keys for models whose + downstream consumers expect the legacy (registration-order) winner. **kwargs: Runtime-specific post-processing options forwarded to :func:`_postprocess_safetensors` for diffusion model exports. See its docstring for supported keys. @@ -1330,7 +1542,9 @@ def export_hf_checkpoint( return try: - post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype) + post_state_dict, hf_quant_config = _export_transformers_checkpoint( + model, dtype, canonical_tied_naming=canonical_tied_naming + ) # Only treat the export as quantized when at least one quant_algo field is set. # get_quant_config always returns a dict (even for sparsity-only or unmodified models), diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index e9b53897fee..9dd608fee9e 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -1208,7 +1208,10 @@ def create_forward_loop( def model_type_is_enc_dec(model): - enc_dec_model_list = ["t5", "bart", "whisper"] + # Substring match against `model.__class__.__name__.lower()` — entries are + # the lowercased class-name form (no underscores). Calibration then uses + # `model.generate` to run the full denoising loop. + enc_dec_model_list = ["t5", "bart", "whisper", "diffusiongemma"] return any(model_name in model.__class__.__name__.lower() for model_name in enc_dec_model_list) diff --git a/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml b/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml index e2efcb5142d..776ceeb9c72 100644 --- a/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml +++ b/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml @@ -73,3 +73,9 @@ - parent_class: 'nn.Embedding' quantizer_name: '*' enable: false + # Diffusion self-conditioning network: text-only and not exercised by + # typical calibration; without exclusion its TensorQuantizers never see + # input and export crashes with "AttributeError: '...' has no attribute + # '_amax'". Companion to the vision excludes above. + - quantizer_name: '*self_conditioning*' + enable: false diff --git a/tests/_test_utils/torch/quantization/tied_modules.py b/tests/_test_utils/torch/quantization/tied_modules.py new file mode 100644 index 00000000000..8ea76d2d459 --- /dev/null +++ b/tests/_test_utils/torch/quantization/tied_modules.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Factories for tied-weight test scenarios. + +These build small synthetic modules whose ``.weight`` :class:`nn.Parameter` is +shared between two sibling modules — mimicking HuggingFace's +``_tied_weights_keys`` machinery — for unit-testing the export-time dedup, +canonical-side naming, and per-side ``input_quantizer.amax`` merge logic in +the HF export path. + +Every factory returns CPU-resident, float32-default modules; no GPU required. +Each factory asserts its own post-conditions before returning, so a broken +tie surfaces as a clear factory-side error rather than as a downstream test +failure with an ambiguous cause. +""" + +import re + +import torch.nn as nn + + +def make_tied_linear_pair( + in_features: int = 16, + out_features: int = 32, + bias: bool = False, +) -> tuple[nn.Linear, nn.Linear]: + """Two :class:`nn.Linear` modules whose ``.weight`` Parameter is shared. + + Mimics what HuggingFace's :meth:`PreTrainedModel.tie_weights` does after + ``__init__``: one extra ``setattr`` so that both modules' ``.weight`` + attributes resolve to the same :class:`nn.Parameter` and therefore the + same underlying storage. The modules are otherwise independent — separate + biases (if requested), separate forward/training state, separate + quantizer slots when ``mtq.quantize`` inserts them later. + """ + enc = nn.Linear(in_features, out_features, bias=bias) + dec = nn.Linear(in_features, out_features, bias=bias) + dec.weight = enc.weight # mimics HF tie_weights() + + # Post-conditions — fail loudly if the tie was somehow lost. + assert enc.weight is dec.weight, "Linear weights not tied (object identity)" + assert enc.weight.data_ptr() == dec.weight.data_ptr(), ( + "Linear weights tied at object level but storage diverged" + ) + return enc, dec + + +def tie_fused_experts_3d_params(enc: nn.Module, dec: nn.Module) -> None: + """Tie ``gate_up_proj`` and ``down_proj`` between two fused-experts modules. + + Mutates ``dec`` in place. After calling, ``dec.gate_up_proj`` IS + ``enc.gate_up_proj`` (same :class:`nn.Parameter`) and likewise for + ``down_proj``. Used by MoE-dedup tests together with the + ``_SyntheticFusedExperts`` fixture defined in + ``tests/unit/torch/quantization/plugins/test_fused_experts.py``. + """ + dec.gate_up_proj = enc.gate_up_proj + dec.down_proj = enc.down_proj + + assert enc.gate_up_proj is dec.gate_up_proj, "gate_up_proj not tied" + assert enc.down_proj is dec.down_proj, "down_proj not tied" + assert enc.gate_up_proj.data_ptr() == dec.gate_up_proj.data_ptr() + assert enc.down_proj.data_ptr() == dec.down_proj.data_ptr() + + +def wrap_in_parent_with_tied_keys( + enc: nn.Module, + dec: nn.Module, + *, + decoder_canonical: bool = True, + weight_attr: str = "weight", +) -> nn.Module: + """Wrap two tied modules in a parent that declares HF ``_tied_weights_keys``. + + Returns a parent :class:`nn.Module` with: + + - ``parent.encoder = enc`` — registered as a submodule (alias side). + - ``parent.decoder = dec`` — registered as a submodule (canonical side + when ``decoder_canonical=True``, the default and DiffusionGemma-like case). + - ``parent._tied_weights_keys``: dict-style ``{alias_regex: canonical}`` + when ``decoder_canonical=True``, list-style (legacy, no canonical/alias + distinction) when ``decoder_canonical=False``. + + Used by tests for :func:`_collect_canonical_tied_patterns` and + :func:`_reorder_canonical_first`. The legacy list-style branch exercises + the "no patterns extracted" negative case. + """ + parent = nn.Module() + parent.encoder = enc + parent.decoder = dec + + if decoder_canonical: + # Dict-style: regex pattern → canonical path. Mimics HF's per-class + # ``_tied_weights_keys`` declaration for an encoder/decoder model. + parent._tied_weights_keys = { + rf"^encoder\.{re.escape(weight_attr)}$": f"decoder.{weight_attr}", + } + else: + # Legacy list-style: just a list of tied paths, no canonical info. + parent._tied_weights_keys = [f"encoder.{weight_attr}"] + + return parent diff --git a/tests/unit/torch/export/test_unified_export_hf.py b/tests/unit/torch/export/test_unified_export_hf.py new file mode 100644 index 00000000000..1d08a8ef620 --- /dev/null +++ b/tests/unit/torch/export/test_unified_export_hf.py @@ -0,0 +1,184 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for tied-weight helpers in unified_export_hf.""" + +from collections import OrderedDict + +import torch +from _test_utils.torch.quantization.tied_modules import ( + make_tied_linear_pair, + wrap_in_parent_with_tied_keys, +) + +import modelopt.torch.quantization as mtq +from modelopt.torch.export.unified_export_hf import ( + _collect_canonical_tied_patterns, + _export_quantized_weight, + _reorder_canonical_first, + sync_tied_input_amax, +) + + +def test_collect_canonical_tied_patterns_dict_style(): + """Dict-style _tied_weights_keys yields regex patterns + canonical-side substrings.""" + enc, dec = make_tied_linear_pair() + parent = wrap_in_parent_with_tied_keys(enc, dec, decoder_canonical=True) + + patterns, side_substrings = _collect_canonical_tied_patterns(parent) + + assert len(patterns) >= 1 + # "decoder" is in the canonical RHS but not the alias LHS — must auto-derive. + # "encoder" is alias-only and must NOT be returned as canonical (would invert dedup). + assert "decoder" in side_substrings + assert "encoder" not in side_substrings + + +def test_collect_canonical_tied_patterns_list_style_yields_no_canonical_info(): + """Legacy list-style _tied_weights_keys carries no canonical/alias info — returns empty.""" + enc, dec = make_tied_linear_pair() + parent = wrap_in_parent_with_tied_keys(enc, dec, decoder_canonical=False) + + patterns, side_substrings = _collect_canonical_tied_patterns(parent) + + assert patterns == [] + assert side_substrings == [] + + +def test_reorder_canonical_first_puts_decoder_keys_before_encoder_keys(): + """_reorder_canonical_first moves canonical-side state_dict keys ahead of alias-side keys.""" + enc, dec = make_tied_linear_pair() + parent = wrap_in_parent_with_tied_keys(enc, dec, decoder_canonical=True) + + sd = OrderedDict( + [ + ("encoder.weight", torch.zeros(1)), + ("unrelated.foo", torch.zeros(1)), + ("decoder.weight", torch.zeros(1)), + ] + ) + + reordered = _reorder_canonical_first(sd, parent) + keys = list(reordered.keys()) + + assert keys.index("decoder.weight") < keys.index("encoder.weight") + assert set(reordered) == set(sd) # no drops or additions + + +def _quantize_and_get_input_quantizers(parent): + """Insert FP8 quantizers via no-op forward_loop and return both input_quantizers.""" + mtq.quantize(parent, mtq.FP8_DEFAULT_CFG, forward_loop=lambda m: None) + return parent.encoder.input_quantizer, parent.decoder.input_quantizer + + +def test_sync_tied_input_amax_max_merges_tied_module_amaxes_in_place(): + """Tied Linears with divergent input_quantizer.amax get both sides overwritten with the max.""" + enc, dec = make_tied_linear_pair() + parent = wrap_in_parent_with_tied_keys(enc, dec, decoder_canonical=True) + enc_q, dec_q = _quantize_and_get_input_quantizers(parent) + + enc_q.amax = torch.tensor(2.0) + dec_q.amax = torch.tensor(5.0) + + sync_tied_input_amax(parent) + + expected = torch.tensor(5.0) + assert torch.allclose(enc_q.amax, expected) + assert torch.allclose(dec_q.amax, expected) + + +def test_sync_tied_input_amax_no_op_for_untied_modules(): + """Untied Linears keep their per-side amaxes — the helper is a no-op when there's no tie.""" + parent = torch.nn.Module() + parent.encoder = torch.nn.Linear(16, 32, bias=False) + parent.decoder = torch.nn.Linear(16, 32, bias=False) + enc_q, dec_q = _quantize_and_get_input_quantizers(parent) + + enc_q.amax = torch.tensor(2.0) + dec_q.amax = torch.tensor(5.0) + + sync_tied_input_amax(parent) + + assert torch.allclose(enc_q.amax, torch.tensor(2.0)) + assert torch.allclose(dec_q.amax, torch.tensor(5.0)) + + +def _clear_export_quantized_weight_cache(): + """Clear the function-static alias cache; isolates each test from prior session state.""" + _export_quantized_weight.__dict__.pop("_tied_weight_alias_cache", None) + + +def _calibrate_through_both_children(parent): + """Insert NVFP4 quantizers and run a one-shot forward through both children for calibration.""" + + def forward_loop(m): + x = torch.randn(2, 16) + m.encoder(x) + m.decoder(x) + + mtq.quantize(parent, mtq.NVFP4_DEFAULT_CFG, forward_loop=forward_loop) + + +def test_export_quantized_weight_aliases_packed_weight_for_tied_linears(): + """Tied Linears share data_ptr for packed .weight and scale buffers after export.""" + _clear_export_quantized_weight_cache() + + enc, dec = make_tied_linear_pair() + parent = wrap_in_parent_with_tied_keys(enc, dec) + _calibrate_through_both_children(parent) + + _export_quantized_weight(enc, torch.float16, "weight") + _export_quantized_weight(dec, torch.float16, "weight") + + assert enc.weight.data_ptr() == dec.weight.data_ptr() + for scale_attr in ("weight_scale", "weight_scale_2"): + if hasattr(enc, scale_attr) and hasattr(dec, scale_attr): + assert getattr(enc, scale_attr).data_ptr() == getattr(dec, scale_attr).data_ptr() + + +def test_export_quantized_weight_no_alias_for_untied_linears(): + """Untied Linears keep independent data_ptrs after export — no false-positive aliasing.""" + _clear_export_quantized_weight_cache() + + parent = torch.nn.Module() + parent.encoder = torch.nn.Linear(16, 32, bias=False) + parent.decoder = torch.nn.Linear(16, 32, bias=False) + assert parent.encoder.weight.data_ptr() != parent.decoder.weight.data_ptr() + _calibrate_through_both_children(parent) + + _export_quantized_weight(parent.encoder, torch.float16, "weight") + _export_quantized_weight(parent.decoder, torch.float16, "weight") + + assert parent.encoder.weight.data_ptr() != parent.decoder.weight.data_ptr() + + +def test_export_quantized_weight_skips_alias_when_one_tied_side_is_unquantized(): + """Unquantized side early-returns; its .weight stays at the original shared Parameter.""" + _clear_export_quantized_weight_cache() + + enc, dec = make_tied_linear_pair() + parent = wrap_in_parent_with_tied_keys(enc, dec) + original_shared_data_ptr = enc.weight.data_ptr() + + _calibrate_through_both_children(parent) + # is_enabled is a read-only property; .disable() is the canonical bypass. + dec.weight_quantizer.disable() + + _export_quantized_weight(enc, torch.float16, "weight") + _export_quantized_weight(dec, torch.float16, "weight") + + assert enc.weight.data_ptr() != original_shared_data_ptr # encoder got fresh packed + assert dec.weight.data_ptr() == original_shared_data_ptr # decoder untouched + assert enc.weight.data_ptr() != dec.weight.data_ptr() diff --git a/tests/unit/torch/quantization/plugins/test_fused_experts.py b/tests/unit/torch/quantization/plugins/test_fused_experts.py index ce23f7a51d5..ef2df36090e 100644 --- a/tests/unit/torch/quantization/plugins/test_fused_experts.py +++ b/tests/unit/torch/quantization/plugins/test_fused_experts.py @@ -22,6 +22,8 @@ pytest.importorskip("transformers") +from _test_utils.torch.quantization.tied_modules import tie_fused_experts_3d_params + import modelopt.torch.quantization as mtq from modelopt.torch.export.moe_utils import _export_fused_experts from modelopt.torch.export.quant_utils import get_quant_config @@ -514,6 +516,115 @@ def _spy_export(wrapper, dtype): QuantModuleRegistry.unregister(expert_type) +# --------------------------------------------------------------------------- +# Tests for tied-experts dedup in _export_fused_experts +# --------------------------------------------------------------------------- +def _build_two_moe_blocks(tie: bool) -> nn.Module: + """Build a parent with two _SyntheticSparseMoeBlock children, optionally with tied 3-D params.""" + parent = nn.Module() + parent.encoder = _SyntheticSparseMoeBlock() + parent.decoder = _SyntheticSparseMoeBlock() + if tie: + tie_fused_experts_3d_params(parent.encoder.experts, parent.decoder.experts) + return parent + + +def _moe_fp8_quant_cfg(): + """Custom inline FP8 cfg targeting the MoE-specific quantizer names.""" + return { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + { + "quantizer_name": "*gate_up_proj_input_quantizer", + "cfg": {"num_bits": 8, "axis": None}, + }, + {"quantizer_name": "*down_proj_input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + {"quantizer_name": "*gate_up_proj_weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_name": "*down_proj_weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + ], + "algorithm": "max", + } + + +def _calibrate_two_moe_blocks(parent): + """Fire one calibration batch through both encoder.experts and decoder.experts.""" + + def forward_loop(m): + torch.manual_seed(0) + x = torch.randn(1, 4, HIDDEN_DIM) + m.encoder(x) + m.decoder(x) + + mtq.quantize(parent, _moe_fp8_quant_cfg(), forward_loop=forward_loop) + + +def _clear_fused_experts_caches(): + """Clear function-static alias caches in both export entry points.""" + _export_fused_experts.__dict__.pop("_tied_unpacked_cache", None) + # _export_fused_experts internally calls _export_quantized_weight per per-expert + # wrapper; clear that cache too so each test sees a pristine state. + from modelopt.torch.export.unified_export_hf import _export_quantized_weight + + _export_quantized_weight.__dict__.pop("_tied_weight_alias_cache", None) + + +class TestExportFusedExpertsTiedDedup: + @staticmethod + def _cleanup_registry(mod_type): + if QuantModuleRegistry.get(mod_type) is not None: + QuantModuleRegistry.unregister(mod_type) + + def test_per_expert_buffers_share_data_ptr_for_tied_fused_experts(self): + """Two tied FusedExperts modules: every per-expert .weight + scale buffer shares data_ptr.""" + _clear_fused_experts_caches() + parent = _build_two_moe_blocks(tie=True) + expert_type = type(parent.encoder.experts) + self._cleanup_registry(expert_type) + try: + _calibrate_two_moe_blocks(parent) + + _export_fused_experts(parent.encoder.experts, torch.float16) + _export_fused_experts(parent.decoder.experts, torch.float16) + + for idx in range(NUM_EXPERTS): + enc_expert = getattr(parent.encoder.experts, str(idx)) + dec_expert = getattr(parent.decoder.experts, str(idx)) + for proj_name in ("gate_proj", "up_proj", "down_proj"): + enc_proj = getattr(enc_expert, proj_name) + dec_proj = getattr(dec_expert, proj_name) + assert enc_proj.weight.data_ptr() == dec_proj.weight.data_ptr() + for scale_attr in ("weight_scale", "weight_scale_2"): + if hasattr(enc_proj, scale_attr) and hasattr(dec_proj, scale_attr): + assert ( + getattr(enc_proj, scale_attr).data_ptr() + == getattr(dec_proj, scale_attr).data_ptr() + ) + finally: + self._cleanup_registry(expert_type) + + def test_per_expert_buffers_have_independent_data_ptrs_for_untied_fused_experts(self): + """Two untied FusedExperts modules: per-expert buffers stay independent (no false-positive alias).""" + _clear_fused_experts_caches() + parent = _build_two_moe_blocks(tie=False) + expert_type = type(parent.encoder.experts) + self._cleanup_registry(expert_type) + try: + _calibrate_two_moe_blocks(parent) + + _export_fused_experts(parent.encoder.experts, torch.float16) + _export_fused_experts(parent.decoder.experts, torch.float16) + + for idx in range(NUM_EXPERTS): + enc_expert = getattr(parent.encoder.experts, str(idx)) + dec_expert = getattr(parent.decoder.experts, str(idx)) + for proj_name in ("gate_proj", "up_proj", "down_proj"): + enc_proj = getattr(enc_expert, proj_name) + dec_proj = getattr(dec_expert, proj_name) + assert enc_proj.weight.data_ptr() != dec_proj.weight.data_ptr() + finally: + self._cleanup_registry(expert_type) + + # --------------------------------------------------------------------------- # Tests for force_eager_experts_impl_on_the_fly # ---------------------------------------------------------------------------