Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 301 additions & 0 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from flax import nnx
from flax.training import train_state
import jax
import jax.numpy as jnp
from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
from maxtext.input_pipeline.multihost_dataloading import MultiHostDataLoadIterator
from maxtext.input_pipeline.multihost_dataloading import RemoteIteratorWrapper
Expand Down Expand Up @@ -165,6 +166,167 @@ class GrainCheckpointRestore(ocp.args.CheckpointArgs):
process_count: Optional[int] = None


_NNX_RNG_STATE_KEYS = ("rngs", "dropout")


def _strip_nnx_rng_state(tree):
"""Removes 'rngs' and 'dropout' subtrees that NNX has but Linen doesn't."""
if not isinstance(tree, dict):
return tree
return {k: _strip_nnx_rng_state(v) for k, v in tree.items() if k not in _NNX_RNG_STATE_KEYS}


def _wrap_mu_nu_with_params(state):
"""Wraps mu/nu in a state dict with an inner 'params' key (Linen collection)."""
if not isinstance(state, dict):
return state
return {k: {"params": v} if k in ("mu", "nu") and isinstance(v, dict) else v for k, v in state.items()}


def _strip_mu_nu_params(state):
"""Strips the inner 'params' wrapping from mu/nu in a state dict."""
if not isinstance(state, dict):
return state
return {k: v["params"] if k in ("mu", "nu") and isinstance(v, dict) and "params" in v else v for k, v in state.items()}


def _translate_nnx_opt_state_to_linen(nnx_opt_state):
"""Reshapes NNX opt_state (int-keyed chain dict) to Linen on-disk layout (flat or list)."""
if not isinstance(nnx_opt_state, dict):
return nnx_opt_state
keys = list(nnx_opt_state.keys())
is_chain = bool(keys) and all(isinstance(k, str) and k.isdigit() for k in keys)
if not is_chain:
return _wrap_mu_nu_with_params(nnx_opt_state)
sorted_keys = sorted(keys, key=int)
if len(sorted_keys) == 1:
return _wrap_mu_nu_with_params(nnx_opt_state[sorted_keys[0]])
return [_wrap_mu_nu_with_params(nnx_opt_state[k]) for k in sorted_keys]


def _translate_linen_opt_state_to_nnx(linen_opt_state, nnx_abstract_opt_state):
"""Reshapes restored Linen opt_state back to the NNX abstract's layout."""
is_nnx_chain = (
isinstance(nnx_abstract_opt_state, dict)
and bool(nnx_abstract_opt_state)
and all(isinstance(k, str) and k.isdigit() for k in nnx_abstract_opt_state.keys())
)
if isinstance(linen_opt_state, list):
return {str(i): _strip_mu_nu_params(el) for i, el in enumerate(linen_opt_state)}
stripped = _strip_mu_nu_params(linen_opt_state)
return {"0": stripped} if is_nnx_chain else stripped


def _replace_values_in_abstract_nnx_state(abstract_nnx_state, concrete_pure_dict):
"""Rebuilds an nnx.State by injecting concrete values into abstract Variables, by path lookup.

Paths missing from `concrete_pure_dict` get a default: jnp.zeros for
plain dtypes, jax.random.key(0) for typed RNG keys.
"""

def _key_to_str(k):
for attr in ("key", "name", "idx"):
if hasattr(k, attr):
return getattr(k, attr)
return str(k)

def _navigate(tree, keys):
cur = tree
for k in keys:
if not isinstance(cur, dict):
return None
if k in cur:
cur = cur[k]
elif str(k) in cur:
cur = cur[str(k)]
else:
return None
return cur

def _default_for(sds):
if not (hasattr(sds, "dtype") and hasattr(sds, "shape")):
return sds
if "key" not in str(sds.dtype):
return jnp.zeros(sds.shape, dtype=sds.dtype)
base = jax.random.key(0)
if sds.shape == ():
return base
return jax.random.split(base, int(np.prod(sds.shape))).reshape(sds.shape)

abs_paths_vars, treedef = jax.tree_util.tree_flatten_with_path(
abstract_nnx_state, is_leaf=lambda x: isinstance(x, nnx.Variable)
)
new_leaves = []
for path, var in abs_paths_vars:
concrete = _navigate(concrete_pure_dict, [_key_to_str(k) for k in path])
if isinstance(var, nnx.Variable):
new_leaves.append(var.replace(value=concrete if concrete is not None else _default_for(var.value)))
else:
new_leaves.append(concrete if concrete is not None else var)
return jax.tree_util.tree_unflatten(treedef, new_leaves)


def _load_linen_full_state_into_nnx(path, abstract_nnx_state, checkpoint_storage_concurrent_gb, use_ocdbt, use_zarr3):
"""Loads a Linen full-state checkpoint into an NNX nnx.State abstract.

Translates between the two layouts: Linen has params/params, top-level
step, and opt_state with inner 'params' wrapping on mu/nu; NNX has model,
optimizer.step, and optimizer.opt_state without that wrapping. NNX-only
RNG/dropout state (not in Linen) is filled with deterministic defaults
(jnp.zeros for plain dtypes, jax.random.key(0) for typed RNG keys).

Args:
path: GCS or local checkpoint directory.
abstract_nnx_state: Target nnx.State, whose to_pure_dict() yields
{model: ..., optimizer: {step, opt_state}}.
checkpoint_storage_concurrent_gb: Concurrent GB for byte I/O.
use_ocdbt: Whether to use OCDBT format.
use_zarr3: Whether to use Zarr3 format.

Returns:
An nnx.State of the same shape as the abstract, with concrete values.
"""
max_logging.log(f"Detected Linen full-state on disk; adapting to NNX abstract at {path}")

pure = abstract_nnx_state.to_pure_dict()
if not (isinstance(pure, dict) and "model" in pure and "optimizer" in pure):
raise ValueError(
f"Expected NNX abstract with model/optimizer; got top-level keys "
f"{list(pure.keys()) if isinstance(pure, dict) else type(pure).__name__}"
)

nnx_opt_state_unboxed = _strip_value_wrappers(pure["optimizer"]["opt_state"])
linen_abstract = {
"params": {"params": _strip_nnx_rng_state(_strip_value_wrappers(pure["model"]))},
"step": _strip_value_wrappers(pure["optimizer"]["step"]),
"opt_state": _translate_nnx_opt_state_to_linen(nnx_opt_state_unboxed),
}

ckptr = ocp.Checkpointer(
ocp.PyTreeCheckpointHandler(
restore_concurrent_gb=checkpoint_storage_concurrent_gb,
save_concurrent_gb=checkpoint_storage_concurrent_gb,
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
)
)
restore_args = ocp.checkpoint_utils.construct_restore_args(linen_abstract)
restored = ckptr.restore(
epath.Path(path),
item=linen_abstract,
restore_args=restore_args,
)

nnx_pure_concrete = {
"model": restored["params"]["params"],
"optimizer": {
"step": restored["step"],
"opt_state": _translate_linen_opt_state_to_nnx(restored["opt_state"], nnx_opt_state_unboxed),
},
}
return _replace_values_in_abstract_nnx_state(abstract_nnx_state, nnx_pure_concrete)


