From 34e2536f1803eaab9d23c61c5a7f687f98118e04 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 18 May 2026 00:04:54 +0000 Subject: [PATCH] NNX: add Linen-format checkpoint adapter for load paths Auto-detect Linen on-disk layout (params/params/...) and adapt to NNX abstract state in-memory, so Linen-saved checkpoints work as drop-in load_parameters_path / load_full_state_path into NNX-mode training without a separate converter script. Two adapters, both module-private to checkpointing.py: - _load_linen_params_into_nnx: weights-only path. Strips {value:} wrappers from the NNX abstract to build a Linen-shape abstract for Orbax, restores, then reconstructs the target nnx.State by replacing each Variable.value with the restored array. - _load_linen_full_state_into_nnx: full-state path. Same translation for model weights, plus translations for the optimizer envelope (Linen step + opt_state at top -> NNX optimizer.step + optimizer. opt_state, with the inner params collection level stripped from mu/nu and optax-chain {'0': ...} unwrapped if present). Fills NNX-only RNG/dropout subtrees (which Linen checkpoints don't carry) with deterministic defaults: jnp.zeros for plain dtypes, jax.random. key(0) reshaped to match the abstract's expected shape for typed key dtype. Both fire only when isinstance(abstract, nnx.State) AND the on-disk top-level shows the double-nested params/params/ layout. Same-format loads (Linen->Linen, NNX->NNX) bypass the adapter entirely; the only cost in those cases is one metadata read for format detection. Verified end-to-end on gpt3-52k with a 4-variant matrix (pure_nnx_decoder True/False x load_full_state_path/load_parameters_path). Loss trajectories are bit-identical between Linen->NNX (adapter fires) and Linen->Linen (same-format baseline). Zero per-step overhead; one-shot O(N_variables) cost at load time, dominated by Orbax disk I/O. Not in scope for this change: AQT quantization state (different Variable type, follow-up), NNX->Linen reverse direction (no realistic use case). --- src/maxtext/common/checkpointing.py | 371 +++++++++++++++++++++++++++- 1 file changed, 369 insertions(+), 2 deletions(-) diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index ad7618868a..90ea77f8a4 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -23,6 +23,7 @@ from flax import nnx from flax.training import train_state import jax +import jax.numpy as jnp from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE from maxtext.input_pipeline.multihost_dataloading import MultiHostDataLoadIterator from maxtext.input_pipeline.multihost_dataloading import RemoteIteratorWrapper @@ -165,6 +166,228 @@ class GrainCheckpointRestore(ocp.args.CheckpointArgs): process_count: Optional[int] = None +_NNX_RNG_STATE_KEYS = ("rngs", "dropout") + + +def _strip_nnx_rng_state(tree): + """Removes 'rngs' and 'dropout' subtrees that NNX has but Linen doesn't.""" + if not isinstance(tree, dict): + return tree + return {k: _strip_nnx_rng_state(v) for k, v in tree.items() if k not in _NNX_RNG_STATE_KEYS} + + +def _wrap_mu_nu_with_params(state): + """Wraps mu/nu in a state dict with an inner 'params' key (Linen collection).""" + if not isinstance(state, dict): + return state + return {k: {"params": v} if k in ("mu", "nu") and isinstance(v, dict) else v for k, v in state.items()} + + +def _strip_mu_nu_params(state): + """Strips the inner 'params' wrapping from mu/nu in a state dict.""" + if not isinstance(state, dict): + return state + return {k: v["params"] if k in ("mu", "nu") and isinstance(v, dict) and "params" in v else v for k, v in state.items()} + + +def _as_chain_index(key): + """Returns int index if `key` is an int or digit string, else None.""" + if isinstance(key, int): + return key + if isinstance(key, str) and key.isdigit(): + return int(key) + return None + + +def _nnx_opt_state_to_linen_shape(opt_state): + """Reshapes NNX opt_state to Linen on-disk layout. + + NNX's to_pure_dict() exposes an optax chain as an int-keyed dict (empty + entries skipped); Linen serializes the same chain as a list with `None` + placeholders. For a single-element chain, the inner state is returned + directly to match Linen's un-chained shape (e.g. adam_pax). Inside each + element, mu/nu are wrapped with the Linen 'params' collection key. + """ + if not isinstance(opt_state, dict): + return opt_state + indices = [_as_chain_index(k) for k in opt_state.keys()] + is_chain = bool(indices) and all(i is not None for i in indices) + if not is_chain: + return _wrap_mu_nu_with_params(opt_state) + length = max(indices) + 1 + chain = [None] * length + for orig_key, idx in zip(opt_state.keys(), indices): + chain[idx] = _wrap_mu_nu_with_params(opt_state[orig_key]) + if length == 1: + return chain[0] + return chain + + +def _linen_opt_state_to_nnx_shape(opt_state): + """Inverse of `_nnx_opt_state_to_linen_shape`. + + Accepts either Linen's flat dict (adam_pax-style) or list shape (adamw-style). + Lists become int-keyed dicts; non-dict entries (e.g. `None` for `EmptyState`) + are dropped to match NNX's `to_pure_dict()`. + """ + if isinstance(opt_state, list): + return {i: _strip_mu_nu_params(el) for i, el in enumerate(opt_state) if isinstance(el, dict)} + if not isinstance(opt_state, dict): + return opt_state + return {0: _strip_mu_nu_params(opt_state)} + + +def _nnx_state_to_linen_shape(nnx_pure_dict): + """Reshapes an NNX state pure dict to Linen on-disk layout. + + Used by the save path so that checkpoints written with `pure_nnx=True` + share an on-disk shape with Linen-saved checkpoints. The reverse, + `_linen_state_to_nnx_shape`, is used by the load path. + + NNX-only RNG/dropout state under `model` is stripped: it isn't part of + the Linen format and will be re-initialized from defaults on load. + + Args: + nnx_pure_dict: NNX state as a pure dict, typically the result of + `nnx.state(train_state_nnx).to_pure_dict()`. + + Returns: + A pure dict with the same leaf arrays in Linen on-disk layout + (`params/params/...`, top-level `step`, and `opt_state`). + """ + if not isinstance(nnx_pure_dict, dict): + return nnx_pure_dict + result = {} + if "model" in nnx_pure_dict: + result["params"] = {"params": _strip_nnx_rng_state(nnx_pure_dict["model"])} + optimizer = nnx_pure_dict.get("optimizer") + if isinstance(optimizer, dict): + if "step" in optimizer: + result["step"] = optimizer["step"] + if "opt_state" in optimizer: + result["opt_state"] = _nnx_opt_state_to_linen_shape(optimizer["opt_state"]) + return result + + +def _linen_state_to_nnx_shape(linen_pure_dict): + """Inverse of `_nnx_state_to_linen_shape`. + + Args: + linen_pure_dict: A pure dict in Linen on-disk layout (top-level keys + include `params`, `step`, and `opt_state`). + + Returns: + A pure dict with the same leaf arrays shaped as + `nnx.state(train_state_nnx).to_pure_dict()` — top-level `model` and + `optimizer` keys. + """ + if not isinstance(linen_pure_dict, dict): + return linen_pure_dict + result = {} + params = linen_pure_dict.get("params") + if isinstance(params, dict) and "params" in params: + result["model"] = params["params"] + elif params is not None: + result["model"] = params + optimizer = {} + if "step" in linen_pure_dict: + optimizer["step"] = linen_pure_dict["step"] + if "opt_state" in linen_pure_dict: + optimizer["opt_state"] = _linen_opt_state_to_nnx_shape(linen_pure_dict["opt_state"]) + if optimizer: + result["optimizer"] = optimizer + return result + + +def _default_for_sds(sds): + """Returns a deterministic concrete value matching `sds` shape/dtype/sharding. + + Honors `sds.sharding` (if present) so large defaults (e.g. an unfilled + optimizer mu/nu that mirrors a multi-GB params tree) are sharded across + the mesh instead of replicated to every device. + """ + if not (hasattr(sds, "dtype") and hasattr(sds, "shape")): + return sds + is_key = "key" in str(sds.dtype) + sharding = getattr(sds, "sharding", None) + if is_key: + base = jax.random.key(0) + if sds.shape == (): + value = base + else: + value = jax.random.split(base, int(np.prod(sds.shape))).reshape(sds.shape) + else: + value = jnp.zeros(sds.shape, dtype=sds.dtype) + if sharding is not None: + value = jax.device_put(value, sharding) + return value + + +def _populate_pure_dict_from_partial(abstract_pure, partial_concrete): + """Walks `abstract_pure` and substitutes each SDS leaf with the corresponding + concrete value from `partial_concrete` (looked up by path), falling back to a + default for paths that aren't present in `partial_concrete`. + + Returns a pure dict with the same structure as `abstract_pure` and concrete + values at every leaf. + """ + if isinstance(abstract_pure, dict): + return { + k: _populate_pure_dict_from_partial(v, partial_concrete.get(k) if isinstance(partial_concrete, dict) else None) + for k, v in abstract_pure.items() + } + if partial_concrete is not None and not isinstance(partial_concrete, dict): + return partial_concrete + return _default_for_sds(abstract_pure) + + +def _load_linen_full_state_into_nnx(path, abstract_nnx_state, checkpoint_storage_concurrent_gb, use_ocdbt, use_zarr3): + """Loads a Linen-shape full-state checkpoint into an NNX abstract. + + Builds a Linen-shape abstract via `_nnx_state_to_linen_shape`, restores against + it, then reshapes back via `_linen_state_to_nnx_shape`. NNX-only paths absent + from Linen (e.g. rngs/dropout) are filled by `_default_for_sds`. + + Args: + path: GCS or local checkpoint directory. + abstract_nnx_state: Target `nnx.State` with `{model, optimizer}` top-level. + checkpoint_storage_concurrent_gb: Concurrent GB for byte I/O. + use_ocdbt: Whether to use OCDBT format. + use_zarr3: Whether to use Zarr3 format. + + Returns: + A pure dict in NNX state layout with concrete values at every leaf, ready + for `nnx.replace_by_pure_dict`. + """ + max_logging.log(f"Adapting Linen-shape full-state checkpoint to NNX at {path}") + + pure = abstract_nnx_state.to_pure_dict() + if not (isinstance(pure, dict) and "model" in pure and "optimizer" in pure): + raise ValueError( + f"Expected NNX abstract with model/optimizer; got top-level keys " + f"{list(pure.keys()) if isinstance(pure, dict) else type(pure).__name__}" + ) + + linen_abstract = _nnx_state_to_linen_shape(_strip_value_wrappers(pure)) + + ckptr = ocp.Checkpointer( + ocp.PyTreeCheckpointHandler( + restore_concurrent_gb=checkpoint_storage_concurrent_gb, + save_concurrent_gb=checkpoint_storage_concurrent_gb, + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + ) + ) + restore_args = ocp.checkpoint_utils.construct_restore_args(linen_abstract) + restored = ckptr.restore( + epath.Path(path), + args=ocp.args.PyTreeRestore(item=linen_abstract, restore_args=restore_args, partial_restore=True), + ) + + partial_concrete = _linen_state_to_nnx_shape(restored) + return _populate_pure_dict_from_partial(pure, partial_concrete) + + def _load_full_state_from_path( path, abstract_unboxed_pre_state, @@ -216,6 +439,16 @@ def combine_sharding(sds, shardings): else: raise ocp_v1.errors.InvalidLayoutError(f"Unknown checkpoint layout: {source_checkpoint_layout}") else: + # Cross-format: NNX target reading a Linen-saved full-state checkpoint. + if isinstance(abstract_unboxed_pre_state, nnx.State) and _is_linen_full_state_on_disk(path, use_ocdbt, use_zarr3): + return _load_linen_full_state_into_nnx( + path, + abstract_unboxed_pre_state, + checkpoint_storage_concurrent_gb, + use_ocdbt, + use_zarr3, + ) + # Original v0 logic. p = epath.Path(path) handler = ocp.PyTreeCheckpointHandler( @@ -640,6 +873,20 @@ def map_to_pspec(data): ) ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) + # NNX target reading a Linen-saved full-state checkpoint. + checkpoint_path = str(checkpoint_manager.directory / str(step) / "items") + if isinstance(abstract_unboxed_pre_state, nnx.State) and _is_linen_full_state_on_disk( + checkpoint_path, use_ocdbt, use_zarr3 + ): + restored_nnx = _load_linen_full_state_into_nnx( + checkpoint_path, + abstract_unboxed_pre_state, + checkpoint_storage_concurrent_gb, + use_ocdbt, + use_zarr3, + ) + return ({"items": restored_nnx}, None) + # Convert nnx.State to pure dict to match how checkpoints are saved for NNX restore_target = abstract_unboxed_pre_state if isinstance(abstract_unboxed_pre_state, nnx.State): @@ -735,6 +982,114 @@ def setup_checkpoint_logger(config) -> Any | None: # pytype: disable=attribute- return orbax_cloud_logger +def _peek_on_disk_tree(path, use_ocdbt, use_zarr3): + """Returns the on-disk pytree metadata at `path`.""" + ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler(use_ocdbt=use_ocdbt, use_zarr3=use_zarr3)) + metadata = ckptr.metadata(epath.Path(path)) + if hasattr(metadata, "item_metadata") and hasattr(metadata.item_metadata, "tree"): + return metadata.item_metadata.tree + if hasattr(metadata, "tree"): + return metadata.tree + return metadata + + +def _is_linen_format_on_disk(path, use_ocdbt, use_zarr3): + """Returns True if the on-disk top-level is a Linen params/params/... layout.""" + try: + tree = _peek_on_disk_tree(path, use_ocdbt, use_zarr3) + if isinstance(tree, dict) and "params" in tree: + inner = tree["params"] + if isinstance(inner, dict) and "params" in inner: + return True + except Exception as e: # pylint: disable=broad-except + max_logging.log(f"Could not peek checkpoint metadata at {path}: {e}") + return False + + +def _is_linen_full_state_on_disk(path, use_ocdbt, use_zarr3): + """Returns True if the on-disk top-level is a Linen full-state layout.""" + try: + tree = _peek_on_disk_tree(path, use_ocdbt, use_zarr3) + if isinstance(tree, dict) and {"step", "params", "opt_state"}.issubset(tree.keys()): + inner = tree.get("params") + if isinstance(inner, dict) and "params" in inner: + return True + except Exception as e: # pylint: disable=broad-except + max_logging.log(f"Could not peek checkpoint metadata at {path}: {e}") + return False + + +def _strip_value_wrappers(tree): + """Unwraps {value: x} -> x recursively for leaf-like x.""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"} and not isinstance(tree["value"], dict): + return tree["value"] + return {k: _strip_value_wrappers(v) for k, v in tree.items()} + return tree + + +def _load_linen_params_into_nnx(path, nnx_abstract_params, checkpoint_storage_concurrent_gb, use_ocdbt, use_zarr3): + """Loads a Linen-format params checkpoint into an NNX nnx.State abstract. + + Restores using a Linen-shape abstract built by stripping the NNX + {value:} Variable boxing, then rebuilds the target nnx.State by + replacing each Variable's .value with the restored array. Sharding + metadata survives because it lives on the ShapeDtypeStruct leaves. + + Args: + path: GCS or local checkpoint directory. + nnx_abstract_params: Target nnx.State (typically the nnx.Param group + from nnx.split(model, nnx.Param, ...)). + checkpoint_storage_concurrent_gb: Concurrent GB for byte I/O. + use_ocdbt: Whether to use OCDBT format. + use_zarr3: Whether to use Zarr3 format. + + Returns: + An nnx.State of the same shape as the abstract, with concrete values. + """ + max_logging.log(f"Adapting Linen-shape params checkpoint to NNX at {path}") + + pure_dict = nnx_abstract_params.to_pure_dict() + linen_shape_abstract = _strip_value_wrappers(pure_dict) + inner_wrapped = {"params": linen_shape_abstract} + + ckptr = ocp.Checkpointer( + ocp.PyTreeCheckpointHandler( + restore_concurrent_gb=checkpoint_storage_concurrent_gb, + save_concurrent_gb=checkpoint_storage_concurrent_gb, + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + ) + ) + restore_args = ocp.checkpoint_utils.construct_restore_args(inner_wrapped) + restored = ckptr.restore( + epath.Path(path), + item={"params": inner_wrapped}, + transforms={}, + restore_args={"params": restore_args}, + ) + linen_restored = restored["params"]["params"] + + abstract_leaves, treedef = jax.tree_util.tree_flatten( + nnx_abstract_params, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + concrete_leaves = jax.tree_util.tree_leaves(linen_restored) + if len(abstract_leaves) != len(concrete_leaves): + raise ValueError( + f"Linen->NNX adapter: leaf count mismatch — " + f"{len(abstract_leaves)} abstract Variables vs " + f"{len(concrete_leaves)} restored arrays. Trees do not align." + ) + new_leaves = [] + for var, arr in zip(abstract_leaves, concrete_leaves): + if isinstance(var, nnx.Variable): + new_leaves.append(var.replace(value=arr)) + else: + new_leaves.append(arr) + return jax.tree_util.tree_unflatten(treedef, new_leaves) + + def load_params_from_path( load_parameters_from_path, abstract_unboxed_params, checkpoint_storage_concurrent_gb, use_ocdbt=True, use_zarr3=True ): @@ -742,6 +1097,18 @@ def load_params_from_path( assert load_parameters_from_path, "load_parameters_from_path is not defined." max_logging.log(f"restoring params from {load_parameters_from_path}") + # Cross-format: NNX target reading a Linen-saved checkpoint. + if isinstance(abstract_unboxed_params, nnx.State) and _is_linen_format_on_disk( + load_parameters_from_path, use_ocdbt, use_zarr3 + ): + return _load_linen_params_into_nnx( + load_parameters_from_path, + abstract_unboxed_params, + checkpoint_storage_concurrent_gb, + use_ocdbt, + use_zarr3, + ) + # *_concurrent_gb should be set for large models, the default is 96. max_logging.log(f"Creating checkpoint manager with ocdbt={use_ocdbt} and zarr3={use_zarr3}") ckptr = ocp.Checkpointer( @@ -794,8 +1161,8 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step actual_step = int(state.step) - 1 if config.pure_nnx: - # Convert nnx.State to dict. - state = state.to_pure_dict() + # Save in Linen on-disk shape so checkpoints share one format across frameworks. + state = _nnx_state_to_linen_shape(state.to_pure_dict()) # Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic. # This occurs if this function was called: