diff --git a/src/maxtext/checkpoint_conversion/to_maxtext.py b/src/maxtext/checkpoint_conversion/to_maxtext.py index 2f396257aa..0cf06da453 100644 --- a/src/maxtext/checkpoint_conversion/to_maxtext.py +++ b/src/maxtext/checkpoint_conversion/to_maxtext.py @@ -67,7 +67,8 @@ from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING -from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, apply_hook_fns, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys +from maxtext.checkpoint_conversion.utils.tensor_handling import apply_hook_fns +from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys from maxtext.inference.inference_utils import str2bool from maxtext.layers import quantizations from maxtext.models import models diff --git a/src/maxtext/checkpoint_conversion/utils/load_dynamic.py b/src/maxtext/checkpoint_conversion/utils/load_dynamic.py new file mode 100644 index 0000000000..22a119fd42 --- /dev/null +++ b/src/maxtext/checkpoint_conversion/utils/load_dynamic.py @@ -0,0 +1,169 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dynamic loading of HuggingFace checkpoints during training/eval workloads directly in the target format.""" + +import jax +from flax import traverse_util +from flax import nnx +from orbax.checkpoint import v1 as ocp_v1 +from orbax.checkpoint._src.arrays import sharding as sharding_utils + +from maxtext.utils import max_logging +from maxtext.checkpoint_conversion.utils.tensor_handling import _get_hf_loading_function +from maxtext.checkpoint_conversion.utils import param_mapping +from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS +import time + + +def get_hf_config_and_mappings(maxtext_config): + """Gets HF config and parameter mapping based on the MaxText config.""" + model_key = maxtext_config.model_name + if "-Instruct" in model_key: + model_key = model_key.replace("-Instruct", "") + hf_config_obj = HF_MODEL_CONFIGS[model_key] + hf_config_dict = hf_config_obj.to_dict() + + param_map_mt_to_hf = param_mapping.PARAM_MAPPING[model_key]( + hf_config_dict, maxtext_config, scan_layers=maxtext_config.scan_layers + ) + hook_fn_map_mt = param_mapping.HOOK_FNS[model_key]( + hf_config_dict, maxtext_config, scan_layers=maxtext_config.scan_layers, saving_to_hf=False + ) + return param_map_mt_to_hf, hook_fn_map_mt + + +def load_sharded_hf_state(path): + """Loads HF state with maximal sharding across TPU mesh to avoid host OOM.""" + t0 = time.time() + context = ocp_v1.Context(checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS) + with context: + metadata = ocp_v1.pytree_metadata(path) + simple_abstract_state = metadata.metadata + + # Distributed Sharded Download: Tell JAX to shard the HF Safetensors download + # across the TPU mesh as much as mathematically possible to avoid Host OOM and GCS throttling! + shardings = sharding_utils.construct_maximal_shardings(simple_abstract_state) + + def combine_sharding(sds, single_sharding): + return jax.ShapeDtypeStruct(shape=sds.shape, dtype=sds.dtype, sharding=single_sharding) + + sharded_abstract_state = jax.tree.map(combine_sharding, simple_abstract_state, shardings) + + max_logging.log("Reading raw Safetensors into memory (Distributed Sharded GCS Download)...") + hf_state = ocp_v1.load_pytree(path, sharded_abstract_state) + max_logging.log(f"load_sharded_hf_state took {time.time() - t0:.2f}s") + return hf_state + + +def transform_hf_state_to_mt_state( + hf_state, target_tree, param_map_mt_to_hf, hook_fn_map_mt, maxtext_config +): + """Transforms HF state into MaxText state by applying param mappings and mathematical hooks.""" + t0 = time.time() + def tensor_getter(key): + return hf_state[key] + + flat_target = traverse_util.flatten_dict(target_tree, sep=".") + flat_restored = flat_target.copy() + + mapped_count = 0 + keys_missed = [] + max_logging.log("Starting fast in-memory Distributed Transformations...") + + for mt_key, hf_source in param_map_mt_to_hf.items(): + mt_name = mt_key.replace("params-", "").replace("-", ".") + + # Determine the correct key in flat_target + check_name = mt_name + if check_name not in flat_target: + if ("params." + mt_name) in flat_target: + check_name = "params." + mt_name + elif mt_key.replace("-", ".") in flat_target: + check_name = mt_key.replace("-", ".") + + if check_name not in flat_target: + keys_missed.append(mt_name) + continue + + target_shape = flat_target[check_name].shape + hook_fn = hook_fn_map_mt.get(mt_key) + + load_fn = _get_hf_loading_function( + hf_source, + tensor_getter, + hook_fn, + target_shape, + maxtext_config, + ) + + # Execute transformation and assign to flat_restored + t_layer = time.time() + unsharded_array = load_fn() + + # Ensure it's Sharded explicitly matching the JAX model expectations + target_sharding = flat_target[check_name].sharding + flat_restored[check_name] = jax.device_put(unsharded_array, device=target_sharding) + + max_logging.log(f"Transformed {check_name} from {hf_source} in {time.time() - t_layer:.4f}s") + mapped_count += 1 + + if mapped_count == 0: + max_logging.log(f"All transformations missed! Sample missed mt_names: {keys_missed[:5]}") + max_logging.log(f"Sample flat_target keys: {list(flat_target.keys())[:5]}") + + max_logging.log(f"Successfully mapped {mapped_count} parameters.") + restored_params = traverse_util.unflatten_dict(flat_restored, sep=".") + + if "params" in restored_params: + restored_params = restored_params["params"] + + max_logging.log(f"transform_hf_state_to_mt_state took {time.time() - t0:.2f}s") + + return {"params": restored_params} + + +def load_safetensors_dynamic_state(path, abstract_unboxed_pre_state, maxtext_config): + """Main entry point to dynamically build and load safetensors into MaxText format. + + Splits execution into: + 1. Deriving Mappings + 2. Loading Sharded arrays directly to TPUs + 3. Processing the transformations natively on TPUs + """ + if maxtext_config is None: + raise ValueError("maxtext_config must be provided for safetensors_dynamic loading.") + + t_total = time.time() + param_map_mt_to_hf, hook_fn_map_mt = get_hf_config_and_mappings(maxtext_config) + max_logging.log(f"[1/3] Mappings derived in {time.time() - t_total:.2f}s") + + target_tree = ( + abstract_unboxed_pre_state.to_pure_dict() + if isinstance(abstract_unboxed_pre_state, nnx.State) + else abstract_unboxed_pre_state.params + ) + + t1 = time.time() + hf_state = load_sharded_hf_state(path) + max_logging.log(f"[2/3] Distributed Sharded GCS load completed in {time.time() - t1:.2f}s") + + t2 = time.time() + restored_params = transform_hf_state_to_mt_state( + hf_state, target_tree, param_map_mt_to_hf, hook_fn_map_mt, maxtext_config + ) + max_logging.log(f"[3/3] CPU Transformations completed in {time.time() - t2:.2f}s") + max_logging.log(f"Total safetensors_dynamic duration: {time.time() - t_total:.2f}s") + + return None, restored_params diff --git a/src/maxtext/checkpoint_conversion/utils/tensor_handling.py b/src/maxtext/checkpoint_conversion/utils/tensor_handling.py new file mode 100644 index 0000000000..5d7c1b2a8b --- /dev/null +++ b/src/maxtext/checkpoint_conversion/utils/tensor_handling.py @@ -0,0 +1,164 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tensor handling utility functions for checkpoint conversion.""" + +from functools import partial +from typing import Any, Callable, List +import jax.numpy as np + + +def apply_hook_fns(weight, target_shape, hook_fns): + """Apply hook functions, essential for to_maxtext and to_huggingface""" + # If hook is unsepecified, use identity + if hook_fns is None: + return weight + if not isinstance(hook_fns, list): + hook_fns = [hook_fns] + # Apply a list of hooks, be careful of order + for hook_fn in hook_fns: + weight = hook_fn(weight, target_shape) + return weight + + +def _build_multi_axis_stacked_tensor( + hf_source_keys: List[List[str]], + tensor_getter_fn: Callable[[str], np.ndarray], + hook_fns: Any, + target_shape: tuple, + config, +) -> np.ndarray: + """Builds a MaxText tensor by stacking HF weights along two axes (experts and layers). + + This function handles the complex case for scanned MoE layers, producing a tensor + with the shape (num_experts, num_layers, ...). + + Args: + hf_source_keys: A nested (2D) list of Hugging Face parameter names. + Outer list iterates experts, inner list iterates layers. + tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array). + hook_fns: The hook function(s) to apply to each individual weight. + target_shape: The final shape of the target MaxText tensor. + config: The MaxText pyconfig object. + + Returns: + The final, assembled NumPy array for the MaxText parameter. + """ + all_expert_tensors = [] + # The hook function needs the shape of an individual slice, not the full stacked tensor. + # For multi-axis stacking (experts, layers, ...), the slice shape is target_shape[2:] + mt_slice_shape = target_shape[2:] + + # Outer loop iterates through experts + for layer_keys_for_expert in hf_source_keys: + layer_tensors_for_expert = [] + # Inner loop iterates through layers for the current expert + for hf_key_single in layer_keys_for_expert: + hf_tensor_numpy = tensor_getter_fn(hf_key_single) + processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) + layer_tensors_for_expert.append(processed_hf_tensor) + all_expert_tensors.append(np.stack(layer_tensors_for_expert, axis=0)) + return np.stack(all_expert_tensors, axis=0) + + +def _build_single_axis_stacked_tensor( + hf_source_keys: List[str], + tensor_getter_fn: Callable[[str], np.ndarray], + hook_fns: Any, + target_shape: tuple, + config, +) -> np.ndarray: + """Builds a MaxText tensor by stacking HF weights along a single axis. + + This function handles both standard scanned layers (e.g., attention) and + unscanned MoE layers (which are stacked along the expert axis). + + Args: + hf_source_keys: A 1D list of Hugging Face parameter names. + tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array). + hook_fns: The hook function(s) to apply to each individual weight. + target_shape: The final shape of the target MaxText tensor. + config: The MaxText pyconfig object. + + Returns: + The final, assembled NumPy array for the MaxText parameter. + """ + tensors_to_stack = [] + + if config.scan_layers: + # If it's a standard scanned layer, we use the configured param_scan_axis. + axis_to_stack = config.param_scan_axis + else: + # Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0). + axis_to_stack = 0 + + # The hook function needs the shape of an individual slice, not the full stacked tensor. + # We calculate it by removing the stacking dimension from the final target shape. + mt_slice_shape_list = list(target_shape) + del mt_slice_shape_list[axis_to_stack] + mt_slice_shape = tuple(mt_slice_shape_list) + + for hf_key_single in hf_source_keys: + hf_tensor_numpy = tensor_getter_fn(hf_key_single) + processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) + tensors_to_stack.append(processed_hf_tensor) + + # Stack all processed tensors along the determined axis. + return np.stack(tensors_to_stack, axis=axis_to_stack) + + +def _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_target_shape_or_shapes, config): + """Determine the loading function for HF keys. + HF keys can take four forms: + Case 1: Unscanned (single string) + Case 2: Scanned (list of strings) + Case 3: Unscanned with expert stacking (list of strings) + Case 4: Scanned with expert stacking (nested list of strings) + """ + load_fn = None + if not isinstance(hf_source_keys_or_key, list): + # Case 1: Single hf key (str) + def _loader(getter, key, shape, hook): + return apply_hook_fns(getter(key), shape, hook) + + load_fn = partial( + _loader, + tensor_getter, + hf_source_keys_or_key, + mt_target_shape_or_shapes, + hook_fn, + ) + # Stacked mapping + elif not isinstance(hf_source_keys_or_key[0], list): + # Case 2 or 3: Single-Axis Stacked hf keys (un-nested list) + load_fn = partial( + _build_single_axis_stacked_tensor, + hf_source_keys_or_key, + tensor_getter, + hook_fn, + mt_target_shape_or_shapes, + config, + ) + else: + # isinstance(hf_source_keys_or_key[0], list) + # Case 4: Multi-Axis Stacked hf keys (nested list) + load_fn = partial( + _build_multi_axis_stacked_tensor, + hf_source_keys_or_key, + tensor_getter, + hook_fn, + mt_target_shape_or_shapes, + config, + ) + return load_fn diff --git a/src/maxtext/checkpoint_conversion/utils/utils.py b/src/maxtext/checkpoint_conversion/utils/utils.py index 93253cffb0..04c8db7552 100644 --- a/src/maxtext/checkpoint_conversion/utils/utils.py +++ b/src/maxtext/checkpoint_conversion/utils/utils.py @@ -23,7 +23,8 @@ import time import json from concurrent.futures import ThreadPoolExecutor -from typing import Any +from functools import partial +from typing import Any, Callable, List from tqdm import tqdm import resource import numpy as np @@ -1165,3 +1166,135 @@ def save_weights_to_checkpoint( checkpoint_manager.wait_until_finished() max_logging.log(f"Elapse for checkpoint save: {(time.time() - start) / 60:.2f} min") + + +def _build_multi_axis_stacked_tensor( + hf_source_keys: List[List[str]], + tensor_getter_fn: Callable[[str], np.ndarray], + hook_fns: Any, + target_shape: tuple, + config, +) -> np.ndarray: + """Builds a MaxText tensor by stacking HF weights along two axes (experts and layers). + + This function handles the complex case for scanned MoE layers, producing a tensor + with the shape (num_experts, num_layers, ...). + + Args: + hf_source_keys: A nested (2D) list of Hugging Face parameter names. + Outer list iterates experts, inner list iterates layers. + tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array). + hook_fns: The hook function(s) to apply to each individual weight. + target_shape: The final shape of the target MaxText tensor. + config: The MaxText pyconfig object. + + Returns: + The final, assembled NumPy array for the MaxText parameter. + """ + all_expert_tensors = [] + # The hook function needs the shape of an individual slice, not the full stacked tensor. + # For multi-axis stacking (experts, layers, ...), the slice shape is target_shape[2:] + mt_slice_shape = target_shape[2:] + + # Outer loop iterates through experts + for layer_keys_for_expert in hf_source_keys: + layer_tensors_for_expert = [] + # Inner loop iterates through layers for the current expert + for hf_key_single in layer_keys_for_expert: + hf_tensor_numpy = tensor_getter_fn(hf_key_single) + processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) + layer_tensors_for_expert.append(processed_hf_tensor) + all_expert_tensors.append(np.stack(layer_tensors_for_expert, axis=0)) + return np.stack(all_expert_tensors, axis=0) + + +def _build_single_axis_stacked_tensor( + hf_source_keys: List[str], + tensor_getter_fn: Callable[[str], np.ndarray], + hook_fns: Any, + target_shape: tuple, + config, +) -> np.ndarray: + """Builds a MaxText tensor by stacking HF weights along a single axis. + + This function handles both standard scanned layers (e.g., attention) and + unscanned MoE layers (which are stacked along the expert axis). + + Args: + hf_source_keys: A 1D list of Hugging Face parameter names. + tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array). + hook_fns: The hook function(s) to apply to each individual weight. + target_shape: The final shape of the target MaxText tensor. + config: The MaxText pyconfig object. + + Returns: + The final, assembled NumPy array for the MaxText parameter. + """ + tensors_to_stack = [] + + if config.scan_layers: + # If it's a standard scanned layer, we use the configured param_scan_axis. + axis_to_stack = config.param_scan_axis + else: + # Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0). + axis_to_stack = 0 + + # The hook function needs the shape of an individual slice, not the full stacked tensor. + # We calculate it by removing the stacking dimension from the final target shape. + mt_slice_shape_list = list(target_shape) + del mt_slice_shape_list[axis_to_stack] + mt_slice_shape = tuple(mt_slice_shape_list) + + for hf_key_single in hf_source_keys: + hf_tensor_numpy = tensor_getter_fn(hf_key_single) + processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) + tensors_to_stack.append(processed_hf_tensor) + + # Stack all processed tensors along the determined axis. + return np.stack(tensors_to_stack, axis=axis_to_stack) + + +def _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_target_shape_or_shapes, config): + """Determine the loading function for HF keys. + HF keys can take four forms: + Case 1: Unscanned (single string) + Case 2: Scanned (list of strings) + Case 3: Unscanned with expert stacking (list of strings) + Case 4: Scanned with expert stacking (nested list of strings) + """ + load_fn = None + if not isinstance(hf_source_keys_or_key, list): + # Case 1: Single hf key (str) + def _loader(getter, key, shape, hook): + return apply_hook_fns(getter(key), shape, hook) + + load_fn = partial( + _loader, + tensor_getter, + hf_source_keys_or_key, + mt_target_shape_or_shapes, + hook_fn, + ) + # Stacked mapping + elif not isinstance(hf_source_keys_or_key[0], list): + # Case 2 or 3: Single-Axis Stacked hf keys (un-nested list) + load_fn = partial( + _build_single_axis_stacked_tensor, + hf_source_keys_or_key, + tensor_getter, + hook_fn, + mt_target_shape_or_shapes, + config, + ) + else: + # isinstance(hf_source_keys_or_key[0], list) + # Case 4: Multi-Axis Stacked hf keys (nested list) + load_fn = partial( + _build_multi_axis_stacked_tensor, + hf_source_keys_or_key, + tensor_getter, + hook_fn, + mt_target_shape_or_shapes, + config, + ) + return load_fn diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index 9fc3930a1c..780a60cb7a 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -582,6 +582,7 @@ def load_state_if_possible( checkpoint_conversion_fn=None, source_checkpoint_layout="orbax", expansion_factor_real_data: int = -1, + maxtext_config: Any | None = None, ): """Loads TrainState as possible from the inputs. @@ -684,7 +685,16 @@ def map_to_pspec(data): case _: return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None) - if load_parameters_from_path != "": + if source_checkpoint_layout == "safetensors_dynamic": + path = load_parameters_from_path or load_full_state_from_path + max_logging.log(f"Dynamic On-the-Fly Formatting: Loading SafeTensors from {path}") + + from maxtext.checkpoint_conversion.utils.load_dynamic import load_safetensors_dynamic_state + + return load_safetensors_dynamic_state( + path, abstract_unboxed_pre_state, maxtext_config + ) + elif load_parameters_from_path != "": if isinstance(abstract_unboxed_pre_state, nnx.State): _, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...) else: @@ -696,6 +706,9 @@ def map_to_pspec(data): checkpoint_storage_concurrent_gb, use_ocdbt=use_ocdbt, use_zarr3=use_zarr3, + enable_orbax_v1=enable_orbax_v1, + source_checkpoint_layout=source_checkpoint_layout, + checkpoint_conversion_fn=checkpoint_conversion_fn, ) return None, restored_params elif load_full_state_from_path != "": @@ -736,35 +749,68 @@ def setup_checkpoint_logger(config) -> Any | None: # pytype: disable=attribute- def load_params_from_path( - load_parameters_from_path, abstract_unboxed_params, checkpoint_storage_concurrent_gb, use_ocdbt=True, use_zarr3=True + load_parameters_from_path, + abstract_unboxed_params, + checkpoint_storage_concurrent_gb, + use_ocdbt=True, + use_zarr3=True, + enable_orbax_v1=False, + source_checkpoint_layout="orbax", + checkpoint_conversion_fn=None, ): """Load decode params from checkpoint at specified path.""" assert load_parameters_from_path, "load_parameters_from_path is not defined." max_logging.log(f"restoring params from {load_parameters_from_path}") - # *_concurrent_gb should be set for large models, the default is 96. - max_logging.log(f"Creating checkpoint manager with ocdbt={use_ocdbt} and zarr3={use_zarr3}") - ckptr = ocp.Checkpointer( - ocp.PyTreeCheckpointHandler( - restore_concurrent_gb=checkpoint_storage_concurrent_gb, - save_concurrent_gb=checkpoint_storage_concurrent_gb, - use_ocdbt=use_ocdbt, - use_zarr3=use_zarr3, - ) - ) + if enable_orbax_v1: + if source_checkpoint_layout == "orbax": + context = ocp_v1.Context(checkpoint_layout=ocp_v1.options.CheckpointLayout.ORBAX) + with context: + restored = ocp_v1.load_pytree(load_parameters_from_path, {"params": abstract_unboxed_params}) + return restored["params"] + elif source_checkpoint_layout == "safetensors": + context = ocp_v1.Context(checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS) + with context: + metadata = ocp_v1.pytree_metadata(load_parameters_from_path) + simple_abstract_state = metadata.metadata + shardings = sharding_utils.construct_maximal_shardings(simple_abstract_state) - # This is a memory optimization. We don't want to restore the entire checkpoint - only the params. - # Rather than pass the entire abstract state, which could unnecessarily restore opt_state and such and waste - # memory, we instead specify here that we are just restoring the params field of the checkpoint - # (which itself may be a dictionary containing a key named 'params'). - restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_unboxed_params) - restored = ckptr.restore( - epath.Path(load_parameters_from_path), - item={"params": abstract_unboxed_params}, - transforms={}, - restore_args={"params": restore_args}, - ) - return restored["params"] + def combine_sharding(sds, shardings): + return jax.ShapeDtypeStruct(shape=sds.shape, dtype=sds.dtype, sharding=shardings) + + sharded_abstract_state = jax.tree.map(combine_sharding, simple_abstract_state, shardings) + pre_transformed_state = ocp_v1.load_pytree(load_parameters_from_path, sharded_abstract_state) + if checkpoint_conversion_fn: + pre_transformed_state = checkpoint_conversion_fn(pre_transformed_state) + if "params" in pre_transformed_state: + return pre_transformed_state["params"] + return pre_transformed_state + else: + raise ocp_v1.errors.InvalidLayoutError(f"Unknown checkpoint layout: {source_checkpoint_layout}") + else: + # *_concurrent_gb should be set for large models, the default is 96. + max_logging.log(f"Creating checkpoint manager with ocdbt={use_ocdbt} and zarr3={use_zarr3}") + ckptr = ocp.Checkpointer( + ocp.PyTreeCheckpointHandler( + restore_concurrent_gb=checkpoint_storage_concurrent_gb, + save_concurrent_gb=checkpoint_storage_concurrent_gb, + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + ) + ) + + # This is a memory optimization. We don't want to restore the entire checkpoint - only the params. + # Rather than pass the entire abstract state, which could unnecessarily restore opt_state and such and waste + # memory, we instead specify here that we are just restoring the params field of the checkpoint + # (which itself may be a dictionary containing a key named 'params'). + restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_unboxed_params) + restored = ckptr.restore( + epath.Path(load_parameters_from_path), + item={"params": abstract_unboxed_params}, + transforms={}, + restore_args={"params": restore_args}, + ) + return restored["params"] def save_params_to_path(checkpoint_dir, params, use_ocdbt=True, use_zarr3=True): diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 9e1e4b58cd..939a7ad94f 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -343,7 +343,7 @@ class Checkpointing(BaseModel): save_quantized_params_path: PathStr = Field("", description="Path to save params quantized on the fly.") enable_orbax_v1: bool = Field(False, description="Bool flag for enabling Orbax v1.") checkpoint_conversion_fn: None | str = Field(None, description="Function for processing loaded checkpoint dict.") - source_checkpoint_layout: Literal["orbax", "safetensors"] = Field( + source_checkpoint_layout: Literal["orbax", "safetensors", "safetensors_dynamic"] = Field( "orbax", description="The layout of the source checkpoint to load." ) save_checkpoint_on_completion: bool = Field( diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 6d62b981dd..0c726fc42e 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -1492,6 +1492,7 @@ def setup_initial_state( checkpoint_conversion_fn=config.checkpoint_conversion_fn, source_checkpoint_layout=config.source_checkpoint_layout, expansion_factor_real_data=config.expansion_factor_real_data, + maxtext_config=config, ) if restored: diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index ca90550630..8c5b8e3cb1 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -338,6 +338,7 @@ def create_train_state_fn(): enable_orbax_v1=config.enable_orbax_v1, checkpoint_conversion_fn=config.checkpoint_conversion_fn, source_checkpoint_layout=config.source_checkpoint_layout, + maxtext_config=config, ) except FileNotFoundError: step0_restored = None