NNX: auto-adapt Linen checkpoints on load (params + full-state)#3929
Draft
ecnal-cienet wants to merge 1 commit into
Draft
NNX: auto-adapt Linen checkpoints on load (params + full-state)#3929ecnal-cienet wants to merge 1 commit into
ecnal-cienet wants to merge 1 commit into
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
fb4d7af to
f77e99c
Compare
Auto-detect Linen on-disk layout (params/params/...) and adapt to NNX
abstract state in-memory, so Linen-saved checkpoints work as drop-in
load_parameters_path / load_full_state_path into NNX-mode training
without a separate converter script.
Two adapters, both module-private to checkpointing.py:
- _load_linen_params_into_nnx: weights-only path. Strips {value:}
wrappers from the NNX abstract to build a Linen-shape abstract for
Orbax, restores, then reconstructs the target nnx.State by replacing
each Variable.value with the restored array.
- _load_linen_full_state_into_nnx: full-state path. Same translation
for model weights, plus translations for the optimizer envelope
(Linen step + opt_state at top -> NNX optimizer.step + optimizer.
opt_state, with the inner params collection level stripped from
mu/nu and optax-chain {'0': ...} unwrapped if present). Fills
NNX-only RNG/dropout subtrees (which Linen checkpoints don't carry)
with deterministic defaults: jnp.zeros for plain dtypes, jax.random.
key(0) reshaped to match the abstract's expected shape for typed
key<urbg> dtype.
Both fire only when isinstance(abstract, nnx.State) AND the on-disk
top-level shows the double-nested params/params/ layout. Same-format
loads (Linen->Linen, NNX->NNX) bypass the adapter entirely; the only
cost in those cases is one metadata read for format detection.
Verified end-to-end on gpt3-52k with a 4-variant matrix
(pure_nnx_decoder True/False x load_full_state_path/load_parameters_path).
Loss trajectories are bit-identical between Linen->NNX (adapter fires)
and Linen->Linen (same-format baseline). Zero per-step overhead; one-shot
O(N_variables) cost at load time, dominated by Orbax disk I/O.
Not in scope for this change: AQT quantization state (different
Variable type, follow-up), NNX->Linen reverse direction (no realistic
use case).
f77e99c to
c035e07
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
src/maxtext/common/checkpointing.pythat auto-detect Linen-format checkpoints on disk and translate them to NNX state shape at load time.load_parameters_path/load_full_state_pathinto NNX-mode training (pure_nnx_decoder=True), without a separate converter script.Test plan
Verified end-to-end on
gpt3-52k(v6e-8, synthetic dataset,scan_layers=True) with a 4-variant matrix. Phase 1 trains 5 steps pure Linen, saves checkpoint. Phase 2 resumes under 4 configurations:pure_nnx_decoderload_full_state_pathload_parameters_pathload_full_state_pathload_parameters_pathV1 ↔ V3 and V2 ↔ V4 produce bit-identical loss trajectories — adapter is functionally equivalent to the same-format restore.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.