From c9f493eb410d3cf616ae314ee0762a4958888221 Mon Sep 17 00:00:00 2001 From: A9isha Date: Tue, 10 Mar 2026 23:16:40 +0000 Subject: [PATCH] Make tokenizer_path to be non-mandatory --- .../convert_checkpoint.md | 2 +- docs/tutorials/posttraining/rl.md | 15 ++-- .../posttraining/rl_on_multi_host.md | 6 +- docs/tutorials/posttraining/sft.md | 12 +-- .../posttraining/sft_on_multi_host.md | 5 +- .../checkpoint_conversion/compare_hf_ckpt.py | 3 +- .../checkpoint_conversion/to_huggingface.py | 5 +- .../checkpoint_conversion/to_maxtext.py | 77 +++++++++++++++---- .../checkpoint_conversion/utils/utils.py | 35 --------- src/maxtext/configs/base.yml | 2 +- src/maxtext/configs/pyconfig.py | 42 ++++++++-- src/maxtext/configs/types.py | 6 +- src/maxtext/utils/globals.py | 35 +++++++++ 13 files changed, 161 insertions(+), 84 deletions(-) diff --git a/docs/guides/checkpointing_solutions/convert_checkpoint.md b/docs/guides/checkpointing_solutions/convert_checkpoint.md index 6c606fb813..224f5c9a2f 100644 --- a/docs/guides/checkpointing_solutions/convert_checkpoint.md +++ b/docs/guides/checkpointing_solutions/convert_checkpoint.md @@ -221,7 +221,7 @@ To extend conversion support to a new model architecture, you must define its sp - In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer. 2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion. -3. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/utils.py), add the new model key in `HF_IDS`. +3. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py), add the new model key in `HF_IDS`. 4. **Add transformer config**: In [`utils/hf_model_configs.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/hf_model_configs.py), add the `transformers.Config` object, describing the Hugging Face model configuration (defined in [`src/maxtext/configs/models`](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/configs/models)). **Note**: This configuration must precisely match the MaxText model's architecture. Here is an example [PR to add support for gemma3 multi-modal model](https://github.com/AI-Hypercomputer/maxtext/pull/1983) diff --git a/docs/tutorials/posttraining/rl.md b/docs/tutorials/posttraining/rl.md index 7f2c366c26..d1284e2eb5 100644 --- a/docs/tutorials/posttraining/rl.md +++ b/docs/tutorials/posttraining/rl.md @@ -86,14 +86,17 @@ install_maxtext_tpu_post_train_extra_deps ## Setup environment variables +Follow the instructions [here](https://huggingface.co/docs/huggingface_hub/v0.21.2/guides/cli) to login to Hugging Face using your access token using + +```bash +huggingface-cli login +``` + Setup following environment variables before running GRPO/GSPO: ```bash # -- Model configuration -- -export HF_MODEL= # e.g. 'llama3.1-8b-Instruct' -export MODEL= # e.g. 'llama3.1-8b' -export TOKENIZER= # e.g. 'meta-llama/Llama-3.1-8B-Instruct' -export HF_TOKEN= +export MODEL= # e.g. 'llama3.1-8b-Instruct' # -- MaxText configuration -- export BASE_OUTPUT_DIRECTORY= # e.g., gs://my-bucket/my-output-directory @@ -135,11 +138,9 @@ Run the following command for GRPO: ``` python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \ model_name=${MODEL?} \ - tokenizer_path=${TOKENIZER?} \ load_parameters_path=${MAXTEXT_CKPT_PATH?} \ run_name=${RUN_NAME?} \ base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ - hf_access_token=${HF_TOKEN?} \ chips_per_vm=${CHIPS_PER_VM?} ``` @@ -159,11 +160,9 @@ Run the following command for GSPO: ``` python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \ model_name=${MODEL?} \ - tokenizer_path=${TOKENIZER?} \ load_parameters_path=${MAXTEXT_CKPT_PATH?} \ run_name=${RUN_NAME?} \ base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ - hf_access_token=${HF_TOKEN?} \ loss_algo=gspo-token \ chips_per_vm=${CHIPS_PER_VM?} ``` diff --git a/docs/tutorials/posttraining/rl_on_multi_host.md b/docs/tutorials/posttraining/rl_on_multi_host.md index e22407b861..29009e1ef6 100644 --- a/docs/tutorials/posttraining/rl_on_multi_host.md +++ b/docs/tutorials/posttraining/rl_on_multi_host.md @@ -69,9 +69,7 @@ actual values. ```bash # -- Model configuration -- -export HF_MODEL= # e.g. 'llama3.1-70b-Instruct' -export MODEL= # e.g. 'llama3.1-70b' -export TOKENIZER= # e.g. 'meta-llama/Llama-3.1-70B-Instruct' +export MODEL= # e.g. 'llama3.1-70b-Instruct' export HF_TOKEN= # -- MaxText configuration -- @@ -200,7 +198,6 @@ xpk workload create-pathways --workload ${WORKLOAD?} \ --command "HF_TOKEN=${HF_TOKEN?} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \ python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \ model_name=${MODEL?} \ - tokenizer_path=${TOKENIZER?} \ load_parameters_path=${MAXTEXT_CKPT_PATH?} \ run_name=${WORKLOAD?} \ base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ @@ -218,7 +215,6 @@ xpk workload create-pathways --workload ${WORKLOAD?} \ --command "HF_TOKEN=${HF_TOKEN?} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \ python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \ model_name=${MODEL?} \ - tokenizer_path=${TOKENIZER?} \ load_parameters_path=${MAXTEXT_CKPT_PATH?} \ run_name=${WORKLOAD?} \ base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ diff --git a/docs/tutorials/posttraining/sft.md b/docs/tutorials/posttraining/sft.md index cb3ff85baf..bd7cdd00ac 100644 --- a/docs/tutorials/posttraining/sft.md +++ b/docs/tutorials/posttraining/sft.md @@ -43,13 +43,17 @@ install_maxtext_tpu_post_train_extra_deps ## Setup environment variables +Follow the instructions [here](https://huggingface.co/docs/huggingface_hub/v0.21.2/guides/cli) to login to Hugging Face using your access token using + +```bash +huggingface-cli login +``` + Set the following environment variables before running SFT. ```sh # -- Model configuration -- -export PRE_TRAINED_MODEL= # e.g., 'llama3.1-8b' -export PRE_TRAINED_MODEL_TOKENIZER= # e.g., 'meta-llama/Llama-3.1-8B-Instruct' -export HF_TOKEN= +export PRE_TRAINED_MODEL= # e.g., 'llama3.1-8b-Instruct' # -- MaxText configuration -- export BASE_OUTPUT_DIRECTORY= # e.g., gs://my-bucket/my-output-directory @@ -93,8 +97,6 @@ python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_tr base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ model_name=${PRE_TRAINED_MODEL?} \ load_parameters_path=${PRE_TRAINED_MODEL_CKPT_PATH?} \ - hf_access_token=${HF_TOKEN?} \ - tokenizer_path=${PRE_TRAINED_MODEL_TOKENIZER?} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE?} \ steps=${STEPS?} \ hf_path=${DATASET_NAME?} \ diff --git a/docs/tutorials/posttraining/sft_on_multi_host.md b/docs/tutorials/posttraining/sft_on_multi_host.md index 55273d5475..aa2bb432a0 100644 --- a/docs/tutorials/posttraining/sft_on_multi_host.md +++ b/docs/tutorials/posttraining/sft_on_multi_host.md @@ -95,7 +95,6 @@ export HF_TOKEN= # -- Model Configuration -- export MODEL_NAME= # e.g., deepseek3-671b -export TOKENIZER_PATH= # e.g., deepseek-ai/DeepSeek-V3 # -- Dataset configuration -- export DATASET_NAME= # e.g., HuggingFaceH4/ultrachat_200k @@ -143,7 +142,7 @@ xpk workload create \ --workload=${WORKLOAD_NAME?} \ --tpu-type=${TPU_TYPE?} \ --num-slices=${TPU_SLICE?} \ ---command "python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} tokenizer_path=${TOKENIZER_PATH?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane hf_path=${DATASET_NAME?} train_split=${TRAIN_SPLIT?} train_data_columns=${TRAIN_DATA_COLUMNS?}" +--command "python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane hf_path=${DATASET_NAME?} train_split=${TRAIN_SPLIT?} train_data_columns=${TRAIN_DATA_COLUMNS?}" ``` Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`. @@ -159,7 +158,7 @@ xpk workload create-pathways \ --workload=${WORKLOAD_NAME?} \ --tpu-type=${TPU_TYPE?} \ --num-slices=${TPU_SLICE?} \ ---command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} tokenizer_path=${TOKENIZER_PATH?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False enable_single_controller=True" +--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False enable_single_controller=True" ``` Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`. diff --git a/src/maxtext/checkpoint_conversion/compare_hf_ckpt.py b/src/maxtext/checkpoint_conversion/compare_hf_ckpt.py index 71e8f430e6..44b011343a 100644 --- a/src/maxtext/checkpoint_conversion/compare_hf_ckpt.py +++ b/src/maxtext/checkpoint_conversion/compare_hf_ckpt.py @@ -48,8 +48,9 @@ from safetensors import safe_open from maxtext.configs import pyconfig -from maxtext.checkpoint_conversion.utils.utils import HF_IDS, print_ram_usage, get_hf_model +from maxtext.checkpoint_conversion.utils.utils import print_ram_usage, get_hf_model from maxtext.utils import max_logging +from maxtext.utils.globals import HF_IDS jax.config.update("jax_platform_name", "cpu") diff --git a/src/maxtext/checkpoint_conversion/to_huggingface.py b/src/maxtext/checkpoint_conversion/to_huggingface.py index 6c6c18023e..4aa5429deb 100644 --- a/src/maxtext/checkpoint_conversion/to_huggingface.py +++ b/src/maxtext/checkpoint_conversion/to_huggingface.py @@ -19,7 +19,7 @@ Key Parameters (to be set in the config file or as command-line overrides): model_name: (Required) The name of the model to convert (e.g., "gemma2-2b"). - Must be a key in `maxtext.checkpoint_conversion.utils.utils.HF_IDS`. + Must be a key in `maxtext.utils.globals.HF_IDS`. load_parameters_path: (Required) Path to the MaxText checkpoint directory containing the parameter-only checkpoint. base_output_directory: (Optional) The directory where the converted HuggingFace @@ -79,12 +79,13 @@ save_model_files, load_orbax_checkpoint, detect_and_extract_checkpoint, - HF_IDS, MemoryMonitorTqdm, print_peak_memory, ) from maxtext.utils import max_logging from maxtext.utils import max_utils +from maxtext.utils.globals import HF_IDS + flags.DEFINE_bool( "override_model_architecture", diff --git a/src/maxtext/checkpoint_conversion/to_maxtext.py b/src/maxtext/checkpoint_conversion/to_maxtext.py index 06f95ed85a..26dbeb214b 100644 --- a/src/maxtext/checkpoint_conversion/to_maxtext.py +++ b/src/maxtext/checkpoint_conversion/to_maxtext.py @@ -18,7 +18,7 @@ Key Parameters (to be set in the config file or as command-line overrides): model_name: (Required) The name of the model to convert (e.g., "gemma2-2b"). - Must be a key in `maxtext.checkpoint_conversion.utils.utils.HF_IDS`. + Must be a key in `maxtext.utils.globals.HF_IDS`. base_output_directory: (Optional) The directory where the converted HuggingFace checkpoint will be saved. Can be a local path, a GCS path (gs://...), or a HuggingFace Hub repo ID (hf://...). @@ -30,7 +30,7 @@ Defaults to False. --hf_model_path: (Optional) Specifies a local or remote directory containing the model weights. If unspecified, we use the default Hugging Face repository ID - (e.g., openai/gpt-oss-20b; see `HF_IDS[model_name]` in `utils/ckpt_conversion/utils`). + (e.g., openai/gpt-oss-20b; see `HF_IDS[model_name]` in `maxtext.utils.globals`). This is necessary for locally dequantized models like GPT-OSS or DeepSeek. Environment Variables: @@ -74,11 +74,12 @@ from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.checkpoint_conversion.standalone_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING -from maxtext.checkpoint_conversion.utils.utils import HF_IDS, MemoryMonitorTqdm, apply_hook_fns, get_hf_model, print_peak_memory, print_ram_usage, validate_and_filter_param_map_keys +from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, apply_hook_fns, get_hf_model, print_peak_memory, print_ram_usage, validate_and_filter_param_map_keys from maxtext.inference.inference_utils import str2bool from maxtext.layers import quantizations from maxtext.models import models from maxtext.utils import max_logging, max_utils, maxtext_utils +from maxtext.utils.globals import HF_IDS import numpy as np from orbax.checkpoint import type_handlers from safetensors import safe_open @@ -155,7 +156,12 @@ def _initialize_index(self): if self.is_local: index_path = os.path.join(self.model_id, index_file) else: - index_path = hf_hub_download(repo_id=self.model_id, filename=index_file, token=self.token, revision=self.revision) + index_path = hf_hub_download( + repo_id=self.model_id, + filename=index_file, + token=self.token, + revision=self.revision, + ) with open(index_path, "r", encoding="utf-8") as f: index_data = json.load(f) self.shard_map = index_data["weight_map"] @@ -185,7 +191,12 @@ def get_tensor(self, key: str) -> np.ndarray: else: # STEP 1: Download outside the lock. # multiple threads can download different shards at the same time. - local_path = hf_hub_download(repo_id=self.model_id, filename=shard_name, token=self.token, revision=self.revision) + local_path = hf_hub_download( + repo_id=self.model_id, + filename=shard_name, + token=self.token, + revision=self.revision, + ) # STEP 2: Lock ONLY the reading into RAM. # This prevents multiple threads from simultaneously allocating large chunks of RAM. @@ -200,7 +211,13 @@ class LazyTensor: and transformation until __array__ is called (e.g., by Orbax during save). """ - def __init__(self, load_fn: Callable[[], np.ndarray], shape: tuple, dtype, name: str = "unknown"): + def __init__( + self, + load_fn: Callable[[], np.ndarray], + shape: tuple, + dtype, + name: str = "unknown", + ): self._load_fn = load_fn self.shape = shape self.dtype = np.dtype(dtype) @@ -421,7 +438,13 @@ def _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_t def _loader(getter, key, shape, hook): return apply_hook_fns(getter(key), shape, hook) - load_fn = partial(_loader, tensor_getter, hf_source_keys_or_key, mt_target_shape_or_shapes, hook_fn) + load_fn = partial( + _loader, + tensor_getter, + hf_source_keys_or_key, + mt_target_shape_or_shapes, + hook_fn, + ) # Stacked mapping elif not isinstance(hf_source_keys_or_key[0], list): # Case 2 or 3: Single-Axis Stacked hf keys (un-nested list) @@ -516,7 +539,12 @@ def _get_maxtext_weight( # to load the tensor later (the `load_fn`, shape, dtype). # The actual data will only be loaded when Orbax calls `__array__` # on this object during the saving process. - final_mt_tensor_numpy = LazyTensor(load_fn, mt_target_shape_or_shapes, config.weight_dtype, name=mt_param_key_or_keys) + final_mt_tensor_numpy = LazyTensor( + load_fn, + mt_target_shape_or_shapes, + config.weight_dtype, + name=mt_param_key_or_keys, + ) if not is_composite_mt_key: # Case 2.1: Lazy mode, `atomic_mt_key` final_mt_weights[mt_target_idx_or_indices] = final_mt_tensor_numpy @@ -562,7 +590,10 @@ def main( # check the supported model ids if model_name_original not in HF_IDS: - raise ValueError(f"Unsupported model name: {model_name_original}. Supported models are: {list(HF_IDS.keys())}") + raise ValueError( + f"Unsupported model name: {model_name_original}.\ + Supported models are: {list(HF_IDS.keys())}" + ) model_id = hf_model_path or HF_IDS[model_name_original] @@ -633,7 +664,11 @@ def _eager_getter(key): filtered_map_keys = validate_and_filter_param_map_keys(param_map_mt_to_hf.keys(), maxtext_abstract_dict.keys()) for mt_param_key_or_keys in MemoryMonitorTqdm( - filtered_map_keys, desc="Transforming weights", unit="param", leave=True, dynamic_ncols=True + filtered_map_keys, + desc="Transforming weights", + unit="param", + leave=True, + dynamic_ncols=True, ): if not lazy_load_tensors: max_logging.log(f"maxtext param: {mt_param_key_or_keys}") @@ -651,7 +686,13 @@ def _eager_getter(key): # Step 2: Determine the loading function for hf key # based on hf_key form (unscanned, scanned, unscanned with expert stacking, or scanned with expert stacking) - load_fn = _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_target_shape_or_shapes, config) + load_fn = _get_hf_loading_function( + hf_source_keys_or_key, + tensor_getter, + hook_fn, + mt_target_shape_or_shapes, + config, + ) # Step 3: Load hf keys and convert to maxtext keys # based on tensor load mode (lazy, eager) and MaxText key form (`atomic_mt_key` or `composite_mt_key`) @@ -710,9 +751,13 @@ def _eager_getter(key): default=False, help="Whether to use lazy loading of HF tensors.", ) - # If not specified, default to maxtext.checkpoint_conversion.utils.utils.HF_IDS[model_name] + # If not specified, default to maxtext.utils.globals.HF_IDS[model_name] parser.add_argument( - "--hf_model_path", type=str, required=False, default="", help="local path to hf model, or custom remote hf repo" + "--hf_model_path", + type=str, + required=False, + default="", + help="local path to hf model, or custom remote hf repo", ) # Determines the logical sharding of the output checkpoint by partitioning # weights across virtual XLA devices. @@ -730,7 +775,11 @@ def _eager_getter(key): parser.add_argument("--simulated_cpu_devices_count", type=int, required=False, default=16) parser.add_argument( - "--revision", type=str, required=False, default=None, help="Specific Hugging Face revision (branch/tag/commit)" + "--revision", + type=str, + required=False, + default=None, + help="Specific Hugging Face revision (branch/tag/commit)", ) # Parse local arguments diff --git a/src/maxtext/checkpoint_conversion/utils/utils.py b/src/maxtext/checkpoint_conversion/utils/utils.py index 31607883f8..bb40812ea3 100644 --- a/src/maxtext/checkpoint_conversion/utils/utils.py +++ b/src/maxtext/checkpoint_conversion/utils/utils.py @@ -56,41 +56,6 @@ DEFAULT_MAX_SHARD_SIZE = 1024 * 1024 * 1024 * 3 # 3GB default -# Mapping from MaxText model key to Hugging Face tokenizer identifiers -HF_IDS = { - "gemma2-2b": "google/gemma-2-2b", - "gemma2-9b": "google/gemma-2-9b", - "gemma2-27b": "google/gemma-2-27b", - "gemma3-4b": "google/gemma-3-4b-it", # hf multi-modal should also support the pure-text - "gemma3-12b": "google/gemma-3-12b-it", - "gemma3-27b": "google/gemma-3-27b-it", - "qwen3-0.6b": "Qwen/Qwen3-0.6B", - "qwen3-4b": "Qwen/Qwen3-4B", - "qwen3-4b-thinking-2507": "Qwen/Qwen3-4B-Thinking-2507", - "qwen3-8b": "Qwen/Qwen3-8B", - "qwen3-14b": "Qwen/Qwen3-14B", - "qwen3-32b": "Qwen/Qwen3-32B", - "llama3.1-8b": "meta-llama/Llama-3.1-8B", - "llama3.1-8b-Instruct": "meta-llama/Llama-3.1-8B-Instruct", - "llama3.1-70b-Instruct": "meta-llama/Llama-3.1-70B-Instruct", - "llama3.1-70b": "meta-llama/Llama-3.1-70B", - "llama3.1-405b": "meta-llama/Llama-3.1-405B", - "qwen3-30b-a3b": "Qwen/Qwen3-30B-A3B-Thinking-2507", - "qwen3-235b-a22b": "Qwen/Qwen3-235B-A22B-Thinking-2507", - "qwen3-480b-a35b": "Qwen/Qwen3-Coder-480B-A35B-Instruct", - "deepseek3-671b": "deepseek-ai/DeepSeek-V3", - "gpt-oss-20b": "openai/gpt-oss-20b", - "gpt-oss-120b": "openai/gpt-oss-120b", - "qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct", - "qwen3-next-80b-a3b": "Qwen/Qwen3-Next-80B-A3B-Instruct", - "mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1", - "mixtral-8x22b": "mistralai/Mixtral-8x22B-Instruct-v0.1", - "olmo3-7b": "allenai/Olmo-3-7B-Instruct", - "olmo3-7b-pt": "allenai/Olmo-3-1025-7B", - "olmo3-32b": "allenai/Olmo-3-32B-Think", -} - - def _get_local_directory(output_dir: str) -> str: """Determines the local directory for saving files.""" if output_dir.startswith("gs://") or output_dir.startswith("hf://"): diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 6772678dfc..e524cce9d5 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -573,7 +573,7 @@ num_vocab_tiling: 1 # Tokenizer vocab_size: 32_000 # powers of 2 for sharding -tokenizer_path: "src/maxtext/assets/tokenizers/tokenizer.llama2" +tokenizer_path: "" # tfds pipeline supports tokenizer_type: sentencepiece, huggingface, tiktoken # grain pipeline supports tokenizer_type: sentencepiece, huggingface # hf pipeline only supports huggingface type, and will ignore tokenizer_type flag diff --git a/src/maxtext/configs/pyconfig.py b/src/maxtext/configs/pyconfig.py index 663b76633a..54d9da447b 100644 --- a/src/maxtext/configs/pyconfig.py +++ b/src/maxtext/configs/pyconfig.py @@ -29,12 +29,13 @@ import omegaconf from maxtext.configs import pyconfig_deprecated -from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR +from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR, MAXTEXT_ASSETS_ROOT, HF_IDS from maxtext.common.common_types import DecoderBlockType, ShardMode from maxtext.configs import types from maxtext.configs.types import MaxTextConfig from maxtext.inference.inference_utils import str2bool from maxtext.utils import max_utils +from maxtext.utils import max_logging logger = logging.getLogger(__name__) logger.setLevel(os.environ.get("LOGLEVEL", "INFO")) @@ -62,7 +63,7 @@ def resolve_config_path(param: str) -> str: # For pip-installed packages, strip the src prefix and resolve against # the installed configs directory (MAXTEXT_CONFIGS_DIR). if param.startswith("src/maxtext/configs/"): - candidate = os.path.join(MAXTEXT_CONFIGS_DIR, param[len("src/maxtext/configs/"):]) + candidate = os.path.join(MAXTEXT_CONFIGS_DIR, param[len("src/maxtext/configs/") :]) if os.path.isfile(candidate): return candidate return os.path.join("src", param) @@ -126,7 +127,7 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]: for key, value in raw_keys.items(): if key not in valid_fields: logger.warning("Ignoring invalid/unsupported field from YAML/CLI: %s", repr(key)) - raise ValueError(f"{key!r} not in {", ".join(map(repr, valid_fields))}.") + raise ValueError(f"{key!r} not in {', '.join(map(repr, valid_fields))}.") new_value = value if isinstance(new_value, str) and new_value.lower() == "none": @@ -142,7 +143,16 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]: # An empty value provided in the configuration is treated as None if ( - key in ("hf_train_files", "hf_eval_files", "hf_access_token", "hf_name", "hf_data_dir", "hf_eval_split") + key + in ( + "hf_train_files", + "hf_eval_files", + "hf_access_token", + "hf_name", + "hf_data_dir", + "hf_eval_split", + "tokenizer_path", + ) and new_value == "" ): new_value = None @@ -150,6 +160,17 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]: if key == "run_name" and new_value is None: new_value = "" + if key == "tokenizer_path" and new_value is None: + try: + new_value = HF_IDS[raw_keys["model_name"]] + except KeyError: + new_value = os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers/tokenizer.llama2") + max_logging.warning( + "tokenizer_path not found in HF_IDS in maxtext/src/maxtext/utils/globals.py. \ + Using the default src/maxtext/assets/tokenizers/tokenizer.llama2 instead. \ + Please pass tokenizer_path in your command if this is not intended." + ) + # Preprocess muon_consistent_rms to be None or float if key == "muon_consistent_rms": if value in ["None", "none"]: @@ -238,16 +259,23 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig: # 3. Handle model-specific config temp_cfg = omegaconf.OmegaConf.merge(base_yml_config, overrides_cfg) model_name = temp_cfg.get("model_name", "default") + # The architecture for -Instruct v/s base models are the same, so for identifying the + # architecture we replace "-Instruct" from the model_name and get the base model name + model_name = model_name.replace("-Instruct", "") if "-Instruct" in model_name else model_name model_cfg = {} if model_name != "default": # First try relative to base config path model_config_path = os.path.join(os.path.dirname(config_path), "models", f"{model_name}.yml") # Try looking for "models" under "src/maxtext/configs/" if not os.path.isfile(model_config_path): - model_config_path = os.path.join(os.path.dirname(os.path.dirname(config_path)), "models", f"{model_name}.yml") + model_config_path = os.path.join( + os.path.dirname(os.path.dirname(config_path)), + "models", + f"{model_name}.yml", + ) if not os.path.isfile(model_config_path): - # Fallback to default location within package + # Fallback to the default location within package dir_path = os.path.dirname(os.path.realpath(__file__)) model_config_path = os.path.join(dir_path, "configs", "models", f"{model_name}.yml") @@ -261,7 +289,7 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig: else: logger.warning("Model config for '%s' not found at %s", model_name, model_config_path) - # 4. Final merge (base, model, then overrides) + # Finally merge (base, model, then overrides) model_cfg_oc = omegaconf.OmegaConf.create(model_cfg) # 4. Manually merge logical_axis_rules to avoid OmegaConf's list replacement behavior. diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index c9e9475883..335fd9a7cd 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -205,7 +205,9 @@ class ProfilerType(str, Enum): "llama2-13b", "llama2-70b", "llama3-8b", + "llama3.1-8b-Instruct", "llama3-70b", + "llama3.1-70b-Instruct", "llama3.1-8b", "llama3.1-70b", "llama3.1-405b", @@ -926,8 +928,8 @@ class Tokenizer(BaseModel): """Configuration for the tokenizer.""" vocab_size: int = Field(32_000, description="The size of the vocabulary.") - tokenizer_path: PathStr = Field( - os.path.join("assets", "tokenizers", "tokenizer.llama2"), + tokenizer_path: None | PathStr = Field( + None, description="Path to the tokenizer model file.", ) tokenizer_type: TokenizerType = Field(TokenizerType.SENTENCEPIECE, description="The type of tokenizer.") diff --git a/src/maxtext/utils/globals.py b/src/maxtext/utils/globals.py index 40985c5060..9ddae77f97 100644 --- a/src/maxtext/utils/globals.py +++ b/src/maxtext/utils/globals.py @@ -42,6 +42,40 @@ EPS = 1e-8 # Epsilon to calculate loss DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE = 2 * 1024**3 # Default checkpoint file size +# Mapping from MaxText model key to Hugging Face tokenizer identifiers +HF_IDS = { + "gemma2-2b": "google/gemma-2-2b", + "gemma2-9b": "google/gemma-2-9b", + "gemma2-27b": "google/gemma-2-27b", + "gemma3-4b": "google/gemma-3-4b-it", # hf multi-modal should also support the pure-text + "gemma3-12b": "google/gemma-3-12b-it", + "gemma3-27b": "google/gemma-3-27b-it", + "qwen3-0.6b": "Qwen/Qwen3-0.6B", + "qwen3-4b": "Qwen/Qwen3-4B", + "qwen3-4b-thinking-2507": "Qwen/Qwen3-4B-Thinking-2507", + "qwen3-8b": "Qwen/Qwen3-8B", + "qwen3-14b": "Qwen/Qwen3-14B", + "qwen3-32b": "Qwen/Qwen3-32B", + "llama3.1-8b": "meta-llama/Llama-3.1-8B", + "llama3.1-8b-Instruct": "meta-llama/Llama-3.1-8B-Instruct", + "llama3.1-70b-Instruct": "meta-llama/Llama-3.1-70B-Instruct", + "llama3.1-70b": "meta-llama/Llama-3.1-70B", + "llama3.1-405b": "meta-llama/Llama-3.1-405B", + "qwen3-30b-a3b": "Qwen/Qwen3-30B-A3B-Thinking-2507", + "qwen3-235b-a22b": "Qwen/Qwen3-235B-A22B-Thinking-2507", + "qwen3-480b-a35b": "Qwen/Qwen3-Coder-480B-A35B-Instruct", + "deepseek3-671b": "deepseek-ai/DeepSeek-V3", + "gpt-oss-20b": "openai/gpt-oss-20b", + "gpt-oss-120b": "openai/gpt-oss-120b", + "qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "qwen3-next-80b-a3b": "Qwen/Qwen3-Next-80B-A3B-Instruct", + "mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1", + "mixtral-8x22b": "mistralai/Mixtral-8x22B-Instruct-v0.1", + "olmo3-7b": "allenai/Olmo-3-7B-Instruct", + "olmo3-7b-pt": "allenai/Olmo-3-1025-7B", + "olmo3-32b": "allenai/Olmo-3-32B-Think", +} + __all__ = [ "DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE", "EPS", @@ -50,4 +84,5 @@ "MAXTEXT_PKG_DIR", "MAXTEXT_REPO_ROOT", "MAXTEXT_TEST_ASSETS_ROOT", + "HF_IDS", ]