From 31604148930a4c96e96aa06fedb13863262d5487 Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Mon, 18 May 2026 10:13:51 -0700 Subject: [PATCH] Rename `save/load_pytree` to `save/load`. Eliminate most user-facing "pytree" terminology in favor of "state" as a more specific term. Add `deprecations.py` for handling deprecated public functions. PiperOrigin-RevId: 917299781 --- src/maxtext/common/checkpointing.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index ad7618868a..f24dc841a8 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -198,11 +198,11 @@ def _load_full_state_from_path( if source_checkpoint_layout == "orbax": context = ocp_v1.Context(checkpoint_layout=ocp_v1.options.CheckpointLayout.ORBAX) with context: - return ocp_v1.load_pytree(path, abstract_unboxed_pre_state) + return ocp_v1.load(path, abstract_state=abstract_unboxed_pre_state) elif source_checkpoint_layout == "safetensors": context = ocp_v1.Context(checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS) with context: - metadata = ocp_v1.pytree_metadata(path) + metadata = ocp_v1.metadata(path) simple_abstract_state = metadata.metadata shardings = sharding_utils.construct_maximal_shardings(simple_abstract_state) @@ -210,7 +210,9 @@ def combine_sharding(sds, shardings): return jax.ShapeDtypeStruct(shape=sds.shape, dtype=sds.dtype, sharding=shardings) sharded_abstract_state = jax.tree.map(combine_sharding, simple_abstract_state, shardings) - pre_transformed_state = ocp_v1.load_pytree(path, sharded_abstract_state) + pre_transformed_state = ocp_v1.load( + path, abstract_state=sharded_abstract_state + ) state = checkpoint_conversion_fn(pre_transformed_state) return state else: