From 7da3738f29e94687511b152daf53ac4b7b4f4c0e Mon Sep 17 00:00:00 2001 From: Sanbao Su Date: Fri, 6 Mar 2026 20:50:20 +0000 Subject: [PATCH 01/10] reshard code --- .../trainers/post_train/rl/reshard_debug.py | 400 ++++++++++++++++++ 1 file changed, 400 insertions(+) create mode 100644 src/maxtext/trainers/post_train/rl/reshard_debug.py diff --git a/src/maxtext/trainers/post_train/rl/reshard_debug.py b/src/maxtext/trainers/post_train/rl/reshard_debug.py new file mode 100644 index 0000000000..063a4fe79d --- /dev/null +++ b/src/maxtext/trainers/post_train/rl/reshard_debug.py @@ -0,0 +1,400 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +RL Trainer + +This module provides a unified `rl_train` function that consolidates the common +RL training logic. It handles model loading, reward function setup, dataset +processing, and training orchestration. By default, we run Group Relative Policy Optimization (GRPO) on +GSM8K math reasoning benchmark. The script is also flexible enough to run Group Sequence Policy Optimization (GSPO). + +Usage Examples: + +# GRPO on Llama3.1-8B-Instruct +python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \ + model_name=llama3.1-8b \ + tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ + load_parameters_path=gs://path/to/checkpoint/0/items \ + run_name=$WORKLOAD \ + base_output_directory=$OUTPUT_PATH \ + hf_access_token=$HF_TOKEN + +# GSPO on Llama3.1-70B-Instruct +python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \ + model_name=llama3.1-70b \ + tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \ + load_parameters_path=gs://path/to/checkpoint/0/items \ + run_name=$WORKLOAD \ + base_output_directory=$OUTPUT_PATH \ + hf_access_token=$HF_TOKEN \ + loss_algo=gspo-token + +""" + +from __future__ import annotations +from typing import Sequence + +import collections +import jax +import json +import logging +import os +import pathwaysutils + +from absl import app +from absl import logging as absl_logging +from flax import nnx +from jax.sharding import Mesh +from orbax import checkpoint as ocp +from transformers import AutoTokenizer +from tunix.rl import rl_cluster as rl_cluster_lib +from tunix.rl.rollout import base_rollout +from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner +from tunix.sft import metrics_logger, profiler +from tunix.sft.utils import show_hbm_usage + +# for vLLM we can skip JAX precompilation with this flag, it makes startup faster +os.environ["SKIP_JAX_PRECOMPILE"] = "1" + +from maxtext.configs import pyconfig +from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR +from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter +from maxtext.trainers.post_train.rl import utils_rl +from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils + + +def get_maxtext_model(config, devices=None): + """ + Load MaxText model with Tunix adapter. + # Note: pass the path to your scanned checkpoint for 'load_parameters_path'. + # To create a scanned checkpoint, you can use /maxtext/src/MaxText/checkpoint_conversion/to_maxtext.py and if + # using Pathways, please set `checkpoint_storage_use_ocdbt=False checkpoint_storage_use_zarr3=False` + # python src/MaxText/checkpoint_conversion/to_maxtext.py \ + # --model_name="gemma2-2b" \ + # --base_output_directory="/path/to/your/output/directory" \ + # --scan_layers=True \ + # --checkpoint_storage_use_ocdbt=False\ + # checkpoint_storage_use_zarr3=False + # Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e., + # load_parameters_path=/path/to/your/output/directory/0/items + """ + model, mesh = model_creation_utils.create_nnx_model(config, devices=devices) + with mesh: + use_no_op_mappings = "maxtext_config" in config.vllm_additional_config + tunix_model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=use_no_op_mappings) + tunix_model.config = None + return tunix_model, mesh + + +def setup_configs_and_devices(argv: list[str]): + """Setup device allocation and configs for training and inference.""" + config = pyconfig.initialize_pydantic(argv) + devices = jax.devices() + if config.num_trainer_slices == -1 and config.num_samplers_slices == -1: + max_logging.log("Running RL on a single slice") + num_vms = len(devices) // config.chips_per_vm + trainer_devices = devices + sampler_devices = devices + if num_vms >= 2 and config.use_pathways: + # Multiple hosts with Pathways - potentially split devices for trainer and sampler + # based on trainer_devices_fraction and sampler_devices_fraction + max_logging.log(f"{num_vms} VMs detected, allocating trainer and sampler devices, and using Pathways.") + num_devices = len(devices) + num_trainer_devices = int(num_devices * config.trainer_devices_fraction) + num_sampler_devices = int(num_devices * config.sampler_devices_fraction) + trainer_devices = devices[:num_trainer_devices] + sampler_devices = devices[num_devices - num_sampler_devices :] + if config.trainer_devices_fraction != 1.0: + max_logging.log(f"Using first {len(trainer_devices)} devices as Trainer devices") + if config.sampler_devices_fraction != 1.0: + max_logging.log(f"Using last {len(sampler_devices)} devices as Sampler devices") + trainer_config = config + sampler_config = config + elif config.num_trainer_slices > 0 and config.num_samplers_slices > 0: + max_logging.log("Running RL with Multislice") + devices_by_slice = collections.defaultdict(list) + for d in devices: + devices_by_slice[d.slice_index].append(d) + slice_indices = sorted(devices_by_slice.keys()) + + if len(slice_indices) < config.num_trainer_slices + config.num_samplers_slices: + raise ValueError("Not enough slices for trainer and samplers") + + trainer_devices = [] + for i in range(config.num_trainer_slices): + trainer_devices.extend(devices_by_slice[slice_indices[i]]) + + sampler_devices = [] + for i in range(config.num_trainer_slices, config.num_trainer_slices + config.num_samplers_slices): + sampler_devices.extend(devices_by_slice[slice_indices[i]]) + + trainer_devices_per_slice = len(trainer_devices) // config.num_trainer_slices + trainer_fsdp = trainer_devices_per_slice + tp = config.ici_tensor_parallelism + if tp > 1: + if trainer_devices_per_slice % tp != 0: + raise ValueError( + f"trainer_devices_per_slice ({trainer_devices_per_slice}) must be divisible by tensor parallelism ({tp})" + ) + if config.ici_fsdp_parallelism != -1 and config.ici_fsdp_parallelism * tp != trainer_devices_per_slice: + raise ValueError( + f"ici_fsdp_parallelism ({config.ici_fsdp_parallelism}) * ici_tensor_parallelism ({tp}) must equal " + f"devices_per_slice ({trainer_devices_per_slice})" + ) + trainer_fsdp = trainer_devices_per_slice // tp + + trainer_update = { + "num_slices": config.num_trainer_slices, + "ici_fsdp_parallelism": trainer_fsdp, + "ici_tensor_parallelism": tp, + "dcn_data_parallelism": config.num_trainer_slices, + } + + sampler_update = { + "num_slices": config.num_samplers_slices, + "ici_fsdp_parallelism": len(sampler_devices) // config.num_samplers_slices, + "ici_tensor_parallelism": -1, + "dcn_data_parallelism": config.num_samplers_slices, + } + + trainer_config = pyconfig.initialize_pydantic(argv, **trainer_update) + sampler_config = pyconfig.initialize_pydantic(argv, **sampler_update) + + else: + raise ValueError("num_trainer_slices and num_samplers_slices should be both -1 or positive") + + return trainer_config, sampler_config, trainer_devices, sampler_devices + + +def get_rollout_kwargs_for_data_parallelism(sampler_config, num_sampler_devices): + """Get rollout kwargs for vLLM rollout when using data parallelism.""" + dp = sampler_config.rollout_data_parallelism + if dp == -1: + return {} + + rollout_kwargs = {} + tp = sampler_config.rollout_tensor_parallelism + + if tp == -1: + if num_sampler_devices % dp != 0: + raise ValueError( + f"num_sampler_devices({num_sampler_devices}) must be divisible by " + f"rollout_data_parallelism({dp}) " + f"when rollout_tensor_parallelism is -1." + ) + tp = num_sampler_devices // dp + elif tp * dp != num_sampler_devices: + raise ValueError( + f"rollout_tensor_parallelism({tp}) * " + f"rollout_data_parallelism({dp}) " + f"!= len(sampler_devices)({num_sampler_devices})" + ) + rollout_kwargs["tensor_parallel_size"] = tp + rollout_kwargs["data_parallel_size"] = dp + rollout_kwargs["rollout_vllm_async_scheduling"] = True + + return rollout_kwargs + + +def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): + """ + Run RL training with the provided configuration. + + Args: + trainer_config: MaxText configuration for the trainer. + sampler_config: MaxText configuration for the sampler. + trainer_devices: JAX devices for the trainer. + sampler_devices: JAX devices for the sampler. + """ + if not trainer_config.debug.rl: + # Apply filter to suppress noisy logs + noise_filter = max_logging.NoisyLogFilter() + logging.getLogger().addFilter(noise_filter) + absl_logging.get_absl_logger().addFilter(noise_filter) + + max_logging.log("Starting RL Resharding Debug Script") + + # Number of training steps. + max_train_steps = 1 + + # Create model tokenizer + model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path) + + # Load reference model + max_logging.log("Creating reference model and also meshes for reference and rollout") + reference_model, reference_mesh = get_maxtext_model(trainer_config, trainer_devices) + devices_array = maxtext_utils.create_device_mesh(sampler_config, sampler_devices) + # if trainer_devices=sampler_devices, then rollout_mesh=reference_mesh + # else rollout_mesh uses sampler_devices + rollout_mesh = Mesh(devices_array, sampler_config.mesh_axes) + if trainer_config.debug.rl: + max_logging.log("Reference Model initialized successfully") + nnx.display(reference_model) + max_logging.log(f"Reference mesh shape: {reference_mesh.shape}") + + # Sanity check that weights are loaded correctly. + _maxtext_state_flatten = nnx.state(reference_model).flat_state() + maxtext_state_flatten = {".".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten} + max_logging.log( + f"maxtext_state_flatten[base.token_embedder.embedding].value=\ + {maxtext_state_flatten['base.token_embedder.embedding'][...]}" + ) + + # TODO: @mazumdera: change this to use lora + if trainer_config.load_checkpoint_only_once: + max_logging.log("Creating policy model by copying reference model instead of restoring from checkpoint again.") + with reference_mesh: + actor_base_model = nnx.clone(reference_model.base) + use_no_op_mappings = "maxtext_config" in trainer_config.vllm_additional_config + actor_model = TunixMaxTextAdapter(base_model=actor_base_model, use_no_op_mappings=use_no_op_mappings) + actor_model.config = None + actor_mesh = reference_mesh + else: + max_logging.log("Creating policy model with same config as reference model on trainer mesh") + actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices) + + if trainer_config.debug.rl: + max_logging.log("Policy Model initialized successfully") + nnx.display(actor_model) + max_logging.log(f"Policy mesh shape: {actor_mesh.shape}") + + # Setup optimizer + optimizer = utils_rl.get_optimizer(trainer_config, max_train_steps) + + # Setup checkpointing + checkpointing_options = ocp.CheckpointManagerOptions( + save_interval_steps=trainer_config.checkpoint_period, max_to_keep=trainer_config.max_num_checkpoints_to_keep + ) + + # Set up micro batching + micro_batch_size = None if trainer_config.micro_batch_size == -1 else trainer_config.micro_batch_size + + # Parse vllm_additional_config + rollout_additional_config = None + if trainer_config.vllm_additional_config: + if isinstance(trainer_config.vllm_additional_config, dict): + # It's already parsed into a dict + rollout_additional_config = trainer_config.vllm_additional_config + elif isinstance(trainer_config.vllm_additional_config, str): + # It's a string, so we need to parse it + try: + rollout_additional_config = json.loads(trainer_config.vllm_additional_config) + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse additional_config JSON: {e}") from e + + max_logging.log(f"Parsed additional config: {rollout_additional_config}") + + # We need to parse vLLM config to get the logical axis rules for the sampler config. + vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml") + argv_list = ["", str(vllm_config_path), "log_config=False"] + vllm_config = pyconfig.initialize(argv_list) + + # RL Cluster config + # Note that we use vLLM as the rollout engine. + # and we are using Tensor Parallelism for rollout + cluster_config = rl_cluster_lib.ClusterConfig( + role_to_mesh={ + rl_cluster_lib.Role.ACTOR: actor_mesh, + rl_cluster_lib.Role.REFERENCE: reference_mesh, + rl_cluster_lib.Role.ROLLOUT: rollout_mesh, + }, + role_to_logical_axis_rule={ + rl_cluster_lib.Role.ACTOR: trainer_config.logical_axis_rules, + rl_cluster_lib.Role.REFERENCE: trainer_config.logical_axis_rules, + rl_cluster_lib.Role.ROLLOUT: vllm_config.logical_axis_rules, + }, + rollout_engine="vllm", + offload_to_cpu=False, + training_config=rl_cluster_lib.RLTrainingConfig( + actor_optimizer=optimizer, + eval_every_n_steps=trainer_config.eval_interval, + max_steps=max_train_steps, + # Micro batching + mini_batch_size=trainer_config.batch_size, + train_micro_batch_size=micro_batch_size, + rollout_micro_batch_size=micro_batch_size, + # Checkpoint saving + checkpoint_root_directory=trainer_config.checkpoint_dir, + checkpointing_options=checkpointing_options, + ), + rollout_config=base_rollout.RolloutConfig( + max_tokens_to_generate=trainer_config.max_target_length - trainer_config.max_prefill_predict_length, + max_prompt_length=trainer_config.max_prefill_predict_length, + kv_cache_size=trainer_config.max_target_length + trainer_config.kv_cache_buffer, + temperature=trainer_config.decode_sampling_temperature, + top_p=trainer_config.decode_sampling_nucleus_p, + top_k=trainer_config.decode_sampling_top_k, + rollout_vllm_model_version=trainer_config.tokenizer_path, + rollout_vllm_hbm_utilization=trainer_config.hbm_utilization_vllm, + rollout_vllm_tpu_backend_type="jax", + rollout_vllm_swap_space_size_gb=trainer_config.swap_space_vllm_gb, + rollout_vllm_hf_config_path=trainer_config.vllm_hf_config_path, + rollout_vllm_additional_config=rollout_additional_config, + rollout_vllm_init_with_random_weights=True, + rollout_vllm_enable_dp_attention=trainer_config.enable_dp_attention, + rollout_vllm_max_num_batched_tokens=trainer_config.max_num_batched_tokens, + rollout_vllm_max_num_seqs=trainer_config.max_num_seqs, + rollout_vllm_kwargs={ + "hf_overrides": trainer_config.vllm_hf_overrides, + }, + **get_rollout_kwargs_for_data_parallelism(sampler_config, len(sampler_devices)), + ), + ) + # Create RL cluster + max_logging.log("Creating RL cluster...") + + rl_cluster = rl_cluster_lib.RLCluster( + actor=actor_model, + reference=reference_model, + tokenizer=model_tokenizer, + cluster_config=cluster_config, + ) + + max_logging.log( + "Calling rl_cluster.sync_weights() to reshard actor weights to rollout mesh..." + ) + + key = jax.random.PRNGKey(42) + for step in range(trainer_config.num_batches): + key, subkey = jax.random.split(key) + noise = jax.random.normal(subkey, ()) * 1e-3 + # Update all actor weights to trigger full resharding + state = nnx.state(actor_model, nnx.Param) + new_state = jax.tree_util.tree_map(lambda x: x + noise, state) + nnx.update(actor_model, new_state) + + show_hbm_usage(f"HBM before step {step}:") + rl_cluster.sync_weights() + jax.tree_util.tree_map(jax.block_until_ready, rl_cluster.rollout._sampler.transformer_state) + show_hbm_usage(f"HBM after step {step}:") + max_logging.log(f"Resharding via sync_weights() completed: step {step}") + + +def main(argv: Sequence[str]) -> None: + """Main function to run RL training. + + Args: + argv: Command-line arguments. + """ + pathwaysutils.initialize() + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + + max_utils.print_system_information() + trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(argv) + rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices) + + +if __name__ == "__main__": + app.run(main) \ No newline at end of file From ff8b01f6ccd05c28aaa39f79eafd1dafe836a031 Mon Sep 17 00:00:00 2001 From: Sanbao Su Date: Sat, 7 Mar 2026 00:34:37 +0000 Subject: [PATCH 02/10] EP --- .../trainers/post_train/rl/reshard_debug.py | 57 ++++++++++++++----- 1 file changed, 44 insertions(+), 13 deletions(-) diff --git a/src/maxtext/trainers/post_train/rl/reshard_debug.py b/src/maxtext/trainers/post_train/rl/reshard_debug.py index 063a4fe79d..b941f8f52a 100644 --- a/src/maxtext/trainers/post_train/rl/reshard_debug.py +++ b/src/maxtext/trainers/post_train/rl/reshard_debug.py @@ -178,32 +178,56 @@ def setup_configs_and_devices(argv: list[str]): return trainer_config, sampler_config, trainer_devices, sampler_devices -def get_rollout_kwargs_for_data_parallelism(sampler_config, num_sampler_devices): +def get_rollout_kwargs_for_parallelism(sampler_config, num_sampler_devices): """Get rollout kwargs for vLLM rollout when using data parallelism.""" dp = sampler_config.rollout_data_parallelism - if dp == -1: - return {} - - rollout_kwargs = {} tp = sampler_config.rollout_tensor_parallelism + ep = sampler_config.rollout_expert_parallelism + + # -1 means "auto-derive from the other two". At most one can be -1. + num_auto = sum(1 for x in [tp, dp, ep] if x == -1) + if num_auto > 1: + raise ValueError( + "At most one of rollout_tensor_parallelism, rollout_data_parallelism, " + "rollout_expert_parallelism can be -1 (auto-derived)." + ) - if tp == -1: - if num_sampler_devices % dp != 0: + if dp == -1: + if num_sampler_devices % (tp * ep) != 0: + raise ValueError( + f"num_sampler_devices({num_sampler_devices}) must be divisible by " + f"rollout_tensor_parallelism({tp}) * rollout_expert_parallelism({ep}) " + f"when rollout_data_parallelism is -1." + ) + dp = num_sampler_devices // tp // ep + elif tp == -1: + if num_sampler_devices % (dp * ep) != 0: raise ValueError( f"num_sampler_devices({num_sampler_devices}) must be divisible by " - f"rollout_data_parallelism({dp}) " + f"rollout_data_parallelism({dp}) * rollout_expert_parallelism({ep}) " f"when rollout_tensor_parallelism is -1." ) - tp = num_sampler_devices // dp - elif tp * dp != num_sampler_devices: + tp = num_sampler_devices // dp // ep + elif ep == -1: + if num_sampler_devices % (tp * dp) != 0: + raise ValueError( + f"num_sampler_devices({num_sampler_devices}) must be divisible by " + f"rollout_tensor_parallelism({tp}) * rollout_data_parallelism({dp}) " + f"when rollout_expert_parallelism is -1." + ) + ep = num_sampler_devices // tp // dp + elif tp * dp * ep != num_sampler_devices: raise ValueError( f"rollout_tensor_parallelism({tp}) * " - f"rollout_data_parallelism({dp}) " + f"rollout_data_parallelism({dp}) * " + f"rollout_expert_parallelism({ep}) " f"!= len(sampler_devices)({num_sampler_devices})" ) + + rollout_kwargs = {} rollout_kwargs["tensor_parallel_size"] = tp rollout_kwargs["data_parallel_size"] = dp - rollout_kwargs["rollout_vllm_async_scheduling"] = True + rollout_kwargs["expert_parallel_size"] = ep return rollout_kwargs @@ -346,10 +370,17 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): rollout_vllm_enable_dp_attention=trainer_config.enable_dp_attention, rollout_vllm_max_num_batched_tokens=trainer_config.max_num_batched_tokens, rollout_vllm_max_num_seqs=trainer_config.max_num_seqs, + rollout_vllm_async_scheduling=trainer_config.async_scheduling, rollout_vllm_kwargs={ "hf_overrides": trainer_config.vllm_hf_overrides, + "enable_expert_parallel": sampler_config.rollout_expert_parallelism > 1, + }, + rollout_vllm_sampling_kwargs={ + "stop": trainer_config.stop_strings, + "detokenize": trainer_config.stop_strings is not None, + "include_stop_str_in_output": trainer_config.stop_strings is not None, }, - **get_rollout_kwargs_for_data_parallelism(sampler_config, len(sampler_devices)), + **get_rollout_kwargs_for_parallelism(sampler_config, len(sampler_devices)), ), ) # Create RL cluster From 7f42c812008634ab6a2db597c63e3abd2e643839 Mon Sep 17 00:00:00 2001 From: Sanbao Su Date: Wed, 11 Mar 2026 04:09:55 +0000 Subject: [PATCH 03/10] extract time from log --- .../trainers/post_train/rl/extract_time.py | 67 +++++++++++++++++++ .../trainers/post_train/rl/reshard_debug.py | 48 +++++++------ 2 files changed, 93 insertions(+), 22 deletions(-) create mode 100644 src/maxtext/trainers/post_train/rl/extract_time.py diff --git a/src/maxtext/trainers/post_train/rl/extract_time.py b/src/maxtext/trainers/post_train/rl/extract_time.py new file mode 100644 index 0000000000..f36835a0ac --- /dev/null +++ b/src/maxtext/trainers/post_train/rl/extract_time.py @@ -0,0 +1,67 @@ +import re +import pandas as pd +from google.cloud import logging +from google.cloud.logging import DESCENDING +from datetime import datetime, timedelta, timezone + +def get_reshard_data(): + client = logging.Client(project="cloud-tpu-multipod-dev") + + # 1. Define a narrow time window (last 24 hours) + # This prevents the API from searching the entire history of the project + start_time = (datetime.now(timezone.utc) - timedelta(days=5)).strftime('%Y-%m-%dT%H:%M:%SZ') + + # 2. Build the exact filter that worked in your UI + # We replace SEARCH() with textPayload: which is the API equivalent + log_filter = ( + f'resource.type="k8s_container" ' + f'resource.labels.location="us-central1" ' + f'resource.labels.cluster_name="zxhe-super-xpk-bid" ' + f'resource.labels.namespace_name="default" ' + f'resource.labels.pod_name:"sanbao-rl-0307-2" ' + f'severity>=DEFAULT ' + f'timestamp >= "{start_time}" ' + f'SEARCH("Reshard finished in")' + ) + + print(f"Querying logs from the last 24 hours (Newest first)...") + + # 3. Use order_by=DESCENDING to find recent logs immediately + entries = client.list_entries(filter_=log_filter, order_by=DESCENDING) + + pattern = r"Reshard finished in (\d+\.?\d*)s" + results = [] + + try: + for entry in entries: + payload = entry.payload + payload_str = None + if isinstance(payload, dict): + payload_str = payload.get("message") or str(payload) + else: + payload_str = str(payload) + if payload_str: + match = re.search(pattern, payload_str) + if match: + results.append({ + "timestamp": entry.timestamp, + "reshard_sec": float(match.group(1)), + "pod": entry.resource.labels.get("pod_name") + }) + except Exception as e: + print(f"Error during API call: {e}") + + if not results: + print("Still no logs found. Try this final check:") + print(f"1. Run: gcloud logging read '{log_filter}' --limit=1") + print("2. If that returns nothing, your local gcloud credentials don't have permission for this project.") + return None + + df = pd.DataFrame(results).sort_values("timestamp") + df.to_csv("reshard_times.csv", index=False) + + print(f"Success! Found {len(df)} events.") + print(df.describe()) + return df + +df = get_reshard_data() \ No newline at end of file diff --git a/src/maxtext/trainers/post_train/rl/reshard_debug.py b/src/maxtext/trainers/post_train/rl/reshard_debug.py index b941f8f52a..66e854eee3 100644 --- a/src/maxtext/trainers/post_train/rl/reshard_debug.py +++ b/src/maxtext/trainers/post_train/rl/reshard_debug.py @@ -13,7 +13,7 @@ # limitations under the License. """ -RL Trainer +Resharding Benchmark for the RL Trainer This module provides a unified `rl_train` function that consolidates the common RL training logic. It handles model loading, reward function setup, dataset @@ -22,24 +22,23 @@ Usage Examples: -# GRPO on Llama3.1-8B-Instruct -python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \ - model_name=llama3.1-8b \ - tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ - load_parameters_path=gs://path/to/checkpoint/0/items \ - run_name=$WORKLOAD \ - base_output_directory=$OUTPUT_PATH \ - hf_access_token=$HF_TOKEN - -# GSPO on Llama3.1-70B-Instruct -python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \ - model_name=llama3.1-70b \ - tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \ - load_parameters_path=gs://path/to/checkpoint/0/items \ - run_name=$WORKLOAD \ - base_output_directory=$OUTPUT_PATH \ - hf_access_token=$HF_TOKEN \ - loss_algo=gspo-token +# GRPO on Qwen3-30B +python3 -m src.maxtext.trainers.post_train.rl.reshard_debug src/maxtext/configs/post_train/rl.yml \ + model_name=qwen3-30b-a3b \ + tokenizer_path=Qwen/Qwen3-30B-A3B \ + run_name=sanbao-rl-0310-1 \ + base_output_directory=gs://sanbao-bucket/mlperf_rl/qwen3/sanbao-rl-0310-1 \ + batch_size=16 \ + rl.num_generations=8 \ + num_batches=4 \ + rollout_data_parallelism=4 \ + rollout_tensor_parallelism=1 \ + rollout_expert_parallelism=4 \ + hbm_utilization_vllm=0.4 \ + scan_layers=True \ + allow_split_physical_axes=True \ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ + vllm_additional_config='{maxtext_config: {model_name: qwen3-30b-a3b, allow_split_physical_axes: true, log_config: false, weight_dtype: bfloat16}}' """ @@ -49,6 +48,7 @@ import collections import jax import json +import time import logging import os import pathwaysutils @@ -120,8 +120,8 @@ def setup_configs_and_devices(argv: list[str]): max_logging.log(f"Using first {len(trainer_devices)} devices as Trainer devices") if config.sampler_devices_fraction != 1.0: max_logging.log(f"Using last {len(sampler_devices)} devices as Sampler devices") - trainer_config = config - sampler_config = config + trainer_config = config.model_copy() + sampler_config = config.model_copy() elif config.num_trainer_slices > 0 and config.num_samplers_slices > 0: max_logging.log("Running RL with Multislice") devices_by_slice = collections.defaultdict(list) @@ -175,6 +175,8 @@ def setup_configs_and_devices(argv: list[str]): else: raise ValueError("num_trainer_slices and num_samplers_slices should be both -1 or positive") + sampler_config.subslice_shape = "" # we are not using subslices in this script, set it to empty to avoid confusion + sampler_config.enable_single_controller = False # we are not using single controller in this script, set it to False to avoid confusion return trainer_config, sampler_config, trainer_devices, sampler_devices @@ -407,10 +409,12 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): nnx.update(actor_model, new_state) show_hbm_usage(f"HBM before step {step}:") + start_time = time.time() rl_cluster.sync_weights() jax.tree_util.tree_map(jax.block_until_ready, rl_cluster.rollout._sampler.transformer_state) + end_time = time.time() show_hbm_usage(f"HBM after step {step}:") - max_logging.log(f"Resharding via sync_weights() completed: step {step}") + max_logging.log(f"Resharding via sync_weights() completed: step {step}. Weight Syncing Time taken: {end_time - start_time:.4f}s") def main(argv: Sequence[str]) -> None: From e87b5efa185a963e2fd3f925c722d033ae7db4ff Mon Sep 17 00:00:00 2001 From: Sanbao Su Date: Wed, 11 Mar 2026 08:14:21 +0000 Subject: [PATCH 04/10] automatic --- src/maxtext/configs/post_train/rl.yml | 2 + src/maxtext/configs/types.py | 5 + .../trainers/post_train/rl/create_yaml.py | 342 ++++++++++++++++++ .../trainers/post_train/rl/extract_time.py | 33 +- .../trainers/post_train/rl/reshard_debug.py | 7 +- 5 files changed, 378 insertions(+), 11 deletions(-) create mode 100644 src/maxtext/trainers/post_train/rl/create_yaml.py diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml index 7d3659ae59..19022ef23e 100644 --- a/src/maxtext/configs/post_train/rl.yml +++ b/src/maxtext/configs/post_train/rl.yml @@ -28,6 +28,8 @@ num_samplers_slices: -1 rollout_data_parallelism: -1 rollout_tensor_parallelism: -1 rollout_expert_parallelism: 1 +rollout_subslice_shape: "" # e.g. '2,2,1' for 4 chips with DP=2, TP=2, EP=1 +rollout_enable_single_controller: False # If True, use a single controller for rollout. This can help with stability when using more than 1 model replica in rollout. # ====== Reproducibility ====== data_shuffle_seed: 42 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 0bb6d49701..08ab152355 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1586,6 +1586,11 @@ class RLHardware(BaseModel): description="Tensor parallelism per replica for rollout. If not specified, it will be auto-determined.", ) rollout_expert_parallelism: int = Field(1, description="Expert parallelism per replica for rollout") + rollout_subslice_shape: str = Field("", description="Subslice shape for rollout in the form of 'x,y,z' for Pathways.") + rollout_enable_single_controller: bool = Field( + False, + description="Whether to enable single-controller mode for rollout. If True, the trainer will also run the rollout and sampling computations instead of launching separate processes. This is only recommended for debugging or if the rollout computation is very small and can be efficiently handled by a single controller.", + ) class VLLM(BaseModel): diff --git a/src/maxtext/trainers/post_train/rl/create_yaml.py b/src/maxtext/trainers/post_train/rl/create_yaml.py new file mode 100644 index 0000000000..bc52a2fdea --- /dev/null +++ b/src/maxtext/trainers/post_train/rl/create_yaml.py @@ -0,0 +1,342 @@ +import os +from jinja2 import Template +import argparse + +def generate_rl_config( + metadata_name, + batch_size, + rollout_data_parallelism, + rollout_tensor_parallelism, + rollout_expert_parallelism, + trainer_devices_fraction, + subslice_shape, + enable_single_controller, + sampler_devices_fraction, + base_output_directory, + run_name, + hf_token +): + yaml_template = """apiVersion: jobset.x-k8s.io/v1alpha2 +kind: JobSet +metadata: + labels: + kueue.x-k8s.io/queue-name: multislice-queue + name: {{ metadata_name }} + namespace: default +spec: + coordinator: + replicatedJob: pathways-head + failurePolicy: + maxRestarts: 1 + restartStrategy: Recreate + network: + enableDNSHostnames: true + publishNotReadyAddresses: true + replicatedJobs: + - name: pathways-head + replicas: 1 + template: + metadata: + annotations: + kueue.x-k8s.io/safe-to-forcefully-terminate: "true" + spec: + backoffLimit: 0 + completionMode: Indexed + completions: 1 + parallelism: 1 + template: + spec: + containers: + - command: + - bash + - -c + - | + echo XPK Start: $(date); + _sigterm() (kill -SIGTERM $! 2>/dev/null;); + trap _sigterm SIGTERM; + + (pip install --no-deps git+https://github.com/AI-Hypercomputer/pathways-utils.git@v0.1.4 && \\ + pip install src/maxtext/integration/vllm && \\ + HF_TOKEN={{ hf_token }} JAX_RANDOM_WEIGHTS=1 VLLM_ENABLE_V1_MULTIPROCESSING=0 NEW_MODEL_DESIGN=1 TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0 JAX_PLATFORMS=proxy,cpu JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 \\ + python3 -m src.maxtext.trainers.post_train.rl.reshard_debug src/maxtext/configs/post_train/rl.yml \\ + model_name=qwen3-30b-a3b \\ + tokenizer_path=Qwen/Qwen3-30B-A3B \\ + run_name={{ run_name }} \\ + base_output_directory={{ base_output_directory }} \\ + hf_access_token={{ hf_token }} \\ + batch_size={{ batch_size }} \\ + rl.num_generations={{ batch_size }} \\ + num_batches=4 \\ + rollout_data_parallelism={{ rollout_data_parallelism }} \\ + rollout_tensor_parallelism={{ rollout_tensor_parallelism }} \\ + rollout_expert_parallelism={{ rollout_expert_parallelism }} \\ + hbm_utilization_vllm=0.4 \\ + scan_layers=True \\ + allow_split_physical_axes=True \\ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \\ + vllm_additional_config='{maxtext_config: {model_name: qwen3-30b-a3b, allow_split_physical_axes: true, log_config: false, weight_dtype: bfloat16}}' \\ + trainer_devices_fraction={{ trainer_devices_fraction }} \\ + subslice_shape='{{ subslice_shape }}' \\ + enable_single_controller={{ enable_single_controller }} \\ + sampler_devices_fraction={{ sampler_devices_fraction }}) & PID=$!; + + while kill -0 $PID 2>/dev/null; + do sleep 5; + done; + wait $PID; + EXIT_CODE=$?; + + echo XPK End: $(date); + echo EXIT_CODE=$EXIT_CODE; + + exit $EXIT_CODE + env: + - name: PATHWAYS_HEAD + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + - name: JAX_PLATFORMS + value: proxy + - name: XCLOUD_ENVIRONMENT + value: GCP + - name: JAX_BACKEND_TARGET + value: grpc://$(PATHWAYS_HEAD):29000 + image: gcr.io/cloud-tpu-multipod-dev/sanbao/maxtext_reshard_image:latest + imagePullPolicy: Always + name: jax-tpu + resources: + limits: + cpu: "24" + memory: 100G + securityContext: + privileged: true + volumeMounts: + - mountPath: /tmp + name: shared-tmp + dnsPolicy: ClusterFirstWithHostNet + hostNetwork: true + initContainers: + - args: + - --server_port=29001 + - --gcs_scratch_location=gs://cloud-pathways-staging/tmp + - --node_type=resource_manager + - --instance_count=1 + - --instance_type=tpu7x:4x4x4 + env: + - name: REPLICATED_JOB_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name'] + - name: JOBSET_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name'] + - name: HOST_ADDRESS + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + - name: TPU_SKIP_MDS_QUERY + value: "true" + image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest + imagePullPolicy: Always + name: pathways-rm + ports: + - containerPort: 29001 + protocol: TCP + - containerPort: 29002 + protocol: TCP + resources: + limits: + cpu: "8" + memory: 32G + restartPolicy: Always + - args: + - --server_port=29000 + - --resource_manager_address=$(PATHWAYS_HEAD):29001 + - --gcs_scratch_location=gs://cloud-pathways-staging/tmp + env: + - name: PATHWAYS_HEAD + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest + imagePullPolicy: Always + name: pathways-proxy + ports: + - containerPort: 29000 + protocol: TCP + resources: + limits: + cpu: "16" + memory: 100G + restartPolicy: Always + nodeSelector: + cloud.google.com/gke-nodepool: cpu-np + restartPolicy: Never + volumes: + - hostPath: + path: /tmp + type: DirectoryOrCreate + name: shared-tmp + - name: worker + replicas: 1 + template: + metadata: + annotations: + cloud.google.com/gke-tpu-slice-topology: 4x4x4 + spec: + backoffLimit: 32 + completionMode: Indexed + completions: 16 + parallelism: 16 + template: + metadata: + annotations: + cloud.google.com/gke-tpu-slice-topology: 4x4x4 + spec: + tolerations: + - key: "google.com/tpu" + operator: "Equal" + value: "present" + effect: "NoSchedule" + affinity: + nodeAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + nodeSelectorTerms: + - matchExpressions: + - key: cloud.google.com/gke-tpu-partition-4x4x4-state + operator: In + values: + - HEALTHY + - DEGRADED + containers: + - args: + - --server_port=29005 + - --resource_manager_address=$(PATHWAYS_HEAD):29001 + - --gcs_scratch_location=gs://cloud-pathways-staging/tmp + env: + - name: TPU_MIN_LOG_LEVEL + value: "0" + - name: TF_CPP_MIN_LOG_LEVEL + value: "0" + - name: XCLOUD_ENVIRONMENT + value: GCP + - name: MEGASCALE_GRPC_ENABLE_XOR_TRACER + value: "false" + - name: MEGASCALE_NUM_SLICES + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/replicatedjob-replicas'] + - name: JOBSET_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name'] + - name: REPLICATED_JOB_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name'] + - name: MEGASCALE_SLICE_ID + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/job-index'] + - name: PATHWAYS_HEAD + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + - name: MEGASCALE_COORDINATOR_ADDRESS + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest + imagePullPolicy: Always + name: pathways-worker + ports: + - containerPort: 29005 + protocol: TCP + - containerPort: 29006 + protocol: TCP + - containerPort: 8471 + protocol: TCP + - containerPort: 8080 + protocol: TCP + resources: + limits: + google.com/tpu: "4" + volumeMounts: + - mountPath: /tmp + name: shared-tmp + dnsPolicy: ClusterFirstWithHostNet + hostNetwork: true + nodeSelector: + cloud.google.com/gke-tpu-accelerator: tpu7x + priorityClassName: medium + restartPolicy: OnFailure + terminationGracePeriodSeconds: 30 + volumes: + - hostPath: + path: /tmp + type: DirectoryOrCreate + name: shared-tmp + startupPolicy: + startupPolicyOrder: InOrder + successPolicy: + operator: All + targetReplicatedJobs: + - pathways-head""" + + t = Template(yaml_template) + rendered_yaml = t.render( + metadata_name=metadata_name, + batch_size=batch_size, + rollout_data_parallelism=rollout_data_parallelism, + rollout_tensor_parallelism=rollout_tensor_parallelism, + rollout_expert_parallelism=rollout_expert_parallelism, + trainer_devices_fraction=trainer_devices_fraction, + subslice_shape=subslice_shape, + enable_single_controller=enable_single_controller, + sampler_devices_fraction=sampler_devices_fraction, + base_output_directory=base_output_directory, + run_name=run_name + ) + + return rendered_yaml + +# Example Usage: +if __name__ == "__main__": + # add args for metadat_name, trainer_chips, sampler_chips, rollout_data_parallelism, rollout_tensor_parallelism, rollout_expert_parallelism + + parser = argparse.ArgumentParser() + parser.add_argument("--metadata_name", type=str, required=True) + parser.add_argument("--trainer_chips", type=int, required=True) + parser.add_argument("--number_of_sampler_chips_per_replica", type=int, required=True) + parser.add_argument("--sampler_sharding_per_replica", type=int, required=True) + parser.add_argument("--sampler_replicas", type=int, required=True) + parser.add_argument("--base_output_directory", type=str, required=True) + parser.add_argument("--hf_token", type=str, required=True) + args = parser.parse_args() + + # for v7x-128 + number_of_chips = 64 + batch_size = args.trainer_chips * 2 + trainer_devices_fraction = args.trainer_chips / number_of_chips + rollout_data_parallelism = args.sampler_replicas + sampler_chips = args.number_of_sampler_chips_per_replica * args.sampler_sharding_per_replica + rollout_tensor_parallelism = sampler_chips // batch_size + + result = generate_rl_config( + metadata_name=args.metadata_name, + batch_size=batch_size, + rollout_data_parallelism=args.rollout_data_parallelism, + rollout_tensor_parallelism=args.rollout_tensor_parallelism, + rollout_expert_parallelism=args.rollout_expert_parallelism, + trainer_devices_fraction=0.0625, + subslice_shape="2,2,1", + enable_single_controller="true", + sampler_devices_fraction=0.0625, + base_output_directory=args.base_output_directory, + run_name=args.metadata_name + hf_token=args.hf_token + ) + + with open("qwen3-30b-v7x-temp.yaml", "w") as f: + f.write(result) \ No newline at end of file diff --git a/src/maxtext/trainers/post_train/rl/extract_time.py b/src/maxtext/trainers/post_train/rl/extract_time.py index f36835a0ac..785be3f827 100644 --- a/src/maxtext/trainers/post_train/rl/extract_time.py +++ b/src/maxtext/trainers/post_train/rl/extract_time.py @@ -1,10 +1,11 @@ +import argparse import re import pandas as pd from google.cloud import logging from google.cloud.logging import DESCENDING from datetime import datetime, timedelta, timezone -def get_reshard_data(): +def get_reshard_data(args): client = logging.Client(project="cloud-tpu-multipod-dev") # 1. Define a narrow time window (last 24 hours) @@ -18,7 +19,7 @@ def get_reshard_data(): f'resource.labels.location="us-central1" ' f'resource.labels.cluster_name="zxhe-super-xpk-bid" ' f'resource.labels.namespace_name="default" ' - f'resource.labels.pod_name:"sanbao-rl-0307-2" ' + f'resource.labels.pod_name:"{args.pod_name}" ' f'severity>=DEFAULT ' f'timestamp >= "{start_time}" ' f'SEARCH("Reshard finished in")' @@ -58,10 +59,26 @@ def get_reshard_data(): return None df = pd.DataFrame(results).sort_values("timestamp") - df.to_csv("reshard_times.csv", index=False) - - print(f"Success! Found {len(df)} events.") - print(df.describe()) - return df -df = get_reshard_data() \ No newline at end of file + # Only keep the third - tenth ones and compute the mean of them + # Note: iloc[2:min(df.shape[0], args.max_steps)] gets indices 2 through 9 (8 items), corresponding to 3rd through 10th + selected_df = df.iloc[2:min(df.shape[0], args.max_steps)] + mean_reshard_time = selected_df["reshard_sec"].mean() + + result_df = pd.DataFrame([{"pod_name": args.pod_name, "mean_reshard_time": mean_reshard_time}]) + # If the csv file already exists, append to it instead of overwriting + try: + existing_df = pd.read_csv("reshard_stats.csv") + result_df = pd.concat([existing_df, result_df], ignore_index=True) + except FileNotFoundError: + pass + result_df.to_csv("reshard_stats.csv", index=False) + print(result_df) + return result_df + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pod_name", type=str, required=True, help="Pod name") + parser.add_argument("--max_steps", type=int, default=10, help="Max steps") + args = parser.parse_args() + get_reshard_data(args) \ No newline at end of file diff --git a/src/maxtext/trainers/post_train/rl/reshard_debug.py b/src/maxtext/trainers/post_train/rl/reshard_debug.py index 66e854eee3..d47958fa0c 100644 --- a/src/maxtext/trainers/post_train/rl/reshard_debug.py +++ b/src/maxtext/trainers/post_train/rl/reshard_debug.py @@ -175,8 +175,8 @@ def setup_configs_and_devices(argv: list[str]): else: raise ValueError("num_trainer_slices and num_samplers_slices should be both -1 or positive") - sampler_config.subslice_shape = "" # we are not using subslices in this script, set it to empty to avoid confusion - sampler_config.enable_single_controller = False # we are not using single controller in this script, set it to False to avoid confusion + sampler_config.subslice_shape = config.rollout_subslice_shape + sampler_config.enable_single_controller = config.rollout_enable_single_controller return trainer_config, sampler_config, trainer_devices, sampler_devices @@ -414,7 +414,8 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): jax.tree_util.tree_map(jax.block_until_ready, rl_cluster.rollout._sampler.transformer_state) end_time = time.time() show_hbm_usage(f"HBM after step {step}:") - max_logging.log(f"Resharding via sync_weights() completed: step {step}. Weight Syncing Time taken: {end_time - start_time:.4f}s") + max_logging.log(f"Weight Syncing Time taken: {end_time - start_time:.4f}s") + max_logging.log(f"Resharding via sync_weights() completed: step {step}") def main(argv: Sequence[str]) -> None: From a747922ba2b9b8bcabbcaa3b9b03162c9097dd0d Mon Sep 17 00:00:00 2001 From: Sanbao Su Date: Wed, 11 Mar 2026 09:25:42 +0000 Subject: [PATCH 05/10] fix extract_time --- .../trainers/post_train/rl/extract_time.py | 59 +++++++++++++------ 1 file changed, 41 insertions(+), 18 deletions(-) diff --git a/src/maxtext/trainers/post_train/rl/extract_time.py b/src/maxtext/trainers/post_train/rl/extract_time.py index 785be3f827..08d4199597 100644 --- a/src/maxtext/trainers/post_train/rl/extract_time.py +++ b/src/maxtext/trainers/post_train/rl/extract_time.py @@ -8,11 +8,11 @@ def get_reshard_data(args): client = logging.Client(project="cloud-tpu-multipod-dev") - # 1. Define a narrow time window (last 24 hours) + # 1. Define a narrow time window (last 5 days) # This prevents the API from searching the entire history of the project start_time = (datetime.now(timezone.utc) - timedelta(days=5)).strftime('%Y-%m-%dT%H:%M:%SZ') - # 2. Build the exact filter that worked in your UI + # 2. Build the filter to search for both reshard and weight sync times. # We replace SEARCH() with textPayload: which is the API equivalent log_filter = ( f'resource.type="k8s_container" ' @@ -22,16 +22,18 @@ def get_reshard_data(args): f'resource.labels.pod_name:"{args.pod_name}" ' f'severity>=DEFAULT ' f'timestamp >= "{start_time}" ' - f'SEARCH("Reshard finished in")' + f'(SEARCH("Reshard finished in") OR SEARCH("Weight Syncing Time taken:"))' ) - print(f"Querying logs from the last 24 hours (Newest first)...") + print(f"Querying logs from the last 5 days (Newest first)...") # 3. Use order_by=DESCENDING to find recent logs immediately entries = client.list_entries(filter_=log_filter, order_by=DESCENDING) - pattern = r"Reshard finished in (\d+\.?\d*)s" - results = [] + reshard_pattern = r"Reshard finished in (\d+\.?\d*)s" + weight_sync_pattern = r"Weight Syncing Time taken: (\d+\.?\d*)s" + reshard_results = [] + weight_sync_results = [] try: for entry in entries: @@ -42,30 +44,49 @@ def get_reshard_data(args): else: payload_str = str(payload) if payload_str: - match = re.search(pattern, payload_str) - if match: - results.append({ + reshard_match = re.search(reshard_pattern, payload_str) + if reshard_match: + reshard_results.append({ "timestamp": entry.timestamp, - "reshard_sec": float(match.group(1)), + "reshard_sec": float(reshard_match.group(1)), + "pod": entry.resource.labels.get("pod_name") + }) + + weight_sync_match = re.search(weight_sync_pattern, payload_str) + if weight_sync_match: + weight_sync_results.append({ + "timestamp": entry.timestamp, + "weight_sync_sec": float(weight_sync_match.group(1)), "pod": entry.resource.labels.get("pod_name") }) except Exception as e: print(f"Error during API call: {e}") - if not results: + if not reshard_results and not weight_sync_results: print("Still no logs found. Try this final check:") print(f"1. Run: gcloud logging read '{log_filter}' --limit=1") print("2. If that returns nothing, your local gcloud credentials don't have permission for this project.") return None - df = pd.DataFrame(results).sort_values("timestamp") + mean_reshard_time = float('nan') + if reshard_results: + df = pd.DataFrame(reshard_results).sort_values("timestamp") + # Only keep the third - tenth ones and compute the mean of them + # Note: iloc[2:min(df.shape[0], args.max_steps)] gets indices 2 through 9 (8 items), corresponding to 3rd through 10th + selected_df = df.iloc[3:min(df.shape[0], args.max_steps)] + mean_reshard_time = selected_df["reshard_sec"].mean() - # Only keep the third - tenth ones and compute the mean of them - # Note: iloc[2:min(df.shape[0], args.max_steps)] gets indices 2 through 9 (8 items), corresponding to 3rd through 10th - selected_df = df.iloc[2:min(df.shape[0], args.max_steps)] - mean_reshard_time = selected_df["reshard_sec"].mean() + mean_weight_sync_time = float('nan') + if weight_sync_results: + df = pd.DataFrame(weight_sync_results).sort_values("timestamp") + selected_df = df.iloc[3:min(df.shape[0], args.max_steps)] + mean_weight_sync_time = selected_df["weight_sync_sec"].mean() - result_df = pd.DataFrame([{"pod_name": args.pod_name, "mean_reshard_time": mean_reshard_time}]) + result_df = pd.DataFrame([{ + "pod_name": args.pod_name, + "mean_reshard_time": mean_reshard_time, + "mean_weight_sync_time": mean_weight_sync_time + }]) # If the csv file already exists, append to it instead of overwriting try: existing_df = pd.read_csv("reshard_stats.csv") @@ -81,4 +102,6 @@ def get_reshard_data(args): parser.add_argument("--pod_name", type=str, required=True, help="Pod name") parser.add_argument("--max_steps", type=int, default=10, help="Max steps") args = parser.parse_args() - get_reshard_data(args) \ No newline at end of file + get_reshard_data(args) + +# python ./src/maxtext/trainers/post_train/rl/extract_time.py --pod_name sanbao-rl-0310-19 \ No newline at end of file From 7e19218ceaa9e9514fb50535d005563e6bf43f29 Mon Sep 17 00:00:00 2001 From: Sanbao Su Date: Fri, 13 Mar 2026 08:34:38 +0000 Subject: [PATCH 06/10] auto pipeline --- .../trainers/post_train/rl/create_yaml.py | 54 +++++++++++------- .../trainers/post_train/rl/extract_time.py | 57 +++++++++++++++++-- 2 files changed, 86 insertions(+), 25 deletions(-) diff --git a/src/maxtext/trainers/post_train/rl/create_yaml.py b/src/maxtext/trainers/post_train/rl/create_yaml.py index bc52a2fdea..3507f4f434 100644 --- a/src/maxtext/trainers/post_train/rl/create_yaml.py +++ b/src/maxtext/trainers/post_train/rl/create_yaml.py @@ -65,12 +65,12 @@ def generate_rl_config( base_output_directory={{ base_output_directory }} \\ hf_access_token={{ hf_token }} \\ batch_size={{ batch_size }} \\ - rl.num_generations={{ batch_size }} \\ - num_batches=4 \\ + rl.num_generations=8 \\ + num_batches=10 \\ rollout_data_parallelism={{ rollout_data_parallelism }} \\ rollout_tensor_parallelism={{ rollout_tensor_parallelism }} \\ rollout_expert_parallelism={{ rollout_expert_parallelism }} \\ - hbm_utilization_vllm=0.4 \\ + hbm_utilization_vllm=0.6 \\ scan_layers=True \\ allow_split_physical_axes=True \\ vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \\ @@ -296,23 +296,22 @@ def generate_rl_config( enable_single_controller=enable_single_controller, sampler_devices_fraction=sampler_devices_fraction, base_output_directory=base_output_directory, - run_name=run_name + run_name=run_name, + hf_token=hf_token ) return rendered_yaml # Example Usage: if __name__ == "__main__": - # add args for metadat_name, trainer_chips, sampler_chips, rollout_data_parallelism, rollout_tensor_parallelism, rollout_expert_parallelism - parser = argparse.ArgumentParser() parser.add_argument("--metadata_name", type=str, required=True) parser.add_argument("--trainer_chips", type=int, required=True) parser.add_argument("--number_of_sampler_chips_per_replica", type=int, required=True) - parser.add_argument("--sampler_sharding_per_replica", type=int, required=True) parser.add_argument("--sampler_replicas", type=int, required=True) parser.add_argument("--base_output_directory", type=str, required=True) parser.add_argument("--hf_token", type=str, required=True) + parser.add_argument("--store_directory", type=str, required=True) args = parser.parse_args() # for v7x-128 @@ -320,23 +319,40 @@ def generate_rl_config( batch_size = args.trainer_chips * 2 trainer_devices_fraction = args.trainer_chips / number_of_chips rollout_data_parallelism = args.sampler_replicas - sampler_chips = args.number_of_sampler_chips_per_replica * args.sampler_sharding_per_replica - rollout_tensor_parallelism = sampler_chips // batch_size + sampler_chips = args.number_of_sampler_chips_per_replica * args.sampler_replicas + assert sampler_chips + args.trainer_chips <= number_of_chips, "Total number of chips used by trainer and sampler must be less than or equal to available chips" + rollout_tensor_parallelism = args.number_of_sampler_chips_per_replica * 2 + rollout_expert_parallelism = rollout_tensor_parallelism // 4 if rollout_tensor_parallelism >= 4 else 1 + assert rollout_tensor_parallelism % rollout_expert_parallelism == 0, "rollout_tensor_parallelism must be divisible by rollout_expert_parallelism" + rollout_tensor_parallelism = 4 if rollout_tensor_parallelism >= 4 else rollout_tensor_parallelism + sampler_devices_fraction = sampler_chips / number_of_chips + if args.trainer_chips == 4: + subslice_shape = "2,2,1" + enable_single_controller = "true" + else: + subslice_shape = "" + enable_single_controller = "false" + + output_directory = os.path.join(args.base_output_directory, args.metadata_name) result = generate_rl_config( metadata_name=args.metadata_name, batch_size=batch_size, - rollout_data_parallelism=args.rollout_data_parallelism, - rollout_tensor_parallelism=args.rollout_tensor_parallelism, - rollout_expert_parallelism=args.rollout_expert_parallelism, - trainer_devices_fraction=0.0625, - subslice_shape="2,2,1", - enable_single_controller="true", - sampler_devices_fraction=0.0625, - base_output_directory=args.base_output_directory, - run_name=args.metadata_name + rollout_data_parallelism=rollout_data_parallelism, + rollout_tensor_parallelism=rollout_tensor_parallelism, + rollout_expert_parallelism=rollout_expert_parallelism, + trainer_devices_fraction=trainer_devices_fraction, + subslice_shape=subslice_shape, + enable_single_controller=enable_single_controller, + sampler_devices_fraction=sampler_devices_fraction, + base_output_directory=output_directory, + run_name=args.metadata_name, hf_token=args.hf_token ) + # if the yaml directory does not exist, create it + if not os.path.exists(args.store_directory): + os.makedirs(args.store_directory) + output_yaml_path = os.path.join(args.store_directory, f"{args.metadata_name}.yaml") - with open("qwen3-30b-v7x-temp.yaml", "w") as f: + with open(output_yaml_path, "w") as f: f.write(result) \ No newline at end of file diff --git a/src/maxtext/trainers/post_train/rl/extract_time.py b/src/maxtext/trainers/post_train/rl/extract_time.py index 08d4199597..0a9c30b316 100644 --- a/src/maxtext/trainers/post_train/rl/extract_time.py +++ b/src/maxtext/trainers/post_train/rl/extract_time.py @@ -1,9 +1,14 @@ import argparse import re +import urllib.parse import pandas as pd from google.cloud import logging from google.cloud.logging import DESCENDING from datetime import datetime, timedelta, timezone +import os + +# Example usage: +# python ./maxtext/src/maxtext/trainers/post_train/rl/extract_time.py --pod_name sanbao-rl-0312-2 def get_reshard_data(args): client = logging.Client(project="cloud-tpu-multipod-dev") @@ -22,7 +27,7 @@ def get_reshard_data(args): f'resource.labels.pod_name:"{args.pod_name}" ' f'severity>=DEFAULT ' f'timestamp >= "{start_time}" ' - f'(SEARCH("Reshard finished in") OR SEARCH("Weight Syncing Time taken:"))' + f'(SEARCH("Reshard finished in") OR SEARCH("Weight Syncing Time taken:") OR (SEARCH("Using") AND SEARCH("GiB on")))' ) print(f"Querying logs from the last 5 days (Newest first)...") @@ -32,8 +37,10 @@ def get_reshard_data(args): reshard_pattern = r"Reshard finished in (\d+\.?\d*)s" weight_sync_pattern = r"Weight Syncing Time taken: (\d+\.?\d*)s" + hbm_pattern = r"Using (\d+\.?\d*) GiB on" reshard_results = [] weight_sync_results = [] + hbm_results = [] try: for entry in entries: @@ -59,10 +66,18 @@ def get_reshard_data(args): "weight_sync_sec": float(weight_sync_match.group(1)), "pod": entry.resource.labels.get("pod_name") }) + + hbm_match = re.search(hbm_pattern, payload_str) + if hbm_match: + hbm_results.append({ + "timestamp": entry.timestamp, + "hbm_gib": float(hbm_match.group(1)), + "pod": entry.resource.labels.get("pod_name") + }) except Exception as e: print(f"Error during API call: {e}") - if not reshard_results and not weight_sync_results: + if not reshard_results and not weight_sync_results and not hbm_results: print("Still no logs found. Try this final check:") print(f"1. Run: gcloud logging read '{log_filter}' --limit=1") print("2. If that returns nothing, your local gcloud credentials don't have permission for this project.") @@ -82,18 +97,48 @@ def get_reshard_data(args): selected_df = df.iloc[3:min(df.shape[0], args.max_steps)] mean_weight_sync_time = selected_df["weight_sync_sec"].mean() + trainer_hbm = float('nan') + sampler_hbm = float('nan') + if hbm_results: + df_hbm = pd.DataFrame(hbm_results).sort_values("timestamp") + if not df_hbm.empty: + trainer_hbm = df_hbm.iloc[0]["hbm_gib"] + sampler_hbm = df_hbm.iloc[-1]["hbm_gib"] + + log_query = ( + f'resource.type="k8s_container" ' + f'resource.labels.project_id="cloud-tpu-multipod-dev" ' + f'resource.labels.location="us-central1" ' + f'resource.labels.cluster_name="zxhe-super-xpk-bid" ' + f'resource.labels.namespace_name="default" ' + f'resource.labels.pod_name:"{args.pod_name}" ' + f'severity>=DEFAULT' + ) + log_link = f"https://console.cloud.google.com/logs/query;query={urllib.parse.quote(log_query)}?project=cloud-tpu-multipod-dev" + result_df = pd.DataFrame([{ "pod_name": args.pod_name, "mean_reshard_time": mean_reshard_time, - "mean_weight_sync_time": mean_weight_sync_time + "mean_weight_sync_time": mean_weight_sync_time, + "trainer_hbm": trainer_hbm, + "sampler_hbm": sampler_hbm, + "log_link": log_link }]) + + # if args.store_directory is not exist, create it + if not os.path.exists(args.store_directory): + os.makedirs(args.store_directory) + output_csv_path = os.path.join(args.store_directory, "reshard_stats.csv") + # If the csv file already exists, append to it instead of overwriting try: - existing_df = pd.read_csv("reshard_stats.csv") + existing_df = pd.read_csv(output_csv_path) result_df = pd.concat([existing_df, result_df], ignore_index=True) except FileNotFoundError: pass - result_df.to_csv("reshard_stats.csv", index=False) + + # Save the results to a CSV file for later analysis + result_df.to_csv(output_csv_path, index=False) print(result_df) return result_df @@ -101,7 +146,7 @@ def get_reshard_data(args): parser = argparse.ArgumentParser() parser.add_argument("--pod_name", type=str, required=True, help="Pod name") parser.add_argument("--max_steps", type=int, default=10, help="Max steps") + parser.add_argument("--store_directory", type=str, required=True) args = parser.parse_args() get_reshard_data(args) -# python ./src/maxtext/trainers/post_train/rl/extract_time.py --pod_name sanbao-rl-0310-19 \ No newline at end of file From 054ffbbc65523338f431530a332b3b44f080f9ee Mon Sep 17 00:00:00 2001 From: Sanbao Su Date: Fri, 13 Mar 2026 11:09:28 +0000 Subject: [PATCH 07/10] with 2x2x1 --- .../trainers/post_train/rl/create_yaml.py | 42 ++++++++++++++----- .../trainers/post_train/rl/extract_time.py | 16 ++++--- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/src/maxtext/trainers/post_train/rl/create_yaml.py b/src/maxtext/trainers/post_train/rl/create_yaml.py index 3507f4f434..f266f6f73f 100644 --- a/src/maxtext/trainers/post_train/rl/create_yaml.py +++ b/src/maxtext/trainers/post_train/rl/create_yaml.py @@ -303,6 +303,18 @@ def generate_rl_config( return rendered_yaml # Example Usage: +""" +python ./maxtext/src/maxtext/trainers/post_train/rl/create_yaml.py \ + --metadata_name "${workload_name}" \ + --trainer_chips "${trainer_chips}" \ + --number_of_sampler_chips_per_replica "${sampler_chips}" \ + --sampler_replicas 1 \ + --base_output_directory "${base_output_directory}" \ + --hf_token "${hf_token}" \ + --store_directory "${store_path}" \ + --enable_tp "${enable_tp}" +""" + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--metadata_name", type=str, required=True) @@ -312,6 +324,7 @@ def generate_rl_config( parser.add_argument("--base_output_directory", type=str, required=True) parser.add_argument("--hf_token", type=str, required=True) parser.add_argument("--store_directory", type=str, required=True) + parser.add_argument("--enable_tp", type=bool, default=True) args = parser.parse_args() # for v7x-128 @@ -321,17 +334,26 @@ def generate_rl_config( rollout_data_parallelism = args.sampler_replicas sampler_chips = args.number_of_sampler_chips_per_replica * args.sampler_replicas assert sampler_chips + args.trainer_chips <= number_of_chips, "Total number of chips used by trainer and sampler must be less than or equal to available chips" - rollout_tensor_parallelism = args.number_of_sampler_chips_per_replica * 2 - rollout_expert_parallelism = rollout_tensor_parallelism // 4 if rollout_tensor_parallelism >= 4 else 1 - assert rollout_tensor_parallelism % rollout_expert_parallelism == 0, "rollout_tensor_parallelism must be divisible by rollout_expert_parallelism" - rollout_tensor_parallelism = 4 if rollout_tensor_parallelism >= 4 else rollout_tensor_parallelism - sampler_devices_fraction = sampler_chips / number_of_chips - if args.trainer_chips == 4: - subslice_shape = "2,2,1" - enable_single_controller = "true" + if args.enable_tp: + rollout_tensor_parallelism = args.number_of_sampler_chips_per_replica * 2 + rollout_expert_parallelism = rollout_tensor_parallelism // 4 if rollout_tensor_parallelism >= 4 else 1 + assert rollout_tensor_parallelism % rollout_expert_parallelism == 0, "rollout_tensor_parallelism must be divisible by rollout_expert_parallelism" + rollout_tensor_parallelism = 4 if rollout_tensor_parallelism >= 4 else rollout_tensor_parallelism else: - subslice_shape = "" - enable_single_controller = "false" + rollout_tensor_parallelism = 1 + rollout_expert_parallelism = args.number_of_sampler_chips_per_replica * 2 + + sampler_devices_fraction = sampler_chips / number_of_chips + enable_single_controller = "true" + + subslice_shape_status = { + 4: "2,2,1", + 8: "2,2,2", + 16: "2,2,4", + 32: "2,4,4", + 64: "4,4,4", + 128: "4,4,8"} + subslice_shape = subslice_shape_status.get(args.trainer_chips, "") output_directory = os.path.join(args.base_output_directory, args.metadata_name) diff --git a/src/maxtext/trainers/post_train/rl/extract_time.py b/src/maxtext/trainers/post_train/rl/extract_time.py index 0a9c30b316..7dd253b0cf 100644 --- a/src/maxtext/trainers/post_train/rl/extract_time.py +++ b/src/maxtext/trainers/post_train/rl/extract_time.py @@ -7,9 +7,6 @@ from datetime import datetime, timedelta, timezone import os -# Example usage: -# python ./maxtext/src/maxtext/trainers/post_train/rl/extract_time.py --pod_name sanbao-rl-0312-2 - def get_reshard_data(args): client = logging.Client(project="cloud-tpu-multipod-dev") @@ -81,7 +78,6 @@ def get_reshard_data(args): print("Still no logs found. Try this final check:") print(f"1. Run: gcloud logging read '{log_filter}' --limit=1") print("2. If that returns nothing, your local gcloud credentials don't have permission for this project.") - return None mean_reshard_time = float('nan') if reshard_results: @@ -125,10 +121,7 @@ def get_reshard_data(args): "log_link": log_link }]) - # if args.store_directory is not exist, create it - if not os.path.exists(args.store_directory): - os.makedirs(args.store_directory) - output_csv_path = os.path.join(args.store_directory, "reshard_stats.csv") + output_csv_path = args.store_cvs_file # If the csv file already exists, append to it instead of overwriting try: @@ -142,11 +135,16 @@ def get_reshard_data(args): print(result_df) return result_df +# Example usage: +""" +python ./maxtext/src/maxtext/trainers/post_train/rl/extract_time.py --pod_name sanbao-rl-0312-2 +""" + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--pod_name", type=str, required=True, help="Pod name") parser.add_argument("--max_steps", type=int, default=10, help="Max steps") - parser.add_argument("--store_directory", type=str, required=True) + parser.add_argument("--store_cvs_file", type=str, required=True) args = parser.parse_args() get_reshard_data(args) From b511479dff915dc85b160c688e09b8fa2b3c79ad Mon Sep 17 00:00:00 2001 From: Sanbao Su Date: Fri, 13 Mar 2026 11:28:53 +0000 Subject: [PATCH 08/10] fix --- src/maxtext/trainers/post_train/rl/create_yaml.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/maxtext/trainers/post_train/rl/create_yaml.py b/src/maxtext/trainers/post_train/rl/create_yaml.py index f266f6f73f..24fa3ad8be 100644 --- a/src/maxtext/trainers/post_train/rl/create_yaml.py +++ b/src/maxtext/trainers/post_train/rl/create_yaml.py @@ -344,7 +344,10 @@ def generate_rl_config( rollout_expert_parallelism = args.number_of_sampler_chips_per_replica * 2 sampler_devices_fraction = sampler_chips / number_of_chips - enable_single_controller = "true" + if args.trainer_chips <= 4: + enable_single_controller = "true" + else: + enable_single_controller = "false" subslice_shape_status = { 4: "2,2,1", From ad989ad23ce5619a37d3a4033dd1207c74f9873f Mon Sep 17 00:00:00 2001 From: Sanbao Su Date: Fri, 13 Mar 2026 17:01:31 +0000 Subject: [PATCH 09/10] fix create_yaml --- src/maxtext/trainers/post_train/rl/create_yaml.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/maxtext/trainers/post_train/rl/create_yaml.py b/src/maxtext/trainers/post_train/rl/create_yaml.py index 24fa3ad8be..8610598f75 100644 --- a/src/maxtext/trainers/post_train/rl/create_yaml.py +++ b/src/maxtext/trainers/post_train/rl/create_yaml.py @@ -344,12 +344,14 @@ def generate_rl_config( rollout_expert_parallelism = args.number_of_sampler_chips_per_replica * 2 sampler_devices_fraction = sampler_chips / number_of_chips - if args.trainer_chips <= 4: + if args.trainer_chips == 4: enable_single_controller = "true" else: enable_single_controller = "false" subslice_shape_status = { + 1: "1,1,1", + 2: "2,1,1", 4: "2,2,1", 8: "2,2,2", 16: "2,2,4", From 052b38e522880614d6af9681ad6c73be70a67dff Mon Sep 17 00:00:00 2001 From: Sanbao Su Date: Fri, 13 Mar 2026 17:17:17 +0000 Subject: [PATCH 10/10] reshard_auto --- .../trainers/post_train/rl/reshard_auto.sh | 165 ++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100644 src/maxtext/trainers/post_train/rl/reshard_auto.sh diff --git a/src/maxtext/trainers/post_train/rl/reshard_auto.sh b/src/maxtext/trainers/post_train/rl/reshard_auto.sh new file mode 100644 index 0000000000..43da7e421f --- /dev/null +++ b/src/maxtext/trainers/post_train/rl/reshard_auto.sh @@ -0,0 +1,165 @@ +#!/bin/bash + +# Define your configurations: "trainer_chips:number_of_sampler_chips_per_replica" +configs=( + "1:1" + "4:1" + "4:2" + "4:4" + "8:1" + "8:2" + "8:4" + "8:8" + "16:1" + "16:2" + "16:4" + "16:8" + "16:16" + "32:2" + "32:4" + "32:8" + "32:16" + "32:32" +) + +# Global variables +base_output_directory="gs://sanbao-bucket/mlperf_rl/reshard" +store_path="./reshard" +project="cloud-tpu-multipod-dev" +zone="us-central1" +cluster="zxhe-super-xpk-bid" +store_cvs_file="${store_path}/reshard_stats_tp.csv" + +mkdir -p ${store_path} + +# Function to handle errors and ensure cleanup +handle_error() { + echo "Error occurred during config ${workload_name}. Cleaning up..." + python ~/xpk/xpk.py workload delete --workload "${workload_name}" --cluster "${cluster}" --project "${project}" --zone "${zone}" + # Continue to next iteration rather than exiting the whole script +} + +for config in "${configs[@]}"; do + # Split the config string into variables + IFS=":" read -r trainer_chips sampler_chips <<< "$config" + + # Generate a unique workload name based on config and date + # timestamp=$(date +"%m-%d") + timestamp="tp0313" + workload_name="sanbao-${trainer_chips}-${sampler_chips}-${timestamp}" + + echo "----------------------------------------------------------" + echo "Running Config: Trainer=${trainer_chips}, Sampler=${sampler_chips}" + echo "Workload Name: ${workload_name}" + echo "----------------------------------------------------------" + + # Trap errors specifically for this iteration + trap 'handle_error' ERR + + # 1. Create the YAML + python ./maxtext/src/maxtext/trainers/post_train/rl/create_yaml.py \ + --metadata_name "${workload_name}" \ + --trainer_chips "${trainer_chips}" \ + --number_of_sampler_chips_per_replica "${sampler_chips}" \ + --sampler_replicas 1 \ + --base_output_directory "${base_output_directory}" \ + --HF_TOKEN "${HF_TOKEN}" \ + --store_directory "${store_path}" + + # 2. Apply Kubernetes YAML + echo "Applying Kubernetes YAML..." + kubectl apply -f "${store_path}/${workload_name}.yaml" + + # 3. Wait for workload to run + echo "Waiting 10 minutes for workload execution..." + sleep 600 + + # 5. Extract Timing Data + echo "Extracting timing data..." + python ./maxtext/src/maxtext/trainers/post_train/rl/extract_time.py \ + --pod_name "${workload_name}" \ + --store_cvs_file "${store_cvs_file}" + + echo "Finished: ${workload_name}. Data in ${store_cvs_file}" + + # Small buffer before starting the next config + sleep 120 + + # 4. Cleanup Workload + echo "Deleting workload..." + python ~/xpk/xpk.py workload delete --workload "${workload_name}" --cluster "${cluster}" --project "${project}" --zone "${zone}" + + # Clear trap for next iteration + trap - ERR +done + +echo "All configurations completed." + + +store_cvs_file="${store_path}/reshard_stats_ep.csv" + +mkdir -p ${store_path} + +# Function to handle errors and ensure cleanup +handle_error() { + echo "Error occurred during config ${workload_name}. Cleaning up..." + python ~/xpk/xpk.py workload delete --workload "${workload_name}" --cluster "${cluster}" --project "${project}" --zone "${zone}" + # Continue to next iteration rather than exiting the whole script +} + +for config in "${configs[@]}"; do + # Split the config string into variables + IFS=":" read -r trainer_chips sampler_chips <<< "$config" + + # Generate a unique workload name based on config and date + # timestamp=$(date +"%m-%d") + timestamp="ep0313" + workload_name="sanbao-${trainer_chips}-${sampler_chips}-${timestamp}" + + echo "----------------------------------------------------------" + echo "Running Config: Trainer=${trainer_chips}, Sampler=${sampler_chips}" + echo "Workload Name: ${workload_name}" + echo "----------------------------------------------------------" + + # Trap errors specifically for this iteration + trap 'handle_error' ERR + + # 1. Create the YAML + python ./maxtext/src/maxtext/trainers/post_train/rl/create_yaml.py \ + --metadata_name "${workload_name}" \ + --trainer_chips "${trainer_chips}" \ + --number_of_sampler_chips_per_replica "${sampler_chips}" \ + --sampler_replicas 1 \ + --base_output_directory "${base_output_directory}" \ + --HF_TOKEN "${HF_TOKEN}" \ + --store_directory "${store_path}" \ + --enable_tp False + + # 2. Apply Kubernetes YAML + echo "Applying Kubernetes YAML..." + kubectl apply -f "${store_path}/${workload_name}.yaml" + + # 3. Wait for workload to run + echo "Waiting 10 minutes for workload execution..." + sleep 600 + + # 5. Extract Timing Data + echo "Extracting timing data..." + python ./maxtext/src/maxtext/trainers/post_train/rl/extract_time.py \ + --pod_name "${workload_name}" \ + --store_cvs_file "${store_cvs_file}" + + echo "Finished: ${workload_name}. Data in ${store_cvs_file}" + + # Small buffer before starting the next config + sleep 120 + + # 4. Cleanup Workload + echo "Deleting workload..." + python ~/xpk/xpk.py workload delete --workload "${workload_name}" --cluster "${cluster}" --project "${project}" --zone "${zone}" + + # Clear trap for next iteration + trap - ERR +done + +echo "All configurations completed." \ No newline at end of file