def _load_full_state_from_path(
path,
abstract_unboxed_pre_state,
Expand Down Expand Up @@ -216,6 +378,16 @@ def combine_sharding(sds, shardings):
else:
raise ocp_v1.errors.InvalidLayoutError(f"Unknown checkpoint layout: {source_checkpoint_layout}")
else:
# Cross-format: NNX target reading a Linen-saved full-state checkpoint.
if isinstance(abstract_unboxed_pre_state, nnx.State) and _is_linen_full_state_on_disk(path, use_ocdbt, use_zarr3):
return _load_linen_full_state_into_nnx(
path,
abstract_unboxed_pre_state,
checkpoint_storage_concurrent_gb,
use_ocdbt,
use_zarr3,
)

# Original v0 logic.
p = epath.Path(path)
handler = ocp.PyTreeCheckpointHandler(
Expand Down Expand Up @@ -735,13 +907,142 @@ def setup_checkpoint_logger(config) -> Any | None: # pytype: disable=attribute-
return orbax_cloud_logger


def _peek_on_disk_tree(path, use_ocdbt, use_zarr3):
"""Returns the on-disk pytree metadata at `path`."""
ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler(use_ocdbt=use_ocdbt, use_zarr3=use_zarr3))
metadata = ckptr.metadata(epath.Path(path))
if hasattr(metadata, "item_metadata") and hasattr(metadata.item_metadata, "tree"):
return metadata.item_metadata.tree
if hasattr(metadata, "tree"):
return metadata.tree
return metadata


def _is_linen_format_on_disk(path, use_ocdbt, use_zarr3):
"""Returns True if the on-disk top-level is a Linen params/params/... layout."""
try:
tree = _peek_on_disk_tree(path, use_ocdbt, use_zarr3)
if isinstance(tree, dict) and "params" in tree:
inner = tree["params"]
if isinstance(inner, dict) and "params" in inner:
max_logging.log(f"format-detect: Linen layout detected at {path}")
return True
except Exception as e: # pylint: disable=broad-except
max_logging.log(f"format-detect: could not peek metadata at {path}: {e}")
return False


def _is_linen_full_state_on_disk(path, use_ocdbt, use_zarr3):
"""Returns True if the on-disk top-level is a Linen full-state layout."""
try:
tree = _peek_on_disk_tree(path, use_ocdbt, use_zarr3)
if isinstance(tree, dict) and {"step", "params", "opt_state"}.issubset(tree.keys()):
inner = tree.get("params")
if isinstance(inner, dict) and "params" in inner:
max_logging.log(f"format-detect: Linen full-state layout detected at {path}")
return True
except Exception as e: # pylint: disable=broad-except
max_logging.log(f"format-detect: could not peek full-state metadata at {path}: {e}")
return False


def _strip_value_wrappers(tree):
"""Unwraps {value: x} -> x recursively for leaf-like x."""
if isinstance(tree, dict):
if set(tree.keys()) == {"value"} and not isinstance(tree["value"], dict):
return tree["value"]
return {k: _strip_value_wrappers(v) for k, v in tree.items()}
return tree


def _wrap_with_value(tree):
"""Wraps each leaf as {value: leaf} recursively."""
if isinstance(tree, dict):
return {k: _wrap_with_value(v) for k, v in tree.items()}
return {"value": tree}


def _load_linen_params_into_nnx(path, nnx_abstract_params, checkpoint_storage_concurrent_gb, use_ocdbt, use_zarr3):
"""Loads a Linen-format params checkpoint into an NNX nnx.State abstract.

Restores using a Linen-shape abstract built by stripping the NNX
{value:} Variable boxing, then rebuilds the target nnx.State by
replacing each Variable's .value with the restored array. Sharding
metadata survives because it lives on the ShapeDtypeStruct leaves.

Args:
path: GCS or local checkpoint directory.
nnx_abstract_params: Target nnx.State (typically the nnx.Param group
from nnx.split(model, nnx.Param, ...)).
checkpoint_storage_concurrent_gb: Concurrent GB for byte I/O.
use_ocdbt: Whether to use OCDBT format.
use_zarr3: Whether to use Zarr3 format.

Returns:
An nnx.State of the same shape as the abstract, with concrete values.
"""
max_logging.log(f"Detected Linen format on disk; adapting to NNX abstract at {path}")

pure_dict = nnx_abstract_params.to_pure_dict()
linen_shape_abstract = _strip_value_wrappers(pure_dict)
inner_wrapped = {"params": linen_shape_abstract}

ckptr = ocp.Checkpointer(
ocp.PyTreeCheckpointHandler(
restore_concurrent_gb=checkpoint_storage_concurrent_gb,
save_concurrent_gb=checkpoint_storage_concurrent_gb,
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
)
)
restore_args = ocp.checkpoint_utils.construct_restore_args(inner_wrapped)
restored = ckptr.restore(
epath.Path(path),
item={"params": inner_wrapped},
transforms={},
restore_args={"params": restore_args},
)
linen_restored = restored["params"]["params"]

abstract_leaves, treedef = jax.tree_util.tree_flatten(
nnx_abstract_params,
is_leaf=lambda x: isinstance(x, nnx.Variable),
)
concrete_leaves = jax.tree_util.tree_leaves(linen_restored)
if len(abstract_leaves) != len(concrete_leaves):
raise ValueError(
f"Linen->NNX adapter: leaf count mismatch — "
f"{len(abstract_leaves)} abstract Variables vs "
f"{len(concrete_leaves)} restored arrays. Trees do not align."
)
new_leaves = []
for var, arr in zip(abstract_leaves, concrete_leaves):
if isinstance(var, nnx.Variable):
new_leaves.append(var.replace(value=arr))
else:
new_leaves.append(arr)
return jax.tree_util.tree_unflatten(treedef, new_leaves)


def load_params_from_path(
load_parameters_from_path, abstract_unboxed_params, checkpoint_storage_concurrent_gb, use_ocdbt=True, use_zarr3=True
):
"""Load decode params from checkpoint at specified path."""
assert load_parameters_from_path, "load_parameters_from_path is not defined."
max_logging.log(f"restoring params from {load_parameters_from_path}")

# Cross-format: NNX target reading a Linen-saved checkpoint.
if isinstance(abstract_unboxed_params, nnx.State) and _is_linen_format_on_disk(
load_parameters_from_path, use_ocdbt, use_zarr3
):
return _load_linen_params_into_nnx(
load_parameters_from_path,
abstract_unboxed_params,
checkpoint_storage_concurrent_gb,
use_ocdbt,
use_zarr3,
)

# *_concurrent_gb should be set for large models, the default is 96.
max_logging.log(f"Creating checkpoint manager with ocdbt={use_ocdbt} and zarr3={use_zarr3}")
ckptr = ocp.Checkpointer(
Expand Down
Loading