fix(checkpoint): exclude TE _extra_state keys from load-time mismatch warning#2247
fix(checkpoint): exclude TE _extra_state keys from load-time mismatch warning#2247adil-a wants to merge 1 commit into
Conversation
… warning TransformerEngine modules attach `_extra_state` entries to their state_dict for internal bookkeeping that is not present in HF safetensors checkpoints. These were being reported as missing keys in the load diagnostic, producing a noisy warning with up to dozens of `_extra_state` examples on every load. The `set_extra_state` shim already tolerates their absence, so filtering them out of the mismatch summary keeps the warning focused on real weight mismatches. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: adil-a <adil.asif2000@hotmail.com>
|
|
||
| state_dict = _maybe_adapt_state_dict_from_hf(model_state.model[0], state_dict, moe_mesh=self.moe_mesh) | ||
| key_diff = _summarize_state_dict_key_diff(expected_keys, set(state_dict.keys())) | ||
| expected_keys_for_diff = {k for k in expected_keys if not k.endswith("_extra_state")} |
There was a problem hiding this comment.
Thanks @adil-a , i think for bf16 this should be fine, but thinking ahead, I'm wondering if that would break any fp8 workflows? Please let me know what you think.
There was a problem hiding this comment.
how about we put a check to ensure the model is in bf16 as well?
There was a problem hiding this comment.
IMHO, i would remove _extra_state if it's empty on the model side
There was a problem hiding this comment.
response from CC:
Good question — I dug into this and it turns out FP8 workflows aren't affected, because the codebase already treats _extra_state as ephemeral on every save/load
path. Quick summary of why:
TE-based FP8. TE's FP8 amax history lives in _extra_state, but it's never round-tripped through our checkpointer:
- _maybe_adapt_state_dict_to_hf in checkpointing.py calls adapter.to_hf(..., exclude_key_regex=r"._extra_state.", ...) unconditionally on save (line 1590).
- stateful_wrappers.py monkey-patches TransformerEngineBaseModule.set_extra_state and BasicOperation.set_extra_state to no-op on DCP's _EXTRA_STATE sentinel (lines
31–49), gated only on HAS_TE — not on FP8 being enabled. - te_attention.py stashes the TE attention module via object.setattr specifically so attn_module._extra_state never enters the state_dict (lines 548–553, with
a comment to that effect). - On load, when a TE module has a custom get_extra_state, an empty torch.tensor([], dtype=uint8) placeholder is injected so DCP doesn't complain (lines 1437–1443).
Net result: TE FP8 amax history is rebuilt from observed activations after load. Filtering _extra_state from the mismatch warning only hides keys that the rest of
the framework is already silently dropping.
torchao Float8Linear. Doesn't use _extra_state at all — Float8Linear.{get,set}_extra_state is nn.Module.{get,set}_extra_state, its state_dict() is ['weight',
'bias'], and the weight stays as a bf16/fp32 nn.Parameter (FP8 conversion happens dynamically in forward).
So unconditionally filtering _extra_state from the warning is consistent with how all four other code paths handle it, and won't mask any real mismatch for FP8
workflows.
There was a problem hiding this comment.
IMHO, i would remove _extra_state if it's empty on the model side
It comes with the TE nn.Modules as a parameter. We'd have to do unnecessary plumbing for this. TBH checkpointing is already in a finicky spot I'd rather we avoid changing things where possible in the current state.
|
/ok to test c7cfaee |
What
When loading an HF safetensors checkpoint into a model that contains
TransformerEngine modules, the checkpoint loader emits a noisy
Checkpoint key mismatchwarning. The "missing" keys are all_extra_stateentries that TE attaches to its own modules for internalbookkeeping (FP8 amax history, etc.) and that are not part of the HF
safetensors checkpoint. The framework already neutralizes these via the
set_extra_stateshim incomponents/checkpoint/stateful_wrappers.py,so reporting them as missing weights is misleading.
This PR filters
*_extra_statekeys out of both sides of themismatch summary so the warning only fires when real weights are
actually missing or unexpected.
Observed warning (8-rank FSDP2 + EP load of NVIDIA-Nemotron-3-Nano-30B-A3B-BF16)
All 77 missing keys end in
._extra_state.Changelog
*_extra_statekeys from the load-timeCheckpoint key mismatchwarning so it focuses on real weightmismatches.
Pre-checks