Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions modelopt/torch/export/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,18 @@
from .megatron_importer import *

from .hf_spec_export import *

with import_plugin("hf_checkpoint_utils"):
from .hf_checkpoint_utils import *

if "sanitize_hf_config_for_deployment" not in globals():

def sanitize_hf_config_for_deployment(config_data, model):
"""No-op fallback when Hugging Face checkpoint utilities are unavailable."""
return None


from .vllm_fakequant_hf import *

with import_plugin("vllm_fakequant_megatron"):
from .vllm_fakequant_megatron import *

with import_plugin("hf_checkpoint_utils"):
from .hf_checkpoint_utils import *
78 changes: 78 additions & 0 deletions modelopt/torch/export/plugins/hf_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import json
import os
import shutil
import warnings
from pathlib import Path
from typing import Any

import torch
from huggingface_hub import snapshot_download
Expand All @@ -29,6 +31,82 @@
_HF_HUB_OFFLINE_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}


def _as_nonnegative_int(value: Any) -> int | None:
"""Return ``value`` as an int when it is a non-negative integer."""
if isinstance(value, bool):
return None
if isinstance(value, int) and value >= 0:
return value
return None


def _count_mtp_layer_prefixes(prefixes: list[Any] | tuple[Any, ...]) -> int | None:
"""Count actual MTP layer prefixes, excluding broad prefixes like ``mtp``."""
layer_prefixes = {
prefix
for prefix in prefixes
if isinstance(prefix, str)
and (parts := prefix.split("."))
and len(parts) >= 2
and parts[-2] == "layers"
and parts[-1].isdigit()
}
return len(layer_prefixes) or None


def _get_num_nextn_predict_layers(config_data: dict[str, Any], model: Any) -> int | None:
"""Get the number of next-token-prediction layers from config metadata."""
num_nextn_predict_layers = _as_nonnegative_int(config_data.get("num_nextn_predict_layers"))
if num_nextn_predict_layers is not None:
return num_nextn_predict_layers

model_config = getattr(model, "config", None)
if model_config is not None:
num_nextn_predict_layers = _as_nonnegative_int(
getattr(model_config, "num_nextn_predict_layers", None)
)
if num_nextn_predict_layers is not None:
return num_nextn_predict_layers

mtp_layer_prefixes = getattr(model, "_mtp_layer_prefixes", None)
if isinstance(mtp_layer_prefixes, (list, tuple)):
return _count_mtp_layer_prefixes(mtp_layer_prefixes)

return None


def sanitize_hf_config_for_deployment(config_data: dict[str, Any], model: Any) -> None:
"""Sanitize exported Hugging Face config metadata for deployment runtimes.

Transformers 5.x validates that ``len(layer_types) == num_hidden_layers``.
Some checkpoints include MTP/next-token-prediction layer entries after the
main decoder layer entries. Those extra entries are not part of
``num_hidden_layers`` and make deployment stacks fail while loading config.
"""
num_hidden_layers = _as_nonnegative_int(config_data.get("num_hidden_layers"))
layer_types = config_data.get("layer_types")
if num_hidden_layers is None or not isinstance(layer_types, list):
return

num_layer_types = len(layer_types)
if num_layer_types == num_hidden_layers:
return

num_nextn_predict_layers = _get_num_nextn_predict_layers(config_data, model)
if (
num_layer_types > num_hidden_layers
and num_nextn_predict_layers == num_layer_types - num_hidden_layers
):
warnings.warn(
"Trimming config.layer_types from "
f"{num_layer_types} to {num_hidden_layers} entries so it matches "
"num_hidden_layers; the removed entries correspond to "
"num_nextn_predict_layers.",
stacklevel=2,
)
config_data["layer_types"] = layer_types[:num_hidden_layers]


def _is_hf_hub_offline() -> bool:
return os.environ.get("HF_HUB_OFFLINE", "").strip().upper() in _HF_HUB_OFFLINE_TRUE_VALUES

Expand Down
4 changes: 3 additions & 1 deletion modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
)
from .model_utils import get_language_model_from_vl, is_multimodal_model
from .moe_utils import _export_fused_experts
from .plugins import SpeculativeDecodingExporter, has_spec_opt
from .plugins import SpeculativeDecodingExporter, has_spec_opt, sanitize_hf_config_for_deployment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard import of a symbol from an optionally-guarded plugin.

This now hard-imports sanitize_hf_config_for_deployment, which lives in hf_checkpoint_utils. In plugins/__init__.py that module is imported inside import_plugin(...) — a guard whose whole purpose is to tolerate the module failing to import. By contrast, SpeculativeDecodingExporter / has_spec_opt come from hf_spec_export, which is imported unguarded.

So if hf_checkpoint_utils ever fails under the guard, the symbol silently won't exist and this line raises ImportError, breaking all HF export — not just the Step path. In practice the deps (huggingface_hub, safetensors, tqdm) are always present for export, so the risk is low, but the import_plugin guard is partially defeated by taking a hard dependency on one of its symbols.

Note the import reorder in plugins/__init__.py doesn't close this gap — it only changes binding order; it doesn't make the symbol exist if the guarded import itself fails. Suggested fixes:

  • guard this import defensively (try/except ImportError → no-op fallback), or
  • if hf_checkpoint_utils is genuinely required for export, drop the import_plugin guard around it so failures are loud rather than silently amputating a symbol.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in d83dcb8. I removed the import_plugin guard around hf_checkpoint_utils in plugins/init.py, so sanitize_hf_config_for_deployment is no longer a hard import from an optionally amputated plugin symbol. This matches the existing direct hf_checkpoint_utils imports in export code: failures are now loud instead of silently dropping the symbol. Validated with ruff-format, ruff-check, py_compile, and the focused hf_checkpoint_utils unit tests.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in 43a9abc. I changed the resolution from making hf_checkpoint_utils unguarded to preserving the import_plugin guard and adding a no-op fallback for sanitize_hf_config_for_deployment when the guarded import does not provide it. That keeps modelopt.torch.export import paths resilient to missing optional deps while keeping the direct unified_export_hf import stable. Validated with ruff-format, ruff-check, py_compile, and the focused hf_checkpoint_utils unit tests.

from .quant_utils import (
fuse_prequant_layernorm,
fuse_prequant_to_linear,
Expand Down Expand Up @@ -1381,6 +1381,8 @@ def export_hf_checkpoint(
with open(original_config) as file:
config_data = json.load(file)

sanitize_hf_config_for_deployment(config_data, model)

if hf_quant_config is not None:
config_data["quantization_config"] = hf_quant_config

Expand Down
78 changes: 77 additions & 1 deletion tests/unit/torch/export/test_hf_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Tests for modelopt/torch/export/plugins/hf_checkpoint_utils.py"""

from types import SimpleNamespace
from unittest.mock import patch

import pytest
Expand All @@ -23,7 +24,7 @@
hf_hub_errors = pytest.importorskip("huggingface_hub.errors")
LocalEntryNotFoundError = hf_hub_errors.LocalEntryNotFoundError

from modelopt.torch.export import copy_hf_ckpt_remote_code
from modelopt.torch.export import copy_hf_ckpt_remote_code, sanitize_hf_config_for_deployment


def test_copy_hf_ckpt_remote_code_local_dir(tmp_path):
Expand Down Expand Up @@ -118,3 +119,78 @@ def test_copy_hf_ckpt_remote_code_hub_id_offline_missing_cache_raises(tmp_path,
pytest.raises(RuntimeError, match="HF_HUB_OFFLINE"),
):
copy_hf_ckpt_remote_code("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", tmp_path / "dst")


def test_sanitize_hf_config_for_deployment_trims_nextn_layer_types():
"""Drop MTP/next-token-prediction layer types from exported config.json."""
hidden_layer_types = ["full_attention"] * 45
nextn_layer_types = ["nextn_predict"] * 3
config_data = {
"num_hidden_layers": 45,
"num_nextn_predict_layers": 3,
"layer_types": hidden_layer_types + nextn_layer_types,
}

with pytest.warns(UserWarning, match="Trimming config.layer_types"):
sanitize_hf_config_for_deployment(config_data, model=SimpleNamespace())

assert config_data["layer_types"] == hidden_layer_types


def test_sanitize_hf_config_for_deployment_uses_model_config_nextn_count():
"""Handle exports where save_pretrained omits num_nextn_predict_layers."""
config_data = {
"num_hidden_layers": 2,
"layer_types": ["full_attention", "linear_attention", "nextn_predict"],
}
model = SimpleNamespace(config=SimpleNamespace(num_nextn_predict_layers=1))

with pytest.warns(UserWarning, match="Trimming config.layer_types"):
sanitize_hf_config_for_deployment(config_data, model=model)

assert config_data["layer_types"] == ["full_attention", "linear_attention"]


def test_sanitize_hf_config_for_deployment_counts_mtp_layer_prefixes():
"""Do not count broad MTP exclude prefixes as prediction layers."""
config_data = {
"num_hidden_layers": 2,
"layer_types": ["full_attention", "linear_attention", "nextn_predict"],
}
model = SimpleNamespace(_mtp_layer_prefixes=["mtp", "mtp.layers.0"])

with pytest.warns(UserWarning, match="Trimming config.layer_types"):
sanitize_hf_config_for_deployment(config_data, model=model)

assert config_data["layer_types"] == ["full_attention", "linear_attention"]


def test_sanitize_hf_config_for_deployment_ignores_broad_mtp_prefix_only():
"""Do not infer prediction-layer count from a broad exclude prefix alone."""
config_data = {
"num_hidden_layers": 2,
"layer_types": ["full_attention", "linear_attention", "nextn_predict"],
}
model = SimpleNamespace(_mtp_layer_prefixes=["mtp"])

sanitize_hf_config_for_deployment(config_data, model=model)

assert config_data["layer_types"] == ["full_attention", "linear_attention", "nextn_predict"]


def test_sanitize_hf_config_for_deployment_keeps_unexplained_layer_type_mismatch():
"""Do not rewrite config when extra layer types are not explained by nextn metadata."""
config_data = {
"num_hidden_layers": 2,
"num_nextn_predict_layers": 1,
"layer_types": ["full_attention", "linear_attention", "extra_a", "extra_b"],
}

sanitize_hf_config_for_deployment(config_data, model=SimpleNamespace())

assert config_data["layer_types"] == [
"full_attention",
"linear_attention",
"extra_a",
"extra_b",
]
Loading