Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ To extend conversion support to a new model architecture, you must define its sp
- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer.

2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion.
3. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/utils.py), add the new model key in `HF_IDS`.
3. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py), add the new model key in `HF_IDS`.
4. **Add transformer config**: In [`utils/hf_model_configs.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/hf_model_configs.py), add the `transformers.Config` object, describing the Hugging Face model configuration (defined in [`src/maxtext/configs/models`](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/configs/models)). **Note**: This configuration must precisely match the MaxText model's architecture.

Here is an example [PR to add support for gemma3 multi-modal model](https://github.com/AI-Hypercomputer/maxtext/pull/1983)
Expand Down
15 changes: 7 additions & 8 deletions docs/tutorials/posttraining/rl.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,17 @@ install_maxtext_tpu_post_train_extra_deps

## Setup environment variables

Follow the instructions [here](https://huggingface.co/docs/huggingface_hub/v0.21.2/guides/cli) to login to Hugging Face using your access token using

```bash
huggingface-cli login
```

Setup following environment variables before running GRPO/GSPO:

```bash
# -- Model configuration --
export HF_MODEL=<Hugging Face Model> # e.g. 'llama3.1-8b-Instruct'
export MODEL=<MaxText Model> # e.g. 'llama3.1-8b'
export TOKENIZER=<Tokenizer> # e.g. 'meta-llama/Llama-3.1-8B-Instruct'
export HF_TOKEN=<Hugging Face access token>
export MODEL=<MaxText Model> # e.g. 'llama3.1-8b-Instruct'

# -- MaxText configuration --
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
Expand Down Expand Up @@ -135,11 +138,9 @@ Run the following command for GRPO:
```
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
model_name=${MODEL?} \
tokenizer_path=${TOKENIZER?} \
load_parameters_path=${MAXTEXT_CKPT_PATH?} \
run_name=${RUN_NAME?} \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
hf_access_token=${HF_TOKEN?} \
chips_per_vm=${CHIPS_PER_VM?}
```

Expand All @@ -159,11 +160,9 @@ Run the following command for GSPO:
```
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
model_name=${MODEL?} \
tokenizer_path=${TOKENIZER?} \
load_parameters_path=${MAXTEXT_CKPT_PATH?} \
run_name=${RUN_NAME?} \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
hf_access_token=${HF_TOKEN?} \
loss_algo=gspo-token \
chips_per_vm=${CHIPS_PER_VM?}
```
Expand Down
6 changes: 1 addition & 5 deletions docs/tutorials/posttraining/rl_on_multi_host.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ actual values.

```bash
# -- Model configuration --
export HF_MODEL=<Hugging Face Model> # e.g. 'llama3.1-70b-Instruct'
export MODEL=<MaxText Model> # e.g. 'llama3.1-70b'
export TOKENIZER=<Tokenizer> # e.g. 'meta-llama/Llama-3.1-70B-Instruct'
export MODEL=<MaxText Model> # e.g. 'llama3.1-70b-Instruct'
export HF_TOKEN=<Hugging Face access token>

# -- MaxText configuration --
Expand Down Expand Up @@ -200,7 +198,6 @@ xpk workload create-pathways --workload ${WORKLOAD?} \
--command "HF_TOKEN=${HF_TOKEN?} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
model_name=${MODEL?} \
tokenizer_path=${TOKENIZER?} \
load_parameters_path=${MAXTEXT_CKPT_PATH?} \
run_name=${WORKLOAD?} \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
Expand All @@ -218,7 +215,6 @@ xpk workload create-pathways --workload ${WORKLOAD?} \
--command "HF_TOKEN=${HF_TOKEN?} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
model_name=${MODEL?} \
tokenizer_path=${TOKENIZER?} \
load_parameters_path=${MAXTEXT_CKPT_PATH?} \
run_name=${WORKLOAD?} \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
Expand Down
12 changes: 7 additions & 5 deletions docs/tutorials/posttraining/sft.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,17 @@ install_maxtext_tpu_post_train_extra_deps

## Setup environment variables

Follow the instructions [here](https://huggingface.co/docs/huggingface_hub/v0.21.2/guides/cli) to login to Hugging Face using your access token using

```bash
huggingface-cli login
```

Set the following environment variables before running SFT.

```sh
# -- Model configuration --
export PRE_TRAINED_MODEL=<model name> # e.g., 'llama3.1-8b'
export PRE_TRAINED_MODEL_TOKENIZER=<tokenizer path> # e.g., 'meta-llama/Llama-3.1-8B-Instruct'
export HF_TOKEN=<Hugging Face access token>
export PRE_TRAINED_MODEL=<model name> # e.g., 'llama3.1-8b-Instruct'

# -- MaxText configuration --
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
Expand Down Expand Up @@ -93,8 +97,6 @@ python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_tr
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
model_name=${PRE_TRAINED_MODEL?} \
load_parameters_path=${PRE_TRAINED_MODEL_CKPT_PATH?} \
hf_access_token=${HF_TOKEN?} \
tokenizer_path=${PRE_TRAINED_MODEL_TOKENIZER?} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE?} \
steps=${STEPS?} \
hf_path=${DATASET_NAME?} \
Expand Down
5 changes: 2 additions & 3 deletions docs/tutorials/posttraining/sft_on_multi_host.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ export HF_TOKEN=<Hugging Face Access Token>

# -- Model Configuration --
export MODEL_NAME=<Model Name> # e.g., deepseek3-671b
export TOKENIZER_PATH=<Model Tokenizer> # e.g., deepseek-ai/DeepSeek-V3

# -- Dataset configuration --
export DATASET_NAME=<Hugging Face Dataset Name> # e.g., HuggingFaceH4/ultrachat_200k
Expand Down Expand Up @@ -143,7 +142,7 @@ xpk workload create \
--workload=${WORKLOAD_NAME?} \
--tpu-type=${TPU_TYPE?} \
--num-slices=${TPU_SLICE?} \
--command "python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} tokenizer_path=${TOKENIZER_PATH?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane hf_path=${DATASET_NAME?} train_split=${TRAIN_SPLIT?} train_data_columns=${TRAIN_DATA_COLUMNS?}"
--command "python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane hf_path=${DATASET_NAME?} train_split=${TRAIN_SPLIT?} train_data_columns=${TRAIN_DATA_COLUMNS?}"
```

Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`.
Expand All @@ -159,7 +158,7 @@ xpk workload create-pathways \
--workload=${WORKLOAD_NAME?} \
--tpu-type=${TPU_TYPE?} \
--num-slices=${TPU_SLICE?} \
--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} tokenizer_path=${TOKENIZER_PATH?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False enable_single_controller=True"
--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False enable_single_controller=True"
```

Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`.
3 changes: 2 additions & 1 deletion src/maxtext/checkpoint_conversion/compare_hf_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@
from safetensors import safe_open

from maxtext.configs import pyconfig
from maxtext.checkpoint_conversion.utils.utils import HF_IDS, print_ram_usage, get_hf_model
from maxtext.checkpoint_conversion.utils.utils import print_ram_usage, get_hf_model
from maxtext.utils import max_logging
from maxtext.utils.globals import HF_IDS


jax.config.update("jax_platform_name", "cpu")
Expand Down
5 changes: 3 additions & 2 deletions src/maxtext/checkpoint_conversion/to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Key Parameters (to be set in the config file or as command-line overrides):
model_name: (Required) The name of the model to convert (e.g., "gemma2-2b").
Must be a key in `maxtext.checkpoint_conversion.utils.utils.HF_IDS`.
Must be a key in `maxtext.utils.globals.HF_IDS`.
load_parameters_path: (Required) Path to the MaxText checkpoint directory
containing the parameter-only checkpoint.
base_output_directory: (Optional) The directory where the converted HuggingFace
Expand Down Expand Up @@ -79,12 +79,13 @@
save_model_files,
load_orbax_checkpoint,
detect_and_extract_checkpoint,
HF_IDS,
MemoryMonitorTqdm,
print_peak_memory,
)
from maxtext.utils import max_logging
from maxtext.utils import max_utils
from maxtext.utils.globals import HF_IDS


flags.DEFINE_bool(
"override_model_architecture",
Expand Down
77 changes: 63 additions & 14 deletions src/maxtext/checkpoint_conversion/to_maxtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Key Parameters (to be set in the config file or as command-line overrides):
model_name: (Required) The name of the model to convert (e.g., "gemma2-2b").
Must be a key in `maxtext.checkpoint_conversion.utils.utils.HF_IDS`.
Must be a key in `maxtext.utils.globals.HF_IDS`.
base_output_directory: (Optional) The directory where the converted HuggingFace
checkpoint will be saved. Can be a local path, a GCS
path (gs://...), or a HuggingFace Hub repo ID (hf://...).
Expand All @@ -30,7 +30,7 @@
Defaults to False.
--hf_model_path: (Optional) Specifies a local or remote directory containing the model weights.
If unspecified, we use the default Hugging Face repository ID
(e.g., openai/gpt-oss-20b; see `HF_IDS[model_name]` in `utils/ckpt_conversion/utils`).
(e.g., openai/gpt-oss-20b; see `HF_IDS[model_name]` in `maxtext.utils.globals`).
This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
Environment Variables:
Expand Down Expand Up @@ -74,11 +74,12 @@
from maxtext.common.common_types import MODEL_MODE_TRAIN
from maxtext.checkpoint_conversion.standalone_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint
from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING
from maxtext.checkpoint_conversion.utils.utils import HF_IDS, MemoryMonitorTqdm, apply_hook_fns, get_hf_model, print_peak_memory, print_ram_usage, validate_and_filter_param_map_keys
from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, apply_hook_fns, get_hf_model, print_peak_memory, print_ram_usage, validate_and_filter_param_map_keys
from maxtext.inference.inference_utils import str2bool
from maxtext.layers import quantizations
from maxtext.models import models
from maxtext.utils import max_logging, max_utils, maxtext_utils
from maxtext.utils.globals import HF_IDS
import numpy as np
from orbax.checkpoint import type_handlers
from safetensors import safe_open
Expand Down Expand Up @@ -155,7 +156,12 @@ def _initialize_index(self):
if self.is_local:
index_path = os.path.join(self.model_id, index_file)
else:
index_path = hf_hub_download(repo_id=self.model_id, filename=index_file, token=self.token, revision=self.revision)
index_path = hf_hub_download(
repo_id=self.model_id,
filename=index_file,
token=self.token,
revision=self.revision,
)
with open(index_path, "r", encoding="utf-8") as f:
index_data = json.load(f)
self.shard_map = index_data["weight_map"]
Expand Down Expand Up @@ -185,7 +191,12 @@ def get_tensor(self, key: str) -> np.ndarray:
else:
# STEP 1: Download outside the lock.
# multiple threads can download different shards at the same time.
local_path = hf_hub_download(repo_id=self.model_id, filename=shard_name, token=self.token, revision=self.revision)
local_path = hf_hub_download(
repo_id=self.model_id,
filename=shard_name,
token=self.token,
revision=self.revision,
)

# STEP 2: Lock ONLY the reading into RAM.
# This prevents multiple threads from simultaneously allocating large chunks of RAM.
Expand All @@ -200,7 +211,13 @@ class LazyTensor:
and transformation until __array__ is called (e.g., by Orbax during save).
"""

def __init__(self, load_fn: Callable[[], np.ndarray], shape: tuple, dtype, name: str = "unknown"):
def __init__(
self,
load_fn: Callable[[], np.ndarray],
shape: tuple,
dtype,
name: str = "unknown",
):
self._load_fn = load_fn
self.shape = shape
self.dtype = np.dtype(dtype)
Expand Down Expand Up @@ -421,7 +438,13 @@ def _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_t
def _loader(getter, key, shape, hook):
return apply_hook_fns(getter(key), shape, hook)

load_fn = partial(_loader, tensor_getter, hf_source_keys_or_key, mt_target_shape_or_shapes, hook_fn)
load_fn = partial(
_loader,
tensor_getter,
hf_source_keys_or_key,
mt_target_shape_or_shapes,
hook_fn,
)
# Stacked mapping
elif not isinstance(hf_source_keys_or_key[0], list):
# Case 2 or 3: Single-Axis Stacked hf keys (un-nested list)
Expand Down Expand Up @@ -516,7 +539,12 @@ def _get_maxtext_weight(
# to load the tensor later (the `load_fn`, shape, dtype).
# The actual data will only be loaded when Orbax calls `__array__`
# on this object during the saving process.
final_mt_tensor_numpy = LazyTensor(load_fn, mt_target_shape_or_shapes, config.weight_dtype, name=mt_param_key_or_keys)
final_mt_tensor_numpy = LazyTensor(
load_fn,
mt_target_shape_or_shapes,
config.weight_dtype,
name=mt_param_key_or_keys,
)
if not is_composite_mt_key:
# Case 2.1: Lazy mode, `atomic_mt_key`
final_mt_weights[mt_target_idx_or_indices] = final_mt_tensor_numpy
Expand Down Expand Up @@ -562,7 +590,10 @@ def main(

# check the supported model ids
if model_name_original not in HF_IDS:
raise ValueError(f"Unsupported model name: {model_name_original}. Supported models are: {list(HF_IDS.keys())}")
raise ValueError(
f"Unsupported model name: {model_name_original}.\
Supported models are: {list(HF_IDS.keys())}"
)

model_id = hf_model_path or HF_IDS[model_name_original]

Expand Down Expand Up @@ -633,7 +664,11 @@ def _eager_getter(key):
filtered_map_keys = validate_and_filter_param_map_keys(param_map_mt_to_hf.keys(), maxtext_abstract_dict.keys())

for mt_param_key_or_keys in MemoryMonitorTqdm(
filtered_map_keys, desc="Transforming weights", unit="param", leave=True, dynamic_ncols=True
filtered_map_keys,
desc="Transforming weights",
unit="param",
leave=True,
dynamic_ncols=True,
):
if not lazy_load_tensors:
max_logging.log(f"maxtext param: {mt_param_key_or_keys}")
Expand All @@ -651,7 +686,13 @@ def _eager_getter(key):

# Step 2: Determine the loading function for hf key
# based on hf_key form (unscanned, scanned, unscanned with expert stacking, or scanned with expert stacking)
load_fn = _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_target_shape_or_shapes, config)
load_fn = _get_hf_loading_function(
hf_source_keys_or_key,
tensor_getter,
hook_fn,
mt_target_shape_or_shapes,
config,
)

# Step 3: Load hf keys and convert to maxtext keys
# based on tensor load mode (lazy, eager) and MaxText key form (`atomic_mt_key` or `composite_mt_key`)
Expand Down Expand Up @@ -710,9 +751,13 @@ def _eager_getter(key):
default=False,
help="Whether to use lazy loading of HF tensors.",
)
# If not specified, default to maxtext.checkpoint_conversion.utils.utils.HF_IDS[model_name]
# If not specified, default to maxtext.utils.globals.HF_IDS[model_name]
parser.add_argument(
"--hf_model_path", type=str, required=False, default="", help="local path to hf model, or custom remote hf repo"
"--hf_model_path",
type=str,
required=False,
default="",
help="local path to hf model, or custom remote hf repo",
)
# Determines the logical sharding of the output checkpoint by partitioning
# weights across virtual XLA devices.
Expand All @@ -730,7 +775,11 @@ def _eager_getter(key):
parser.add_argument("--simulated_cpu_devices_count", type=int, required=False, default=16)

parser.add_argument(
"--revision", type=str, required=False, default=None, help="Specific Hugging Face revision (branch/tag/commit)"
"--revision",
type=str,
required=False,
default=None,
help="Specific Hugging Face revision (branch/tag/commit)",
)

# Parse local arguments
Expand Down
35 changes: 0 additions & 35 deletions src/maxtext/checkpoint_conversion/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,41 +56,6 @@
DEFAULT_MAX_SHARD_SIZE = 1024 * 1024 * 1024 * 3 # 3GB default


# Mapping from MaxText model key to Hugging Face tokenizer identifiers
HF_IDS = {
"gemma2-2b": "google/gemma-2-2b",
"gemma2-9b": "google/gemma-2-9b",
"gemma2-27b": "google/gemma-2-27b",
"gemma3-4b": "google/gemma-3-4b-it", # hf multi-modal should also support the pure-text
"gemma3-12b": "google/gemma-3-12b-it",
"gemma3-27b": "google/gemma-3-27b-it",
"qwen3-0.6b": "Qwen/Qwen3-0.6B",
"qwen3-4b": "Qwen/Qwen3-4B",
"qwen3-4b-thinking-2507": "Qwen/Qwen3-4B-Thinking-2507",
"qwen3-8b": "Qwen/Qwen3-8B",
"qwen3-14b": "Qwen/Qwen3-14B",
"qwen3-32b": "Qwen/Qwen3-32B",
"llama3.1-8b": "meta-llama/Llama-3.1-8B",
"llama3.1-8b-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
"llama3.1-70b-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
"llama3.1-70b": "meta-llama/Llama-3.1-70B",
"llama3.1-405b": "meta-llama/Llama-3.1-405B",
"qwen3-30b-a3b": "Qwen/Qwen3-30B-A3B-Thinking-2507",
"qwen3-235b-a22b": "Qwen/Qwen3-235B-A22B-Thinking-2507",
"qwen3-480b-a35b": "Qwen/Qwen3-Coder-480B-A35B-Instruct",
"deepseek3-671b": "deepseek-ai/DeepSeek-V3",
"gpt-oss-20b": "openai/gpt-oss-20b",
"gpt-oss-120b": "openai/gpt-oss-120b",
"qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
"qwen3-next-80b-a3b": "Qwen/Qwen3-Next-80B-A3B-Instruct",
"mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"mixtral-8x22b": "mistralai/Mixtral-8x22B-Instruct-v0.1",
"olmo3-7b": "allenai/Olmo-3-7B-Instruct",
"olmo3-7b-pt": "allenai/Olmo-3-1025-7B",
"olmo3-32b": "allenai/Olmo-3-32B-Think",
}


def _get_local_directory(output_dir: str) -> str:
"""Determines the local directory for saving files."""
if output_dir.startswith("gs://") or output_dir.startswith("hf://"):
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ num_vocab_tiling: 1

# Tokenizer
vocab_size: 32_000 # powers of 2 for sharding
tokenizer_path: "src/maxtext/assets/tokenizers/tokenizer.llama2"
tokenizer_path: ""
# tfds pipeline supports tokenizer_type: sentencepiece, huggingface, tiktoken
# grain pipeline supports tokenizer_type: sentencepiece, huggingface
# hf pipeline only supports huggingface type, and will ignore tokenizer_type flag
Expand Down
Loading
Loading