From 43a9abca6624f6d3e4080867179ff4ba681c0a66 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Thu, 11 Jun 2026 15:18:47 -0700 Subject: [PATCH] [6289151] Fix exported Step layer type metadata Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/export/plugins/__init__.py | 14 +++- .../export/plugins/hf_checkpoint_utils.py | 78 +++++++++++++++++++ modelopt/torch/export/unified_export_hf.py | 4 +- .../torch/export/test_hf_checkpoint_utils.py | 78 ++++++++++++++++++- 4 files changed, 169 insertions(+), 5 deletions(-) diff --git a/modelopt/torch/export/plugins/__init__.py b/modelopt/torch/export/plugins/__init__.py index 99a755598ac..000681b8592 100644 --- a/modelopt/torch/export/plugins/__init__.py +++ b/modelopt/torch/export/plugins/__init__.py @@ -21,10 +21,18 @@ from .megatron_importer import * from .hf_spec_export import * + +with import_plugin("hf_checkpoint_utils"): + from .hf_checkpoint_utils import * + +if "sanitize_hf_config_for_deployment" not in globals(): + + def sanitize_hf_config_for_deployment(config_data, model): + """No-op fallback when Hugging Face checkpoint utilities are unavailable.""" + return None + + from .vllm_fakequant_hf import * with import_plugin("vllm_fakequant_megatron"): from .vllm_fakequant_megatron import * - -with import_plugin("hf_checkpoint_utils"): - from .hf_checkpoint_utils import * diff --git a/modelopt/torch/export/plugins/hf_checkpoint_utils.py b/modelopt/torch/export/plugins/hf_checkpoint_utils.py index e8c7bb945c7..977acb8b150 100644 --- a/modelopt/torch/export/plugins/hf_checkpoint_utils.py +++ b/modelopt/torch/export/plugins/hf_checkpoint_utils.py @@ -18,7 +18,9 @@ import json import os import shutil +import warnings from pathlib import Path +from typing import Any import torch from huggingface_hub import snapshot_download @@ -29,6 +31,82 @@ _HF_HUB_OFFLINE_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} +def _as_nonnegative_int(value: Any) -> int | None: + """Return ``value`` as an int when it is a non-negative integer.""" + if isinstance(value, bool): + return None + if isinstance(value, int) and value >= 0: + return value + return None + + +def _count_mtp_layer_prefixes(prefixes: list[Any] | tuple[Any, ...]) -> int | None: + """Count actual MTP layer prefixes, excluding broad prefixes like ``mtp``.""" + layer_prefixes = { + prefix + for prefix in prefixes + if isinstance(prefix, str) + and (parts := prefix.split(".")) + and len(parts) >= 2 + and parts[-2] == "layers" + and parts[-1].isdigit() + } + return len(layer_prefixes) or None + + +def _get_num_nextn_predict_layers(config_data: dict[str, Any], model: Any) -> int | None: + """Get the number of next-token-prediction layers from config metadata.""" + num_nextn_predict_layers = _as_nonnegative_int(config_data.get("num_nextn_predict_layers")) + if num_nextn_predict_layers is not None: + return num_nextn_predict_layers + + model_config = getattr(model, "config", None) + if model_config is not None: + num_nextn_predict_layers = _as_nonnegative_int( + getattr(model_config, "num_nextn_predict_layers", None) + ) + if num_nextn_predict_layers is not None: + return num_nextn_predict_layers + + mtp_layer_prefixes = getattr(model, "_mtp_layer_prefixes", None) + if isinstance(mtp_layer_prefixes, (list, tuple)): + return _count_mtp_layer_prefixes(mtp_layer_prefixes) + + return None + + +def sanitize_hf_config_for_deployment(config_data: dict[str, Any], model: Any) -> None: + """Sanitize exported Hugging Face config metadata for deployment runtimes. + + Transformers 5.x validates that ``len(layer_types) == num_hidden_layers``. + Some checkpoints include MTP/next-token-prediction layer entries after the + main decoder layer entries. Those extra entries are not part of + ``num_hidden_layers`` and make deployment stacks fail while loading config. + """ + num_hidden_layers = _as_nonnegative_int(config_data.get("num_hidden_layers")) + layer_types = config_data.get("layer_types") + if num_hidden_layers is None or not isinstance(layer_types, list): + return + + num_layer_types = len(layer_types) + if num_layer_types == num_hidden_layers: + return + + num_nextn_predict_layers = _get_num_nextn_predict_layers(config_data, model) + if ( + num_layer_types > num_hidden_layers + and num_nextn_predict_layers == num_layer_types - num_hidden_layers + ): + warnings.warn( + "Trimming config.layer_types from " + f"{num_layer_types} to {num_hidden_layers} entries so it matches " + "num_hidden_layers; the removed entries correspond to " + "num_nextn_predict_layers.", + stacklevel=2, + ) + config_data["layer_types"] = layer_types[:num_hidden_layers] + + def _is_hf_hub_offline() -> bool: return os.environ.get("HF_HUB_OFFLINE", "").strip().upper() in _HF_HUB_OFFLINE_TRUE_VALUES diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index ef5757aa0cb..89b06338f00 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -90,7 +90,7 @@ ) from .model_utils import get_language_model_from_vl, is_multimodal_model from .moe_utils import _export_fused_experts -from .plugins import SpeculativeDecodingExporter, has_spec_opt +from .plugins import SpeculativeDecodingExporter, has_spec_opt, sanitize_hf_config_for_deployment from .quant_utils import ( fuse_prequant_layernorm, fuse_prequant_to_linear, @@ -1381,6 +1381,8 @@ def export_hf_checkpoint( with open(original_config) as file: config_data = json.load(file) + sanitize_hf_config_for_deployment(config_data, model) + if hf_quant_config is not None: config_data["quantization_config"] = hf_quant_config diff --git a/tests/unit/torch/export/test_hf_checkpoint_utils.py b/tests/unit/torch/export/test_hf_checkpoint_utils.py index 33d17eebb3d..2e571714560 100644 --- a/tests/unit/torch/export/test_hf_checkpoint_utils.py +++ b/tests/unit/torch/export/test_hf_checkpoint_utils.py @@ -15,6 +15,7 @@ """Tests for modelopt/torch/export/plugins/hf_checkpoint_utils.py""" +from types import SimpleNamespace from unittest.mock import patch import pytest @@ -23,7 +24,7 @@ hf_hub_errors = pytest.importorskip("huggingface_hub.errors") LocalEntryNotFoundError = hf_hub_errors.LocalEntryNotFoundError -from modelopt.torch.export import copy_hf_ckpt_remote_code +from modelopt.torch.export import copy_hf_ckpt_remote_code, sanitize_hf_config_for_deployment def test_copy_hf_ckpt_remote_code_local_dir(tmp_path): @@ -118,3 +119,78 @@ def test_copy_hf_ckpt_remote_code_hub_id_offline_missing_cache_raises(tmp_path, pytest.raises(RuntimeError, match="HF_HUB_OFFLINE"), ): copy_hf_ckpt_remote_code("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", tmp_path / "dst") + + +def test_sanitize_hf_config_for_deployment_trims_nextn_layer_types(): + """Drop MTP/next-token-prediction layer types from exported config.json.""" + hidden_layer_types = ["full_attention"] * 45 + nextn_layer_types = ["nextn_predict"] * 3 + config_data = { + "num_hidden_layers": 45, + "num_nextn_predict_layers": 3, + "layer_types": hidden_layer_types + nextn_layer_types, + } + + with pytest.warns(UserWarning, match="Trimming config.layer_types"): + sanitize_hf_config_for_deployment(config_data, model=SimpleNamespace()) + + assert config_data["layer_types"] == hidden_layer_types + + +def test_sanitize_hf_config_for_deployment_uses_model_config_nextn_count(): + """Handle exports where save_pretrained omits num_nextn_predict_layers.""" + config_data = { + "num_hidden_layers": 2, + "layer_types": ["full_attention", "linear_attention", "nextn_predict"], + } + model = SimpleNamespace(config=SimpleNamespace(num_nextn_predict_layers=1)) + + with pytest.warns(UserWarning, match="Trimming config.layer_types"): + sanitize_hf_config_for_deployment(config_data, model=model) + + assert config_data["layer_types"] == ["full_attention", "linear_attention"] + + +def test_sanitize_hf_config_for_deployment_counts_mtp_layer_prefixes(): + """Do not count broad MTP exclude prefixes as prediction layers.""" + config_data = { + "num_hidden_layers": 2, + "layer_types": ["full_attention", "linear_attention", "nextn_predict"], + } + model = SimpleNamespace(_mtp_layer_prefixes=["mtp", "mtp.layers.0"]) + + with pytest.warns(UserWarning, match="Trimming config.layer_types"): + sanitize_hf_config_for_deployment(config_data, model=model) + + assert config_data["layer_types"] == ["full_attention", "linear_attention"] + + +def test_sanitize_hf_config_for_deployment_ignores_broad_mtp_prefix_only(): + """Do not infer prediction-layer count from a broad exclude prefix alone.""" + config_data = { + "num_hidden_layers": 2, + "layer_types": ["full_attention", "linear_attention", "nextn_predict"], + } + model = SimpleNamespace(_mtp_layer_prefixes=["mtp"]) + + sanitize_hf_config_for_deployment(config_data, model=model) + + assert config_data["layer_types"] == ["full_attention", "linear_attention", "nextn_predict"] + + +def test_sanitize_hf_config_for_deployment_keeps_unexplained_layer_type_mismatch(): + """Do not rewrite config when extra layer types are not explained by nextn metadata.""" + config_data = { + "num_hidden_layers": 2, + "num_nextn_predict_layers": 1, + "layer_types": ["full_attention", "linear_attention", "extra_a", "extra_b"], + } + + sanitize_hf_config_for_deployment(config_data, model=SimpleNamespace()) + + assert config_data["layer_types"] == [ + "full_attention", + "linear_attention", + "extra_a", + "extra_b", + ]