Skip to content

NNX: auto-adapt Linen checkpoints on load (params + full-state)#3929

Draft
ecnal-cienet wants to merge 1 commit into
mainfrom
feat/checkpoint-linen-to-nnx-adapter
Draft

NNX: auto-adapt Linen checkpoints on load (params + full-state)#3929
ecnal-cienet wants to merge 1 commit into
mainfrom
feat/checkpoint-linen-to-nnx-adapter

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

Summary

  • Add two in-memory adapters in src/maxtext/common/checkpointing.py that auto-detect Linen-format checkpoints on disk and translate them to NNX state shape at load time.
  • Linen-saved checkpoints now work as drop-in load_parameters_path / load_full_state_path into NNX-mode training (pure_nnx_decoder=True), without a separate converter script.
  • Detection is via on-disk metadata; no config flag, no opt-in. Same-format loads bypass the adapter entirely.

Test plan

Verified end-to-end on gpt3-52k (v6e-8, synthetic dataset, scan_layers=True) with a 4-variant matrix. Phase 1 trains 5 steps pure Linen, saves checkpoint. Phase 2 resumes under 4 configurations:

# pure_nnx_decoder Load path Adapter Pre-PR Post-PR Loss step 0/5 → step 9
V1 True load_full_state_path fires ❌ FAIL ✅ PASS 14.378 → 14.377 (matches V3)
V2 True load_parameters_path fires ❌ FAIL ✅ PASS 14.378 → 14.373 (matches V4)
V3 False load_full_state_path (no) 14.378 → 14.377
V4 False load_parameters_path (no) 14.378 → 14.373

V1 ↔ V3 and V2 ↔ V4 produce bit-identical loss trajectories — adapter is functionally equivalent to the same-format restore.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 18, 2026

Codecov Report

❌ Patch coverage is 10.71429% with 125 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/common/checkpointing.py 10.71% 123 Missing and 2 partials ⚠️

📢 Thoughts on this report? Let us know!

@ecnal-cienet ecnal-cienet force-pushed the feat/checkpoint-linen-to-nnx-adapter branch 3 times, most recently from fb4d7af to f77e99c Compare May 18, 2026 00:38
Auto-detect Linen on-disk layout (params/params/...) and adapt to NNX
abstract state in-memory, so Linen-saved checkpoints work as drop-in
load_parameters_path / load_full_state_path into NNX-mode training
without a separate converter script.

Two adapters, both module-private to checkpointing.py:
- _load_linen_params_into_nnx: weights-only path. Strips {value:}
  wrappers from the NNX abstract to build a Linen-shape abstract for
  Orbax, restores, then reconstructs the target nnx.State by replacing
  each Variable.value with the restored array.
- _load_linen_full_state_into_nnx: full-state path. Same translation
  for model weights, plus translations for the optimizer envelope
  (Linen step + opt_state at top -> NNX optimizer.step + optimizer.
  opt_state, with the inner params collection level stripped from
  mu/nu and optax-chain {'0': ...} unwrapped if present). Fills
  NNX-only RNG/dropout subtrees (which Linen checkpoints don't carry)
  with deterministic defaults: jnp.zeros for plain dtypes, jax.random.
  key(0) reshaped to match the abstract's expected shape for typed
  key<urbg> dtype.

Both fire only when isinstance(abstract, nnx.State) AND the on-disk
top-level shows the double-nested params/params/ layout. Same-format
loads (Linen->Linen, NNX->NNX) bypass the adapter entirely; the only
cost in those cases is one metadata read for format detection.

Verified end-to-end on gpt3-52k with a 4-variant matrix
(pure_nnx_decoder True/False x load_full_state_path/load_parameters_path).
Loss trajectories are bit-identical between Linen->NNX (adapter fires)
and Linen->Linen (same-format baseline). Zero per-step overhead; one-shot
O(N_variables) cost at load time, dominated by Orbax disk I/O.

Not in scope for this change: AQT quantization state (different
Variable type, follow-up), NNX->Linen reverse direction (no realistic
use case).
@ecnal-cienet ecnal-cienet force-pushed the feat/checkpoint-linen-to-nnx-adapter branch from f77e99c to c035e07 Compare May 18, 2026 00:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant