diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index ad7618868a..f24dc841a8 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -198,11 +198,11 @@ def _load_full_state_from_path( if source_checkpoint_layout == "orbax": context = ocp_v1.Context(checkpoint_layout=ocp_v1.options.CheckpointLayout.ORBAX) with context: - return ocp_v1.load_pytree(path, abstract_unboxed_pre_state) + return ocp_v1.load(path, abstract_state=abstract_unboxed_pre_state) elif source_checkpoint_layout == "safetensors": context = ocp_v1.Context(checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS) with context: - metadata = ocp_v1.pytree_metadata(path) + metadata = ocp_v1.metadata(path) simple_abstract_state = metadata.metadata shardings = sharding_utils.construct_maximal_shardings(simple_abstract_state) @@ -210,7 +210,9 @@ def combine_sharding(sds, shardings): return jax.ShapeDtypeStruct(shape=sds.shape, dtype=sds.dtype, sharding=shardings) sharded_abstract_state = jax.tree.map(combine_sharding, simple_abstract_state, shardings) - pre_transformed_state = ocp_v1.load_pytree(path, sharded_abstract_state) + pre_transformed_state = ocp_v1.load( + path, abstract_state=sharded_abstract_state + ) state = checkpoint_conversion_fn(pre_transformed_state) return state else: