Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion tests/models/test_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
* Architecture-hint fields on ``cls.fast_llm_config_class`` are all consumed by some declaration.
* OptionalConfigConverter sentinels match the resolved field default. Otherwise an exported value equal
to the sentinel becomes absent on disk and re-imports as a different default, silently breaking round-trip.

Plus an end-to-end weight-coverage walker (:func:`test_format_weight_coverage`) — for each test
fixture with a checkpoint format, materialise the Fast-LLM model and assert every parameter is consumed
by some leaf :class:`WeightConverter`. Catches the "silent drop" failure mode where a model param has
no converter and ``_convert_state_dict`` skips it on export.
"""

import typing
Expand All @@ -24,9 +29,11 @@
_safe_set_nested_dict_value,
)
from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.layers.attention.config import AttentionConfig
from fast_llm.layers.block.config import PatternBlockSequenceConfig
from fast_llm.layers.decoder.config import DecoderBlockConfig, StochasticMixerConfig
from tests.utils.model_configs import MODEL_CONFIGS

# Configs that don't default-construct cleanly need a minimal-valid factory.
_DEFAULT_FACTORIES: dict[type, typing.Callable[[], typing.Any]] = {
Expand Down Expand Up @@ -156,6 +163,51 @@ def test_safe_set_nested_dict_value_collision() -> None:
_safe_set_nested_dict_value(out, ("nested", "key"), "other")


@pytest.mark.parametrize(
"fixture_name", [name for name, cfg in MODEL_CONFIGS.items() if cfg.checkpoint_format is not None]
)
def test_format_weight_coverage(fixture_name: str) -> None:
"""Every Fast-LLM parameter must be consumed by some :class:`WeightConverter`.

Materialises the fixture's base model (CPU, meta tensors via ``ParameterMeta`` — no distributed
setup) and compares ``named_parameters()`` against the set of ``fast_llm_name`` entries emitted by
``base_model_converter_class.get_converters(config)``. Runtime-tied parameters
(``BaseModel.get_tied_parameters``) count as covered if any member of their group has a converter,
matching the export-time behaviour where a single shared weight is serialised once.
"""
model_testing_config = MODEL_CONFIGS[fixture_name]
handler = model_testing_config.checkpoint_format.get_handler_class()
base_model_config = model_testing_config.base_model_config_class.from_dict(
model_testing_config.config_dict["model"]["base_model"]
)
base_model = base_model_config.base_model_class(base_model_config, DistributedConfig())

param_id_to_name = {id(parameter): name for name, parameter in base_model.named_parameters()}
model_names = set(param_id_to_name.values())
tied_groups = [
frozenset(param_id_to_name[id(parameter)] for parameter in parameters)
for parameters in base_model.get_tied_parameters().values()
]

consumed: set[str] = set()
for leaf in handler.base_model_converter_class.get_converters(base_model_config):
consumed.update(leaf.fast_llm_name)

# Tied closure: any group with at least one explicit consumer is covered in full.
covered = set(consumed)
for group in tied_groups:
if group & consumed:
covered |= group

missing = sorted(model_names - covered)
phantom = sorted(consumed - model_names)
assert not missing and not phantom, (
f"{handler.__name__}: weight coverage mismatch — "
f"Fast-LLM params with no converter: {missing}; "
f"converters with no matching param: {phantom}"
)


def test_llama_export_rejects_mismatched_block_and_head_norm_epsilon() -> None:
"""End-to-end regression: a Llama config with mismatched block/head normalization epsilon must fail to
export. Both the decoder Custom and the head Nested write ``rms_norm_eps`` into the same HF dict; a
Expand All @@ -164,7 +216,6 @@ def test_llama_export_rejects_mismatched_block_and_head_norm_epsilon() -> None:

from fast_llm.models.gpt.config import GPTBaseModelConfig
from fast_llm.models.gpt.conversion.llama import LlamaBaseModelConverter
from tests.utils.model_configs import MODEL_CONFIGS

cfg = copy.deepcopy(MODEL_CONFIGS["llama"].config_dict["model"]["base_model"])
# Default head normalization inherits the block default (1e-5); pin head to a different value.
Expand Down
13 changes: 9 additions & 4 deletions tests/utils/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,22 +1051,27 @@ def update_and_add_testing_config(
("model", "base_model", "decoder"): {
"type": "pattern",
"blocks": {
# Sub-dicts in ``_gemma4_block_overrides`` / ``_gemma4_mixer_overrides`` are deepcopied
# per block. Without this the two blocks' nested override dicts would alias, and
# ``Config._from_dict`` (which mutates its input via ``pop``) would consume the
# shared sub-dicts when processing the first block, leaving the second to silently
# fall back to type defaults (LayerNorm / output_scale.enabled=None).
"sliding_attention": {
**copy.deepcopy(_llama_block),
**_gemma4_block_overrides,
**copy.deepcopy(_gemma4_block_overrides),
"mixer": {
**copy.deepcopy(_llama_block["mixer"]),
**_gemma4_mixer_overrides,
**copy.deepcopy(_gemma4_mixer_overrides),
"window_size": 128,
},
"mlp": copy.deepcopy(_gemma4_moe_mlp),
},
"full_attention": {
**copy.deepcopy(_llama_block),
**_gemma4_block_overrides,
**copy.deepcopy(_gemma4_block_overrides),
"mixer": {
**copy.deepcopy(_llama_block["mixer"]),
**_gemma4_mixer_overrides,
**copy.deepcopy(_gemma4_mixer_overrides),
"rotary": {"type": "proportional", "partial_rotary_factor": 0.25},
},
"mlp": copy.deepcopy(_gemma4_moe_mlp),
Expand Down
Loading