diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index ad7618868a..1cb80e2ae7 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -23,6 +23,7 @@ from flax import nnx from flax.training import train_state import jax +import jax.numpy as jnp from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE from maxtext.input_pipeline.multihost_dataloading import MultiHostDataLoadIterator from maxtext.input_pipeline.multihost_dataloading import RemoteIteratorWrapper @@ -165,6 +166,167 @@ class GrainCheckpointRestore(ocp.args.CheckpointArgs): process_count: Optional[int] = None +_NNX_RNG_STATE_KEYS = ("rngs", "dropout") + + +def _strip_nnx_rng_state(tree): + """Removes 'rngs' and 'dropout' subtrees that NNX has but Linen doesn't.""" + if not isinstance(tree, dict): + return tree + return {k: _strip_nnx_rng_state(v) for k, v in tree.items() if k not in _NNX_RNG_STATE_KEYS} + + +def _wrap_mu_nu_with_params(state): + """Wraps mu/nu in a state dict with an inner 'params' key (Linen collection).""" + if not isinstance(state, dict): + return state + return {k: {"params": v} if k in ("mu", "nu") and isinstance(v, dict) else v for k, v in state.items()} + + +def _strip_mu_nu_params(state): + """Strips the inner 'params' wrapping from mu/nu in a state dict.""" + if not isinstance(state, dict): + return state + return {k: v["params"] if k in ("mu", "nu") and isinstance(v, dict) and "params" in v else v for k, v in state.items()} + + +def _translate_nnx_opt_state_to_linen(nnx_opt_state): + """Reshapes NNX opt_state (int-keyed chain dict) to Linen on-disk layout (flat or list).""" + if not isinstance(nnx_opt_state, dict): + return nnx_opt_state + keys = list(nnx_opt_state.keys()) + is_chain = bool(keys) and all(isinstance(k, str) and k.isdigit() for k in keys) + if not is_chain: + return _wrap_mu_nu_with_params(nnx_opt_state) + sorted_keys = sorted(keys, key=int) + if len(sorted_keys) == 1: + return _wrap_mu_nu_with_params(nnx_opt_state[sorted_keys[0]]) + return [_wrap_mu_nu_with_params(nnx_opt_state[k]) for k in sorted_keys] + + +def _translate_linen_opt_state_to_nnx(linen_opt_state, nnx_abstract_opt_state): + """Reshapes restored Linen opt_state back to the NNX abstract's layout.""" + is_nnx_chain = ( + isinstance(nnx_abstract_opt_state, dict) + and bool(nnx_abstract_opt_state) + and all(isinstance(k, str) and k.isdigit() for k in nnx_abstract_opt_state.keys()) + ) + if isinstance(linen_opt_state, list): + return {str(i): _strip_mu_nu_params(el) for i, el in enumerate(linen_opt_state)} + stripped = _strip_mu_nu_params(linen_opt_state) + return {"0": stripped} if is_nnx_chain else stripped + + +def _replace_values_in_abstract_nnx_state(abstract_nnx_state, concrete_pure_dict): + """Rebuilds an nnx.State by injecting concrete values into abstract Variables, by path lookup. + + Paths missing from `concrete_pure_dict` get a default: jnp.zeros for + plain dtypes, jax.random.key(0) for typed RNG keys. + """ + + def _key_to_str(k): + for attr in ("key", "name", "idx"): + if hasattr(k, attr): + return getattr(k, attr) + return str(k) + + def _navigate(tree, keys): + cur = tree + for k in keys: + if not isinstance(cur, dict): + return None + if k in cur: + cur = cur[k] + elif str(k) in cur: + cur = cur[str(k)] + else: + return None + return cur + + def _default_for(sds): + if not (hasattr(sds, "dtype") and hasattr(sds, "shape")): + return sds + if "key" not in str(sds.dtype): + return jnp.zeros(sds.shape, dtype=sds.dtype) + base = jax.random.key(0) + if sds.shape == (): + return base + return jax.random.split(base, int(np.prod(sds.shape))).reshape(sds.shape) + + abs_paths_vars, treedef = jax.tree_util.tree_flatten_with_path( + abstract_nnx_state, is_leaf=lambda x: isinstance(x, nnx.Variable) + ) + new_leaves = [] + for path, var in abs_paths_vars: + concrete = _navigate(concrete_pure_dict, [_key_to_str(k) for k in path]) + if isinstance(var, nnx.Variable): + new_leaves.append(var.replace(value=concrete if concrete is not None else _default_for(var.value))) + else: + new_leaves.append(concrete if concrete is not None else var) + return jax.tree_util.tree_unflatten(treedef, new_leaves) + + +def _load_linen_full_state_into_nnx(path, abstract_nnx_state, checkpoint_storage_concurrent_gb, use_ocdbt, use_zarr3): + """Loads a Linen full-state checkpoint into an NNX nnx.State abstract. + + Translates between the two layouts: Linen has params/params, top-level + step, and opt_state with inner 'params' wrapping on mu/nu; NNX has model, + optimizer.step, and optimizer.opt_state without that wrapping. NNX-only + RNG/dropout state (not in Linen) is filled with deterministic defaults + (jnp.zeros for plain dtypes, jax.random.key(0) for typed RNG keys). + + Args: + path: GCS or local checkpoint directory. + abstract_nnx_state: Target nnx.State, whose to_pure_dict() yields + {model: ..., optimizer: {step, opt_state}}. + checkpoint_storage_concurrent_gb: Concurrent GB for byte I/O. + use_ocdbt: Whether to use OCDBT format. + use_zarr3: Whether to use Zarr3 format. + + Returns: + An nnx.State of the same shape as the abstract, with concrete values. + """ + max_logging.log(f"Detected Linen full-state on disk; adapting to NNX abstract at {path}") + + pure = abstract_nnx_state.to_pure_dict() + if not (isinstance(pure, dict) and "model" in pure and "optimizer" in pure): + raise ValueError( + f"Expected NNX abstract with model/optimizer; got top-level keys " + f"{list(pure.keys()) if isinstance(pure, dict) else type(pure).__name__}" + ) + + nnx_opt_state_unboxed = _strip_value_wrappers(pure["optimizer"]["opt_state"]) + linen_abstract = { + "params": {"params": _strip_nnx_rng_state(_strip_value_wrappers(pure["model"]))}, + "step": _strip_value_wrappers(pure["optimizer"]["step"]), + "opt_state": _translate_nnx_opt_state_to_linen(nnx_opt_state_unboxed), + } + + ckptr = ocp.Checkpointer( + ocp.PyTreeCheckpointHandler( + restore_concurrent_gb=checkpoint_storage_concurrent_gb, + save_concurrent_gb=checkpoint_storage_concurrent_gb, + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + ) + ) + restore_args = ocp.checkpoint_utils.construct_restore_args(linen_abstract) + restored = ckptr.restore( + epath.Path(path), + item=linen_abstract, + restore_args=restore_args, + ) + + nnx_pure_concrete = { + "model": restored["params"]["params"], + "optimizer": { + "step": restored["step"], + "opt_state": _translate_linen_opt_state_to_nnx(restored["opt_state"], nnx_opt_state_unboxed), + }, + } + return _replace_values_in_abstract_nnx_state(abstract_nnx_state, nnx_pure_concrete) + + def _load_full_state_from_path( path, abstract_unboxed_pre_state, @@ -216,6 +378,16 @@ def combine_sharding(sds, shardings): else: raise ocp_v1.errors.InvalidLayoutError(f"Unknown checkpoint layout: {source_checkpoint_layout}") else: + # Cross-format: NNX target reading a Linen-saved full-state checkpoint. + if isinstance(abstract_unboxed_pre_state, nnx.State) and _is_linen_full_state_on_disk(path, use_ocdbt, use_zarr3): + return _load_linen_full_state_into_nnx( + path, + abstract_unboxed_pre_state, + checkpoint_storage_concurrent_gb, + use_ocdbt, + use_zarr3, + ) + # Original v0 logic. p = epath.Path(path) handler = ocp.PyTreeCheckpointHandler( @@ -735,6 +907,123 @@ def setup_checkpoint_logger(config) -> Any | None: # pytype: disable=attribute- return orbax_cloud_logger +def _peek_on_disk_tree(path, use_ocdbt, use_zarr3): + """Returns the on-disk pytree metadata at `path`.""" + ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler(use_ocdbt=use_ocdbt, use_zarr3=use_zarr3)) + metadata = ckptr.metadata(epath.Path(path)) + if hasattr(metadata, "item_metadata") and hasattr(metadata.item_metadata, "tree"): + return metadata.item_metadata.tree + if hasattr(metadata, "tree"): + return metadata.tree + return metadata + + +def _is_linen_format_on_disk(path, use_ocdbt, use_zarr3): + """Returns True if the on-disk top-level is a Linen params/params/... layout.""" + try: + tree = _peek_on_disk_tree(path, use_ocdbt, use_zarr3) + if isinstance(tree, dict) and "params" in tree: + inner = tree["params"] + if isinstance(inner, dict) and "params" in inner: + max_logging.log(f"format-detect: Linen layout detected at {path}") + return True + except Exception as e: # pylint: disable=broad-except + max_logging.log(f"format-detect: could not peek metadata at {path}: {e}") + return False + + +def _is_linen_full_state_on_disk(path, use_ocdbt, use_zarr3): + """Returns True if the on-disk top-level is a Linen full-state layout.""" + try: + tree = _peek_on_disk_tree(path, use_ocdbt, use_zarr3) + if isinstance(tree, dict) and {"step", "params", "opt_state"}.issubset(tree.keys()): + inner = tree.get("params") + if isinstance(inner, dict) and "params" in inner: + max_logging.log(f"format-detect: Linen full-state layout detected at {path}") + return True + except Exception as e: # pylint: disable=broad-except + max_logging.log(f"format-detect: could not peek full-state metadata at {path}: {e}") + return False + + +def _strip_value_wrappers(tree): + """Unwraps {value: x} -> x recursively for leaf-like x.""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"} and not isinstance(tree["value"], dict): + return tree["value"] + return {k: _strip_value_wrappers(v) for k, v in tree.items()} + return tree + + +def _wrap_with_value(tree): + """Wraps each leaf as {value: leaf} recursively.""" + if isinstance(tree, dict): + return {k: _wrap_with_value(v) for k, v in tree.items()} + return {"value": tree} + + +def _load_linen_params_into_nnx(path, nnx_abstract_params, checkpoint_storage_concurrent_gb, use_ocdbt, use_zarr3): + """Loads a Linen-format params checkpoint into an NNX nnx.State abstract. + + Restores using a Linen-shape abstract built by stripping the NNX + {value:} Variable boxing, then rebuilds the target nnx.State by + replacing each Variable's .value with the restored array. Sharding + metadata survives because it lives on the ShapeDtypeStruct leaves. + + Args: + path: GCS or local checkpoint directory. + nnx_abstract_params: Target nnx.State (typically the nnx.Param group + from nnx.split(model, nnx.Param, ...)). + checkpoint_storage_concurrent_gb: Concurrent GB for byte I/O. + use_ocdbt: Whether to use OCDBT format. + use_zarr3: Whether to use Zarr3 format. + + Returns: + An nnx.State of the same shape as the abstract, with concrete values. + """ + max_logging.log(f"Detected Linen format on disk; adapting to NNX abstract at {path}") + + pure_dict = nnx_abstract_params.to_pure_dict() + linen_shape_abstract = _strip_value_wrappers(pure_dict) + inner_wrapped = {"params": linen_shape_abstract} + + ckptr = ocp.Checkpointer( + ocp.PyTreeCheckpointHandler( + restore_concurrent_gb=checkpoint_storage_concurrent_gb, + save_concurrent_gb=checkpoint_storage_concurrent_gb, + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + ) + ) + restore_args = ocp.checkpoint_utils.construct_restore_args(inner_wrapped) + restored = ckptr.restore( + epath.Path(path), + item={"params": inner_wrapped}, + transforms={}, + restore_args={"params": restore_args}, + ) + linen_restored = restored["params"]["params"] + + abstract_leaves, treedef = jax.tree_util.tree_flatten( + nnx_abstract_params, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + concrete_leaves = jax.tree_util.tree_leaves(linen_restored) + if len(abstract_leaves) != len(concrete_leaves): + raise ValueError( + f"Linen->NNX adapter: leaf count mismatch — " + f"{len(abstract_leaves)} abstract Variables vs " + f"{len(concrete_leaves)} restored arrays. Trees do not align." + ) + new_leaves = [] + for var, arr in zip(abstract_leaves, concrete_leaves): + if isinstance(var, nnx.Variable): + new_leaves.append(var.replace(value=arr)) + else: + new_leaves.append(arr) + return jax.tree_util.tree_unflatten(treedef, new_leaves) + + def load_params_from_path( load_parameters_from_path, abstract_unboxed_params, checkpoint_storage_concurrent_gb, use_ocdbt=True, use_zarr3=True ): @@ -742,6 +1031,18 @@ def load_params_from_path( assert load_parameters_from_path, "load_parameters_from_path is not defined." max_logging.log(f"restoring params from {load_parameters_from_path}") + # Cross-format: NNX target reading a Linen-saved checkpoint. + if isinstance(abstract_unboxed_params, nnx.State) and _is_linen_format_on_disk( + load_parameters_from_path, use_ocdbt, use_zarr3 + ): + return _load_linen_params_into_nnx( + load_parameters_from_path, + abstract_unboxed_params, + checkpoint_storage_concurrent_gb, + use_ocdbt, + use_zarr3, + ) + # *_concurrent_gb should be set for large models, the default is 96. max_logging.log(f"Creating checkpoint manager with ocdbt={use_ocdbt} and zarr3={use_zarr3}") ckptr = ocp.Checkpointer(