Skip to content

fix(checkpoint): exclude TE _extra_state keys from load-time mismatch warning#2247

Open
adil-a wants to merge 1 commit into
mainfrom
adil-a/fix/extra-state-key-diff-warning
Open

fix(checkpoint): exclude TE _extra_state keys from load-time mismatch warning#2247
adil-a wants to merge 1 commit into
mainfrom
adil-a/fix/extra-state-key-diff-warning

Conversation

@adil-a
Copy link
Copy Markdown
Collaborator

@adil-a adil-a commented May 15, 2026

What

When loading an HF safetensors checkpoint into a model that contains
TransformerEngine modules, the checkpoint loader emits a noisy
Checkpoint key mismatch warning. The "missing" keys are all
_extra_state entries that TE attaches to its own modules for internal
bookkeeping (FP8 amax history, etc.) and that are not part of the HF
safetensors checkpoint. The framework already neutralizes these via the
set_extra_state shim in components/checkpoint/stateful_wrappers.py,
so reporting them as missing weights is misleading.

This PR filters *_extra_state keys out of both sides of the
mismatch summary so the warning only fires when real weights are
actually missing or unexpected.

Observed warning (8-rank FSDP2 + EP load of NVIDIA-Nemotron-3-Nano-30B-A3B-BF16)

WARNING:root:Checkpoint key mismatch for FSDPNemotronHForCausalLM: missing=77 unexpected=0 (missing examples=['lm_head._extra_state', 'model.layers.1.mixer.shared_experts.down_proj._extra_state', 'model.layers.1.mixer.shared_experts.up_proj._extra_state', 'model.layers.10.mixer.shared_experts.down_proj._extra_state', 'model.layers.10.mixer.shared_experts.up_proj._extra_state', 'model.layers.12.mixer.attn_module._extra_state', 'model.layers.12.mixer.k_proj._extra_state', 'model.layers.12.mixer.o_proj._extra_state', 'model.layers.12.mixer.q_proj._extra_state', 'model.layers.12.mixer.v_proj._extra_state'], unexpected examples=[])

All 77 missing keys end in ._extra_state.

Changelog

  • Exclude *_extra_state keys from the load-time
    Checkpoint key mismatch warning so it focuses on real weight
    mismatches.

Pre-checks

  • Linting: no formatting/style changes required
  • DCO sign-off

… warning

TransformerEngine modules attach `_extra_state` entries to their state_dict
for internal bookkeeping that is not present in HF safetensors checkpoints.
These were being reported as missing keys in the load diagnostic, producing
a noisy warning with up to dozens of `_extra_state` examples on every load.
The `set_extra_state` shim already tolerates their absence, so filtering
them out of the mismatch summary keeps the warning focused on real weight
mismatches.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.


state_dict = _maybe_adapt_state_dict_from_hf(model_state.model[0], state_dict, moe_mesh=self.moe_mesh)
key_diff = _summarize_state_dict_key_diff(expected_keys, set(state_dict.keys()))
expected_keys_for_diff = {k for k in expected_keys if not k.endswith("_extra_state")}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @adil-a , i think for bf16 this should be fine, but thinking ahead, I'm wondering if that would break any fp8 workflows? Please let me know what you think.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about we put a check to ensure the model is in bf16 as well?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, i would remove _extra_state if it's empty on the model side

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

response from CC:

Good question — I dug into this and it turns out FP8 workflows aren't affected, because the codebase already treats _extra_state as ephemeral on every save/load
path. Quick summary of why:

TE-based FP8. TE's FP8 amax history lives in _extra_state, but it's never round-tripped through our checkpointer:

  • _maybe_adapt_state_dict_to_hf in checkpointing.py calls adapter.to_hf(..., exclude_key_regex=r"._extra_state.", ...) unconditionally on save (line 1590).
  • stateful_wrappers.py monkey-patches TransformerEngineBaseModule.set_extra_state and BasicOperation.set_extra_state to no-op on DCP's _EXTRA_STATE sentinel (lines
    31–49), gated only on HAS_TE — not on FP8 being enabled.
  • te_attention.py stashes the TE attention module via object.setattr specifically so attn_module._extra_state never enters the state_dict (lines 548–553, with
    a comment to that effect).
  • On load, when a TE module has a custom get_extra_state, an empty torch.tensor([], dtype=uint8) placeholder is injected so DCP doesn't complain (lines 1437–1443).

Net result: TE FP8 amax history is rebuilt from observed activations after load. Filtering _extra_state from the mismatch warning only hides keys that the rest of
the framework is already silently dropping.

torchao Float8Linear. Doesn't use _extra_state at all — Float8Linear.{get,set}_extra_state is nn.Module.{get,set}_extra_state, its state_dict() is ['weight',
'bias'], and the weight stays as a bf16/fp32 nn.Parameter (FP8 conversion happens dynamically in forward).

So unconditionally filtering _extra_state from the warning is consistent with how all four other code paths handle it, and won't mask any real mismatch for FP8
workflows.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, i would remove _extra_state if it's empty on the model side

It comes with the TE nn.Modules as a parameter. We'd have to do unnecessary plumbing for this. TBH checkpointing is already in a finicky spot I'd rather we avoid changing things where possible in the current state.

@adil-a
Copy link
Copy Markdown
Collaborator Author

adil-a commented May 18, 2026

/ok to test c7cfaee

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants