Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 82 additions & 46 deletions tests/end_to_end/tpu/gemma3/4b/test_gemma3.sh
Original file line number Diff line number Diff line change
@@ -1,58 +1,94 @@
#!/bin/bash

# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Gemma3-4b.
# Validates the Gemma3-4B pre-training pipeline using a pre-converted MaxText checkpoint.

# The flow of this file is as follows:
# 1. Convert the checkpoint downloaded from Kaggle to make it compatible with MaxText
# 2. Run decoding, finetuning of Gemma3-4b with the converted checkpoint. Also, run pretraining of Gemma3-4b
# 3. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding.
# The flow of this script is as follows:
# 1. Run inference on the pre-converted checkpoint.
# 2. Run pre-training starting from the pre-converted checkpoint.
# 3. Run inference on the checkpoint produced by the pre-training run.
# 4. Convert the checkpoint produced by the pre-training run back to HuggingFace format.

# Usage:
# export HF_TOKEN=<your Hugging Face access token>
# export RUN_ID=$(date +%Y-%m-%d-%H-%M)
# bash test_gemma3_to_mt.sh $RUN_ID
# bash test_gemma3.sh $RUN_ID


set -ex
idx=$(date +%Y-%m-%d-%H-%M)
export MODEL_VARIATION='4b'
export MODEL_NAME=gemma3-${MODEL_VARIATION}

# Installing torch for deps in forward_pass_logit_checker
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
run_id=${1:-$(date +%Y-%m-%d-%H-%M)}
MODEL_NAME='gemma3-4b'

# To convert the multimodal model, make sure the use_multimodal is set to be true
USE_MULTIMODAL=false

# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored
BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}
UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items

# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \
# Non-Googlers please remember to use separate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET).
# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing.
# You can use the Flax checkpoint available on Kaggle:
# https://www.kaggle.com/models/google/gemma-3/flax/
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
DATASET_PATH=gs://maxtext-dataset

# Step 1: Install torch
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

export CHKPT_BUCKET=gs://maxtext-gemma/gemma3/flax
export MODEL_BUCKET=gs://maxtext-gemma/gemma3
# Step 2: Run inference on the original checkpoint converted from Hugging Face
if [ ${USE_MULTIMODAL} == true ]; then
python3 -m maxtext.inference.decode \
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
load_parameters_path=${UNSCANNED_CKPT_PATH} \
per_device_batch_size=1 run_name=${run_id} \
max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false \
scan_layers=false use_multimodal=true \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
prompt=\'Describe\ image\ \<start_of_image\>\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\'
else
python3 -m maxtext.inference.decode \
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
load_parameters_path=${UNSCANNED_CKPT_PATH} \
per_device_batch_size=1 run_name=${run_id} \
max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
scan_layers=false prompt='I love to' attention=\'dot_product\'
fi

python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_gemma3_chkpt --base_model_path ${CHKPT_BUCKET}/${MODEL_VARIATION} --maxtext_model_path ${MODEL_BUCKET}/${MODEL_VARIATION}/${idx} --model_size ${MODEL_VARIATION}
# Step 3: Run Pre-training on the converted checkpoint
# We can also run training by using the scanned converted checkpoint
# Note that scanned checkpoint helps with efficient training
python3 -m maxtext.trainers.pre_train.train \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/train \
dataset_path=${DATASET_PATH} tokenizer_type="huggingface" \
load_parameters_path=${UNSCANNED_CKPT_PATH} \
per_device_batch_size=1 run_name=${run_id} \
max_target_length=8192 steps=5 async_checkpointing=false \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
model_name=${MODEL_NAME} scan_layers=false use_multimodal=${USE_MULTIMODAL}

# Step 4: Run inference on the checkpoint generated from the previous run
if [ ${USE_MULTIMODAL} == true ]; then
python3 -m maxtext.inference.decode \
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \
per_device_batch_size=1 run_name=${run_id} \
max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false \
scan_layers=false use_multimodal=true \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
prompt=\'Describe\ image\ \<start_of_image\>\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\'
else
python3 -m maxtext.inference.decode \
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \
per_device_batch_size=1 run_name=${run_id} \
max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
scan_layers=false prompt='I love to' attention=\'dot_product\'
fi

# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
export DATASET_PATH=gs://maxtext-dataset
# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run
export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/gemma3-4b
# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train` and `decode` commands
export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items
export RUN_NAME=unscanned_chkpt_${idx}
# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format.
# We can do this by running `maxtext.utils.generate_param_only_checkpoint` on `CONVERTED_CHECKPOINT` with `force_unroll=true`.
JAX_PLATFORMS=cpu python3 -m maxtext.utils.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_NAME} force_unroll=true

export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items

# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state.
# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
python3 -m maxtext.inference.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to"

# # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu`
python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0

# Finetune by using the scanned converted checkpoint by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. For Googlers, uncomment the line below if you want to use the pre-converted checkpoint.
# export CONVERTED_CHECKPOINT=gs://maxtext-model-checkpoints/gemma3-4b/2025-03-18-19-03/scanned/0/items
FINETUNE_RUN_NAME=runner_finetune_${idx}
python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$FINETUNE_RUN_NAME steps=10 enable_checkpointing=true sharding_tolerance=0.03

# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load_parameters_path
PRETRAIN_RUN_NAME=runner_pretrain_${idx}
python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$PRETRAIN_RUN_NAME steps=10 enable_checkpointing=false sharding_tolerance=0.03
# Step 5: Convert the checkpoint from MaxText format to Hugging Face format
python3 -m maxtext.checkpoint_conversion.to_huggingface \
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \
use_multimodal=${USE_MULTIMODAL} \
scan_layers=false
67 changes: 67 additions & 0 deletions tests/end_to_end/tpu/gemma3/4b/test_gemma3_rl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/bin/bash

# Validates the Gemma3-4B RL pipeline using a pre-converted MaxText checkpoint.

# The flow of this script is as follows:
# 1. Run inference on the pre-converted checkpoint.
# 2. Run RL starting from the pre-converted checkpoint.
# 3. Run inference on the checkpoint produced by the RL run.
# 4. Convert the checkpoint produced by the RL run back to HuggingFace format.

# Usage:
# export HF_TOKEN=<your Hugging Face access token>
# export RUN_ID=$(date +%Y-%m-%d-%H-%M)
# bash test_gemma3_to_mt.sh $RUN_ID
# bash test_gemma3_rl.sh $RUN_ID


set -ex

run_id=${1:-$(date +%Y-%m-%d-%H-%M)}
MODEL_NAME='gemma3-4b'

# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored
BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}
UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items
SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/scanned/${run_id}/0/items

# Step 1: Install torch
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

# Step 2: Run inference on the original checkpoint converted from Hugging Face
python3 -m maxtext.inference.vllm_decode \
model_name=${MODEL_NAME} \
load_parameters_path=${UNSCANNED_CKPT_PATH} \
vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \
hbm_utilization_vllm=0.5 \
prompt='Suggest some famous landmarks in London.' \
use_chat_template=True scan_layers=false

# Step 3: Run RL on the converted checkpoint
python3 -m maxtext.trainers.post_train.rl.train_rl \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/rl \
load_parameters_path=${SCANNED_CKPT_PATH} \
run_name=${run_id} rl.loss_algo='grpo' scan_layers=true \
num_batches=5 batch_size=1 num_test_batches=5 \
model_name=${MODEL_NAME} enable_single_controller=True \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
rollout_tensor_parallelism=1 \
vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \
vllm_additional_config='{"maxtext_config": {"model_name": "gemma3-4b", "log_config": "false"}}'


# Step 4: Run inference on the checkpoint generated from the previous run
python3 -m maxtext.inference.vllm_decode \
model_name=${MODEL_NAME} \
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/rl/${run_id}/checkpoints/actor/5/model_params \
vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \
hbm_utilization_vllm=0.5 \
prompt='Suggest some famous landmarks in London.' \
use_chat_template=True scan_layers=true

# Step 5: Convert the checkpoint from MaxText format to Hugging Face format
python3 -m maxtext.checkpoint_conversion.to_huggingface \
model_name=${MODEL_NAME} \
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/rl/${run_id}/checkpoints/actor/5/model_params \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \
use_multimodal=false scan_layers=true
63 changes: 63 additions & 0 deletions tests/end_to_end/tpu/gemma3/4b/test_gemma3_sft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/bin/bash

# Validates the Gemma3-4B SFT pipeline using a pre-converted MaxText checkpoint.

# The flow of this script is as follows:
# 1. Run inference on the pre-converted checkpoint.
# 2. Run SFT starting from the pre-converted checkpoint.
# 3. Run inference on the checkpoint produced by the SFT run.
# 4. Convert the checkpoint produced by the SFT run back to HuggingFace format.

# Usage:
# export HF_TOKEN=<your Hugging Face access token>
# export RUN_ID=$(date +%Y-%m-%d-%H-%M)
# bash test_gemma3_to_mt.sh $RUN_ID
# bash test_gemma3_sft.sh $RUN_ID


set -ex

run_id=${1:-$(date +%Y-%m-%d-%H-%M)}
MODEL_NAME='gemma3-4b'

# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored
BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}
UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items
SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/scanned/${run_id}/0/items

# Step 1: Install torch
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

# Step 2: Run inference on the original checkpoint converted from Hugging Face
python3 -m maxtext.inference.vllm_decode \
model_name=${MODEL_NAME} \
load_parameters_path=${UNSCANNED_CKPT_PATH} \
vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \
hbm_utilization_vllm=0.5 \
prompt="Suggest some famous landmarks in London." \
use_chat_template=True scan_layers=false

# Step 3: Run SFT on the converted checkpoint
python3 -m maxtext.trainers.post_train.sft.train_sft \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/sft \
load_parameters_path=${SCANNED_CKPT_PATH} \
per_device_batch_size=1 run_name=${run_id} \
steps=5 scan_layers=true \
model_name=${MODEL_NAME} enable_single_controller=True \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False

# Step 4: Run inference on the checkpoint generated from the previous run
python3 -m maxtext.inference.vllm_decode \
model_name=${MODEL_NAME} \
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/sft/${run_id}/checkpoints/5/model_params \
vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \
hbm_utilization_vllm=0.5 \
prompt="Suggest some famous landmarks in London." \
use_chat_template=True scan_layers=true

# Step 5: Convert the checkpoint from MaxText format to Hugging Face format
python3 -m maxtext.checkpoint_conversion.to_huggingface \
model_name=${MODEL_NAME} \
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/sft/${run_id}/checkpoints/5/model_params \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \
use_multimodal=false scan_layers=true
58 changes: 0 additions & 58 deletions tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_hf.sh

This file was deleted.

Loading
Loading