From 95f0c9b2c38bab8d08bf274e9dff82e6af15b3a1 Mon Sep 17 00:00:00 2001 From: Surbhi Jain Date: Sat, 16 May 2026 00:24:06 +0000 Subject: [PATCH] Add gemma3-4b E2E test scripts for pre-training and post-training --- tests/end_to_end/tpu/gemma3/4b/test_gemma3.sh | 128 +++++++++++------- .../tpu/gemma3/4b/test_gemma3_rl.sh | 67 +++++++++ .../tpu/gemma3/4b/test_gemma3_sft.sh | 63 +++++++++ .../tpu/gemma3/4b/test_gemma3_to_hf.sh | 58 -------- .../tpu/gemma3/4b/test_gemma3_to_mt.sh | 88 +++++------- 5 files changed, 246 insertions(+), 158 deletions(-) create mode 100644 tests/end_to_end/tpu/gemma3/4b/test_gemma3_rl.sh create mode 100644 tests/end_to_end/tpu/gemma3/4b/test_gemma3_sft.sh delete mode 100644 tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_hf.sh diff --git a/tests/end_to_end/tpu/gemma3/4b/test_gemma3.sh b/tests/end_to_end/tpu/gemma3/4b/test_gemma3.sh index 20986e6cb9..287a17b3c7 100644 --- a/tests/end_to_end/tpu/gemma3/4b/test_gemma3.sh +++ b/tests/end_to_end/tpu/gemma3/4b/test_gemma3.sh @@ -1,58 +1,94 @@ #!/bin/bash -# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Gemma3-4b. +# Validates the Gemma3-4B pre-training pipeline using a pre-converted MaxText checkpoint. -# The flow of this file is as follows: -# 1. Convert the checkpoint downloaded from Kaggle to make it compatible with MaxText -# 2. Run decoding, finetuning of Gemma3-4b with the converted checkpoint. Also, run pretraining of Gemma3-4b -# 3. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. +# The flow of this script is as follows: +# 1. Run inference on the pre-converted checkpoint. +# 2. Run pre-training starting from the pre-converted checkpoint. +# 3. Run inference on the checkpoint produced by the pre-training run. +# 4. Convert the checkpoint produced by the pre-training run back to HuggingFace format. + +# Usage: +# export HF_TOKEN= +# export RUN_ID=$(date +%Y-%m-%d-%H-%M) +# bash test_gemma3_to_mt.sh $RUN_ID +# bash test_gemma3.sh $RUN_ID set -ex -idx=$(date +%Y-%m-%d-%H-%M) -export MODEL_VARIATION='4b' -export MODEL_NAME=gemma3-${MODEL_VARIATION} -# Installing torch for deps in forward_pass_logit_checker -python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu +run_id=${1:-$(date +%Y-%m-%d-%H-%M)} +MODEL_NAME='gemma3-4b' + +# To convert the multimodal model, make sure the use_multimodal is set to be true +USE_MULTIMODAL=false + +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored +BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME} +UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items -# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ -# Non-Googlers please remember to use separate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). -# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing. -# You can use the Flax checkpoint available on Kaggle: -# https://www.kaggle.com/models/google/gemma-3/flax/ +# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data +DATASET_PATH=gs://maxtext-dataset + +# Step 1: Install torch +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu -export CHKPT_BUCKET=gs://maxtext-gemma/gemma3/flax -export MODEL_BUCKET=gs://maxtext-gemma/gemma3 +# Step 2: Run inference on the original checkpoint converted from Hugging Face +if [ ${USE_MULTIMODAL} == true ]; then + python3 -m maxtext.inference.decode \ + model_name=${MODEL_NAME} tokenizer_type="huggingface" \ + load_parameters_path=${UNSCANNED_CKPT_PATH} \ + per_device_batch_size=1 run_name=${run_id} \ + max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false \ + scan_layers=false use_multimodal=true \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' +else + python3 -m maxtext.inference.decode \ + model_name=${MODEL_NAME} tokenizer_type="huggingface" \ + load_parameters_path=${UNSCANNED_CKPT_PATH} \ + per_device_batch_size=1 run_name=${run_id} \ + max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + scan_layers=false prompt='I love to' attention=\'dot_product\' +fi -python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_gemma3_chkpt --base_model_path ${CHKPT_BUCKET}/${MODEL_VARIATION} --maxtext_model_path ${MODEL_BUCKET}/${MODEL_VARIATION}/${idx} --model_size ${MODEL_VARIATION} +# Step 3: Run Pre-training on the converted checkpoint +# We can also run training by using the scanned converted checkpoint +# Note that scanned checkpoint helps with efficient training +python3 -m maxtext.trainers.pre_train.train \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/train \ + dataset_path=${DATASET_PATH} tokenizer_type="huggingface" \ + load_parameters_path=${UNSCANNED_CKPT_PATH} \ + per_device_batch_size=1 run_name=${run_id} \ + max_target_length=8192 steps=5 async_checkpointing=false \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + model_name=${MODEL_NAME} scan_layers=false use_multimodal=${USE_MULTIMODAL} +# Step 4: Run inference on the checkpoint generated from the previous run +if [ ${USE_MULTIMODAL} == true ]; then + python3 -m maxtext.inference.decode \ + model_name=${MODEL_NAME} tokenizer_type="huggingface" \ + load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \ + per_device_batch_size=1 run_name=${run_id} \ + max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false \ + scan_layers=false use_multimodal=true \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' +else + python3 -m maxtext.inference.decode \ + model_name=${MODEL_NAME} tokenizer_type="huggingface" \ + load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \ + per_device_batch_size=1 run_name=${run_id} \ + max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + scan_layers=false prompt='I love to' attention=\'dot_product\' +fi -# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data -export DATASET_PATH=gs://maxtext-dataset -# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run -export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/gemma3-4b -# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train` and `decode` commands -export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items -export RUN_NAME=unscanned_chkpt_${idx} -# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. -# We can do this by running `maxtext.utils.generate_param_only_checkpoint` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. -JAX_PLATFORMS=cpu python3 -m maxtext.utils.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_NAME} force_unroll=true - -export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items - -# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. -# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.inference.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" - -# # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0 - -# Finetune by using the scanned converted checkpoint by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. For Googlers, uncomment the line below if you want to use the pre-converted checkpoint. -# export CONVERTED_CHECKPOINT=gs://maxtext-model-checkpoints/gemma3-4b/2025-03-18-19-03/scanned/0/items -FINETUNE_RUN_NAME=runner_finetune_${idx} -python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$FINETUNE_RUN_NAME steps=10 enable_checkpointing=true sharding_tolerance=0.03 - -# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load_parameters_path -PRETRAIN_RUN_NAME=runner_pretrain_${idx} -python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$PRETRAIN_RUN_NAME steps=10 enable_checkpointing=false sharding_tolerance=0.03 +# Step 5: Convert the checkpoint from MaxText format to Hugging Face format +python3 -m maxtext.checkpoint_conversion.to_huggingface \ + model_name=${MODEL_NAME} tokenizer_type="huggingface" \ + load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \ + use_multimodal=${USE_MULTIMODAL} \ + scan_layers=false diff --git a/tests/end_to_end/tpu/gemma3/4b/test_gemma3_rl.sh b/tests/end_to_end/tpu/gemma3/4b/test_gemma3_rl.sh new file mode 100644 index 0000000000..05ccbfe69a --- /dev/null +++ b/tests/end_to_end/tpu/gemma3/4b/test_gemma3_rl.sh @@ -0,0 +1,67 @@ +#!/bin/bash + +# Validates the Gemma3-4B RL pipeline using a pre-converted MaxText checkpoint. + +# The flow of this script is as follows: +# 1. Run inference on the pre-converted checkpoint. +# 2. Run RL starting from the pre-converted checkpoint. +# 3. Run inference on the checkpoint produced by the RL run. +# 4. Convert the checkpoint produced by the RL run back to HuggingFace format. + +# Usage: +# export HF_TOKEN= +# export RUN_ID=$(date +%Y-%m-%d-%H-%M) +# bash test_gemma3_to_mt.sh $RUN_ID +# bash test_gemma3_rl.sh $RUN_ID + + +set -ex + +run_id=${1:-$(date +%Y-%m-%d-%H-%M)} +MODEL_NAME='gemma3-4b' + +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored +BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME} +UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items +SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/scanned/${run_id}/0/items + +# Step 1: Install torch +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu + +# Step 2: Run inference on the original checkpoint converted from Hugging Face +python3 -m maxtext.inference.vllm_decode \ + model_name=${MODEL_NAME} \ + load_parameters_path=${UNSCANNED_CKPT_PATH} \ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ + hbm_utilization_vllm=0.5 \ + prompt='Suggest some famous landmarks in London.' \ + use_chat_template=True scan_layers=false + +# Step 3: Run RL on the converted checkpoint +python3 -m maxtext.trainers.post_train.rl.train_rl \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/rl \ + load_parameters_path=${SCANNED_CKPT_PATH} \ + run_name=${run_id} rl.loss_algo='grpo' scan_layers=true \ + num_batches=5 batch_size=1 num_test_batches=5 \ + model_name=${MODEL_NAME} enable_single_controller=True \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + rollout_tensor_parallelism=1 \ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ + vllm_additional_config='{"maxtext_config": {"model_name": "gemma3-4b", "log_config": "false"}}' + + +# Step 4: Run inference on the checkpoint generated from the previous run +python3 -m maxtext.inference.vllm_decode \ + model_name=${MODEL_NAME} \ + load_parameters_path=${BASE_OUTPUT_DIRECTORY}/rl/${run_id}/checkpoints/actor/5/model_params \ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ + hbm_utilization_vllm=0.5 \ + prompt='Suggest some famous landmarks in London.' \ + use_chat_template=True scan_layers=true + +# Step 5: Convert the checkpoint from MaxText format to Hugging Face format +python3 -m maxtext.checkpoint_conversion.to_huggingface \ + model_name=${MODEL_NAME} \ + load_parameters_path=${BASE_OUTPUT_DIRECTORY}/rl/${run_id}/checkpoints/actor/5/model_params \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \ + use_multimodal=false scan_layers=true diff --git a/tests/end_to_end/tpu/gemma3/4b/test_gemma3_sft.sh b/tests/end_to_end/tpu/gemma3/4b/test_gemma3_sft.sh new file mode 100644 index 0000000000..e6ebd5a319 --- /dev/null +++ b/tests/end_to_end/tpu/gemma3/4b/test_gemma3_sft.sh @@ -0,0 +1,63 @@ +#!/bin/bash + +# Validates the Gemma3-4B SFT pipeline using a pre-converted MaxText checkpoint. + +# The flow of this script is as follows: +# 1. Run inference on the pre-converted checkpoint. +# 2. Run SFT starting from the pre-converted checkpoint. +# 3. Run inference on the checkpoint produced by the SFT run. +# 4. Convert the checkpoint produced by the SFT run back to HuggingFace format. + +# Usage: +# export HF_TOKEN= +# export RUN_ID=$(date +%Y-%m-%d-%H-%M) +# bash test_gemma3_to_mt.sh $RUN_ID +# bash test_gemma3_sft.sh $RUN_ID + + +set -ex + +run_id=${1:-$(date +%Y-%m-%d-%H-%M)} +MODEL_NAME='gemma3-4b' + +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored +BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME} +UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items +SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/scanned/${run_id}/0/items + +# Step 1: Install torch +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu + +# Step 2: Run inference on the original checkpoint converted from Hugging Face +python3 -m maxtext.inference.vllm_decode \ + model_name=${MODEL_NAME} \ + load_parameters_path=${UNSCANNED_CKPT_PATH} \ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ + hbm_utilization_vllm=0.5 \ + prompt="Suggest some famous landmarks in London." \ + use_chat_template=True scan_layers=false + +# Step 3: Run SFT on the converted checkpoint +python3 -m maxtext.trainers.post_train.sft.train_sft \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/sft \ + load_parameters_path=${SCANNED_CKPT_PATH} \ + per_device_batch_size=1 run_name=${run_id} \ + steps=5 scan_layers=true \ + model_name=${MODEL_NAME} enable_single_controller=True \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False + +# Step 4: Run inference on the checkpoint generated from the previous run +python3 -m maxtext.inference.vllm_decode \ + model_name=${MODEL_NAME} \ + load_parameters_path=${BASE_OUTPUT_DIRECTORY}/sft/${run_id}/checkpoints/5/model_params \ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ + hbm_utilization_vllm=0.5 \ + prompt="Suggest some famous landmarks in London." \ + use_chat_template=True scan_layers=true + +# Step 5: Convert the checkpoint from MaxText format to Hugging Face format +python3 -m maxtext.checkpoint_conversion.to_huggingface \ + model_name=${MODEL_NAME} \ + load_parameters_path=${BASE_OUTPUT_DIRECTORY}/sft/${run_id}/checkpoints/5/model_params \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \ + use_multimodal=false scan_layers=true diff --git a/tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_hf.sh b/tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_hf.sh deleted file mode 100644 index 1772c76b97..0000000000 --- a/tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_hf.sh +++ /dev/null @@ -1,58 +0,0 @@ -#!/bin/bash - -# This script is both an end-to-end test that runs once a day on a v4-8 and documentation for how to get started with Gemma3-4B. - -# The flow of this script is as follows: -# 1. Convert a MaxText checkpoint to a Hugging Face model checkpoint. -# 2. Run a forward pass check to compare the logits and KL divergence between the converted ckpt and original golden HF model. - -# Pre-requisites: -# 1. Set HF_TOKEN environment variable to your Hugging Face access token with read permissions -# export HF_TOKEN= - -set -ex -idx=$(date +%Y-%m-%d-%H-%M) -MODEL_NAME='gemma3-4b' -export MODEL_VARIATION='4b' -TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"'/tokenizer.gemma3' -# To convert the multimodal model, make sure the use_multimodal is set to be true -USE_MULTIMODAL=false - -# Installing torch for deps in forward_pass_logit_checker.py -python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu - -# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ -# Non-Googlers please remember to use separate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). -# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing. -export MODEL_BUCKET=gs://maxtext-gemma/unified/gemma3/hf -# Here is an example of qwen3-4b maxtext checkpoint, converted from Qwen/Qwen3-4B -export CKPT_PATH=gs://maxtext-gemma/unified/gemma3/4b/unscanned/2025-08-05-18-18/0/items - -# You can upload to huggingface hub or GCS by uncommenting the HF_CKPT_PATH and using it as base_output_directory -# export HF_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/hf/${idx} -export LOCAL_PATH=./tmp/hf/${MODEL_NAME}/${idx} - -python3 -m maxtext.checkpoint_conversion.to_huggingface "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ - model_name=${MODEL_NAME} \ - hf_access_token=${HF_TOKEN} \ - load_parameters_path=${CKPT_PATH} \ - base_output_directory=${LOCAL_PATH} \ - use_multimodal=${USE_MULTIMODAL} \ - scan_layers=false - -# Alternatively, if uploaded the converted ckpt, HF requires local storage of model and please uncomment below -# mkdir -p "${LOCAL_PATH}" -# gcloud storage cp -r ${HF_CKPT_PATH}/** ${LOCAL_PATH} -# echo "Copied from ${HF_CKPT_PATH} to ${LOCAL_PATH}" - -# We also test whether the forward pass logits match the original HF model -# to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ - tokenizer_path=${TOKENIZER_PATH} \ - load_parameters_path=${CKPT_PATH} \ - model_name=${MODEL_NAME} \ - use_multimodal=${USE_MULTIMODAL} \ - scan_layers=false \ - --hf_model_path=${LOCAL_PATH} \ - --max_kl_div=0.015 \ - --run_hf_model=true diff --git a/tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh b/tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh index 09301387be..4d1b4b898a 100644 --- a/tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh +++ b/tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh @@ -1,12 +1,11 @@ #!/bin/bash -# This script is both an end-to-end test that runs once a day on a v4-8 and documentation for how to get started with Gemma3-4B. +# Converts Gemma3-4B HuggingFace checkpoint to MaxText format and validates logit correctness. # The flow of this script is as follows: -# 1. Convert the checkpoint downloaded from Hugging Face to make it compatible with MaxText. -# 2. Run a forward pass logits check to compare with the original HF golden model. -# 3. Run decoding, finetuning of Gemma3-4B. with the converted checkpoint. -# 4. Run decoding from the finetuned checkpoint from step 3. +# 1. Install PyTorch (CPU) required for checkpoint conversion. +# 2. Convert the HuggingFace checkpoint to MaxText format in both unscanned and scanned formats. +# 3. Run a forward pass logits check to verify the converted checkpoint matches the original HF model. # Pre-requisites: # 1. Set HF_TOKEN environment variable to your Hugging Face access token with read permissions @@ -14,76 +13,57 @@ set -ex -idx=$(date +%Y-%m-%d-%H-%M) + +run_id=${1:-$(date +%Y-%m-%d-%H-%M)} MODEL_NAME='gemma3-4b' -export MODEL_VARIATION='4b' HF_GOLDEN_MODEL='google/gemma-3-4b-it' -TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"'/tokenizer.gemma3' + # To convert the multimodal model, make sure the use_multimodal is set to be true USE_MULTIMODAL=false -# Installing torch for deps in forward_pass_logit_checker.py +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you want to store scanned and unscanned checkpoints +BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}/to_maxtext + +# Step 1: Install torch python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu -# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ -# Non-Googlers please remember to use separate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). -# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing. -export MODEL_BUCKET=gs://maxtext-gemma/unified/gemma3 +# Step 2: Convert the checkpoint from Hugging Face to make it compatible with MaxText -# To get unscanned ckpt: -python3 -m maxtext.checkpoint_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ +# Step 2.a: Convert to unscanned checkpoint (for inference) +python3 -m maxtext.checkpoint_conversion.to_maxtext \ model_name=${MODEL_NAME} \ - hf_access_token=${HF_TOKEN} \ - base_output_directory=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx} \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/unscanned/${run_id} \ use_multimodal=${USE_MULTIMODAL} \ - scan_layers=false + scan_layers=false \ + hardware=cpu skip_jax_distributed_system=True \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + --eager_load_method='transformers' -export UNSCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx}/0/items +UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/unscanned/${run_id}/0/items +echo "Unscanned checkpoint path: ${UNSCANNED_CKPT_PATH}" -# To get scanned ckpt, flip the scan_layers. -python3 -m maxtext.checkpoint_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ +# Step 2.b: Convert to scanned checkpoint (for training) +python3 -m maxtext.checkpoint_conversion.to_maxtext \ model_name=${MODEL_NAME} \ - hf_access_token=${HF_TOKEN} \ - base_output_directory=${MODEL_BUCKET}/${MODEL_VARIATION}/scanned/${idx} \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/scanned/${run_id} \ use_multimodal=${USE_MULTIMODAL} \ - scan_layers=true + scan_layers=true \ + hardware=cpu skip_jax_distributed_system=True \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ + --eager_load_method='transformers' -export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/scanned/${idx}/0/items +SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/scanned/${run_id}/0/items +echo "Scanned checkpoint path: ${SCANNED_CKPT_PATH}" -# We also test whether the forward pass logits match the original HF model +# Step 3: Test whether the forward pass logits match the original HF model # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` - # ToDo: improve forward_pass_logit_checker to test multi-modal prompt -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ - tokenizer_path=${TOKENIZER_PATH} \ +python3 -m tests.utils.forward_pass_logit_checker \ load_parameters_path=${UNSCANNED_CKPT_PATH} \ model_name=${MODEL_NAME} \ use_multimodal=${USE_MULTIMODAL} \ scan_layers=false \ --hf_model_path=${HF_GOLDEN_MODEL} \ --max_kl_div=0.03 \ - --run_hf_model=true - -# We can run decoding for unscanned checkpoints. -if [ ${USE_MULTIMODAL} == true ]; then - python3 -m maxtext.inference.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=false use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' -else - python3 -m maxtext.inference.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=false prompt='I love to' attention=\'dot_product\' -fi - -# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data -export DATASET_PATH=gs://maxtext-dataset -# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run -export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/gemma3-4b - -# We can also run finetuning by using the scanned converted checkpoint. -# Note that scanned checkpoint helps with efficient finetuning -export FINETUNE_RUN_NAME=runner_finetune_${idx} -python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=${MODEL_NAME} checkpoint_period=5 scan_layers=false - -# Now, run decoding on the checkpoint generated from our finetune run. -if [ ${USE_MULTIMODAL} == true ]; then - python3 -m maxtext.inference.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=false use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' -else - python3 -m maxtext.inference.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=false prompt='I love to' attention=\'dot_product\' -fi + --run_hf_model=true \ + hardware=cpu skip_jax_distributed_system=True