diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml index 7d3659ae59..19022ef23e 100644 --- a/src/maxtext/configs/post_train/rl.yml +++ b/src/maxtext/configs/post_train/rl.yml @@ -28,6 +28,8 @@ num_samplers_slices: -1 rollout_data_parallelism: -1 rollout_tensor_parallelism: -1 rollout_expert_parallelism: 1 +rollout_subslice_shape: "" # e.g. '2,2,1' for 4 chips with DP=2, TP=2, EP=1 +rollout_enable_single_controller: False # If True, use a single controller for rollout. This can help with stability when using more than 1 model replica in rollout. # ====== Reproducibility ====== data_shuffle_seed: 42 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 0bb6d49701..08ab152355 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1586,6 +1586,11 @@ class RLHardware(BaseModel): description="Tensor parallelism per replica for rollout. If not specified, it will be auto-determined.", ) rollout_expert_parallelism: int = Field(1, description="Expert parallelism per replica for rollout") + rollout_subslice_shape: str = Field("", description="Subslice shape for rollout in the form of 'x,y,z' for Pathways.") + rollout_enable_single_controller: bool = Field( + False, + description="Whether to enable single-controller mode for rollout. If True, the trainer will also run the rollout and sampling computations instead of launching separate processes. This is only recommended for debugging or if the rollout computation is very small and can be efficiently handled by a single controller.", + ) class VLLM(BaseModel): diff --git a/src/maxtext/trainers/post_train/rl/create_yaml.py b/src/maxtext/trainers/post_train/rl/create_yaml.py new file mode 100644 index 0000000000..bc52a2fdea --- /dev/null +++ b/src/maxtext/trainers/post_train/rl/create_yaml.py @@ -0,0 +1,342 @@ +import os +from jinja2 import Template +import argparse + +def generate_rl_config( + metadata_name, + batch_size, + rollout_data_parallelism, + rollout_tensor_parallelism, + rollout_expert_parallelism, + trainer_devices_fraction, + subslice_shape, + enable_single_controller, + sampler_devices_fraction, + base_output_directory, + run_name, + hf_token +): + yaml_template = """apiVersion: jobset.x-k8s.io/v1alpha2 +kind: JobSet +metadata: + labels: + kueue.x-k8s.io/queue-name: multislice-queue + name: {{ metadata_name }} + namespace: default +spec: + coordinator: + replicatedJob: pathways-head + failurePolicy: + maxRestarts: 1 + restartStrategy: Recreate + network: + enableDNSHostnames: true + publishNotReadyAddresses: true + replicatedJobs: + - name: pathways-head + replicas: 1 + template: + metadata: + annotations: + kueue.x-k8s.io/safe-to-forcefully-terminate: "true" + spec: + backoffLimit: 0 + completionMode: Indexed + completions: 1 + parallelism: 1 + template: + spec: + containers: + - command: + - bash + - -c + - | + echo XPK Start: $(date); + _sigterm() (kill -SIGTERM $! 2>/dev/null;); + trap _sigterm SIGTERM; + + (pip install --no-deps git+https://github.com/AI-Hypercomputer/pathways-utils.git@v0.1.4 && \\ + pip install src/maxtext/integration/vllm && \\ + HF_TOKEN={{ hf_token }} JAX_RANDOM_WEIGHTS=1 VLLM_ENABLE_V1_MULTIPROCESSING=0 NEW_MODEL_DESIGN=1 TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0 JAX_PLATFORMS=proxy,cpu JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 \\ + python3 -m src.maxtext.trainers.post_train.rl.reshard_debug src/maxtext/configs/post_train/rl.yml \\ + model_name=qwen3-30b-a3b \\ + tokenizer_path=Qwen/Qwen3-30B-A3B \\ + run_name={{ run_name }} \\ + base_output_directory={{ base_output_directory }} \\ + hf_access_token={{ hf_token }} \\ + batch_size={{ batch_size }} \\ + rl.num_generations={{ batch_size }} \\ + num_batches=4 \\ + rollout_data_parallelism={{ rollout_data_parallelism }} \\ + rollout_tensor_parallelism={{ rollout_tensor_parallelism }} \\ + rollout_expert_parallelism={{ rollout_expert_parallelism }} \\ + hbm_utilization_vllm=0.4 \\ + scan_layers=True \\ + allow_split_physical_axes=True \\ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \\ + vllm_additional_config='{maxtext_config: {model_name: qwen3-30b-a3b, allow_split_physical_axes: true, log_config: false, weight_dtype: bfloat16}}' \\ + trainer_devices_fraction={{ trainer_devices_fraction }} \\ + subslice_shape='{{ subslice_shape }}' \\ + enable_single_controller={{ enable_single_controller }} \\ + sampler_devices_fraction={{ sampler_devices_fraction }}) & PID=$!; + + while kill -0 $PID 2>/dev/null; + do sleep 5; + done; + wait $PID; + EXIT_CODE=$?; + + echo XPK End: $(date); + echo EXIT_CODE=$EXIT_CODE; + + exit $EXIT_CODE + env: + - name: PATHWAYS_HEAD + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + - name: JAX_PLATFORMS + value: proxy + - name: XCLOUD_ENVIRONMENT + value: GCP + - name: JAX_BACKEND_TARGET + value: grpc://$(PATHWAYS_HEAD):29000 + image: gcr.io/cloud-tpu-multipod-dev/sanbao/maxtext_reshard_image:latest + imagePullPolicy: Always + name: jax-tpu + resources: + limits: + cpu: "24" + memory: 100G + securityContext: + privileged: true + volumeMounts: + - mountPath: /tmp + name: shared-tmp + dnsPolicy: ClusterFirstWithHostNet + hostNetwork: true + initContainers: + - args: + - --server_port=29001 + - --gcs_scratch_location=gs://cloud-pathways-staging/tmp + - --node_type=resource_manager + - --instance_count=1 + - --instance_type=tpu7x:4x4x4 + env: + - name: REPLICATED_JOB_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name'] + - name: JOBSET_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name'] + - name: HOST_ADDRESS + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + - name: TPU_SKIP_MDS_QUERY + value: "true" + image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest + imagePullPolicy: Always + name: pathways-rm + ports: + - containerPort: 29001 + protocol: TCP + - containerPort: 29002 + protocol: TCP + resources: + limits: + cpu: "8" + memory: 32G + restartPolicy: Always + - args: + - --server_port=29000 + - --resource_manager_address=$(PATHWAYS_HEAD):29001 + - --gcs_scratch_location=gs://cloud-pathways-staging/tmp + env: + - name: PATHWAYS_HEAD + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest + imagePullPolicy: Always + name: pathways-proxy + ports: + - containerPort: 29000 + protocol: TCP + resources: + limits: + cpu: "16" + memory: 100G + restartPolicy: Always + nodeSelector: + cloud.google.com/gke-nodepool: cpu-np + restartPolicy: Never + volumes: + - hostPath: + path: /tmp + type: DirectoryOrCreate + name: shared-tmp + - name: worker + replicas: 1 + template: + metadata: + annotations: + cloud.google.com/gke-tpu-slice-topology: 4x4x4 + spec: + backoffLimit: 32 + completionMode: Indexed + completions: 16 + parallelism: 16 + template: + metadata: + annotations: + cloud.google.com/gke-tpu-slice-topology: 4x4x4 + spec: + tolerations: + - key: "google.com/tpu" + operator: "Equal" + value: "present" + effect: "NoSchedule" + affinity: + nodeAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + nodeSelectorTerms: + - matchExpressions: + - key: cloud.google.com/gke-tpu-partition-4x4x4-state + operator: In + values: + - HEALTHY + - DEGRADED + containers: + - args: + - --server_port=29005 + - --resource_manager_address=$(PATHWAYS_HEAD):29001 + - --gcs_scratch_location=gs://cloud-pathways-staging/tmp + env: + - name: TPU_MIN_LOG_LEVEL + value: "0" + - name: TF_CPP_MIN_LOG_LEVEL + value: "0" + - name: XCLOUD_ENVIRONMENT + value: GCP + - name: MEGASCALE_GRPC_ENABLE_XOR_TRACER + value: "false" + - name: MEGASCALE_NUM_SLICES + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/replicatedjob-replicas'] + - name: JOBSET_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name'] + - name: REPLICATED_JOB_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name'] + - name: MEGASCALE_SLICE_ID + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/job-index'] + - name: PATHWAYS_HEAD + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + - name: MEGASCALE_COORDINATOR_ADDRESS + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest + imagePullPolicy: Always + name: pathways-worker + ports: + - containerPort: 29005 + protocol: TCP + - containerPort: 29006 + protocol: TCP + - containerPort: 8471 + protocol: TCP + - containerPort: 8080 + protocol: TCP + resources: + limits: + google.com/tpu: "4" + volumeMounts: + - mountPath: /tmp + name: shared-tmp + dnsPolicy: ClusterFirstWithHostNet + hostNetwork: true + nodeSelector: + cloud.google.com/gke-tpu-accelerator: tpu7x + priorityClassName: medium + restartPolicy: OnFailure + terminationGracePeriodSeconds: 30 + volumes: + - hostPath: + path: /tmp + type: DirectoryOrCreate + name: shared-tmp + startupPolicy: + startupPolicyOrder: InOrder + successPolicy: + operator: All + targetReplicatedJobs: + - pathways-head""" + + t = Template(yaml_template) + rendered_yaml = t.render( + metadata_name=metadata_name, + batch_size=batch_size, + rollout_data_parallelism=rollout_data_parallelism, + rollout_tensor_parallelism=rollout_tensor_parallelism, + rollout_expert_parallelism=rollout_expert_parallelism, + trainer_devices_fraction=trainer_devices_fraction, + subslice_shape=subslice_shape, + enable_single_controller=enable_single_controller, + sampler_devices_fraction=sampler_devices_fraction, + base_output_directory=base_output_directory, + run_name=run_name + ) + + return rendered_yaml + +# Example Usage: +if __name__ == "__main__": + # add args for metadat_name, trainer_chips, sampler_chips, rollout_data_parallelism, rollout_tensor_parallelism, rollout_expert_parallelism + + parser = argparse.ArgumentParser() + parser.add_argument("--metadata_name", type=str, required=True) + parser.add_argument("--trainer_chips", type=int, required=True) + parser.add_argument("--number_of_sampler_chips_per_replica", type=int, required=True) + parser.add_argument("--sampler_sharding_per_replica", type=int, required=True) + parser.add_argument("--sampler_replicas", type=int, required=True) + parser.add_argument("--base_output_directory", type=str, required=True) + parser.add_argument("--hf_token", type=str, required=True) + args = parser.parse_args() + + # for v7x-128 + number_of_chips = 64 + batch_size = args.trainer_chips * 2 + trainer_devices_fraction = args.trainer_chips / number_of_chips + rollout_data_parallelism = args.sampler_replicas + sampler_chips = args.number_of_sampler_chips_per_replica * args.sampler_sharding_per_replica + rollout_tensor_parallelism = sampler_chips // batch_size + + result = generate_rl_config( + metadata_name=args.metadata_name, + batch_size=batch_size, + rollout_data_parallelism=args.rollout_data_parallelism, + rollout_tensor_parallelism=args.rollout_tensor_parallelism, + rollout_expert_parallelism=args.rollout_expert_parallelism, + trainer_devices_fraction=0.0625, + subslice_shape="2,2,1", + enable_single_controller="true", + sampler_devices_fraction=0.0625, + base_output_directory=args.base_output_directory, + run_name=args.metadata_name + hf_token=args.hf_token + ) + + with open("qwen3-30b-v7x-temp.yaml", "w") as f: + f.write(result) \ No newline at end of file diff --git a/src/maxtext/trainers/post_train/rl/extract_time.py b/src/maxtext/trainers/post_train/rl/extract_time.py new file mode 100644 index 0000000000..08d4199597 --- /dev/null +++ b/src/maxtext/trainers/post_train/rl/extract_time.py @@ -0,0 +1,107 @@ +import argparse +import re +import pandas as pd +from google.cloud import logging +from google.cloud.logging import DESCENDING +from datetime import datetime, timedelta, timezone + +def get_reshard_data(args): + client = logging.Client(project="cloud-tpu-multipod-dev") + + # 1. Define a narrow time window (last 5 days) + # This prevents the API from searching the entire history of the project + start_time = (datetime.now(timezone.utc) - timedelta(days=5)).strftime('%Y-%m-%dT%H:%M:%SZ') + + # 2. Build the filter to search for both reshard and weight sync times. + # We replace SEARCH() with textPayload: which is the API equivalent + log_filter = ( + f'resource.type="k8s_container" ' + f'resource.labels.location="us-central1" ' + f'resource.labels.cluster_name="zxhe-super-xpk-bid" ' + f'resource.labels.namespace_name="default" ' + f'resource.labels.pod_name:"{args.pod_name}" ' + f'severity>=DEFAULT ' + f'timestamp >= "{start_time}" ' + f'(SEARCH("Reshard finished in") OR SEARCH("Weight Syncing Time taken:"))' + ) + + print(f"Querying logs from the last 5 days (Newest first)...") + + # 3. Use order_by=DESCENDING to find recent logs immediately + entries = client.list_entries(filter_=log_filter, order_by=DESCENDING) + + reshard_pattern = r"Reshard finished in (\d+\.?\d*)s" + weight_sync_pattern = r"Weight Syncing Time taken: (\d+\.?\d*)s" + reshard_results = [] + weight_sync_results = [] + + try: + for entry in entries: + payload = entry.payload + payload_str = None + if isinstance(payload, dict): + payload_str = payload.get("message") or str(payload) + else: + payload_str = str(payload) + if payload_str: + reshard_match = re.search(reshard_pattern, payload_str) + if reshard_match: + reshard_results.append({ + "timestamp": entry.timestamp, + "reshard_sec": float(reshard_match.group(1)), + "pod": entry.resource.labels.get("pod_name") + }) + + weight_sync_match = re.search(weight_sync_pattern, payload_str) + if weight_sync_match: + weight_sync_results.append({ + "timestamp": entry.timestamp, + "weight_sync_sec": float(weight_sync_match.group(1)), + "pod": entry.resource.labels.get("pod_name") + }) + except Exception as e: + print(f"Error during API call: {e}") + + if not reshard_results and not weight_sync_results: + print("Still no logs found. Try this final check:") + print(f"1. Run: gcloud logging read '{log_filter}' --limit=1") + print("2. If that returns nothing, your local gcloud credentials don't have permission for this project.") + return None + + mean_reshard_time = float('nan') + if reshard_results: + df = pd.DataFrame(reshard_results).sort_values("timestamp") + # Only keep the third - tenth ones and compute the mean of them + # Note: iloc[2:min(df.shape[0], args.max_steps)] gets indices 2 through 9 (8 items), corresponding to 3rd through 10th + selected_df = df.iloc[3:min(df.shape[0], args.max_steps)] + mean_reshard_time = selected_df["reshard_sec"].mean() + + mean_weight_sync_time = float('nan') + if weight_sync_results: + df = pd.DataFrame(weight_sync_results).sort_values("timestamp") + selected_df = df.iloc[3:min(df.shape[0], args.max_steps)] + mean_weight_sync_time = selected_df["weight_sync_sec"].mean() + + result_df = pd.DataFrame([{ + "pod_name": args.pod_name, + "mean_reshard_time": mean_reshard_time, + "mean_weight_sync_time": mean_weight_sync_time + }]) + # If the csv file already exists, append to it instead of overwriting + try: + existing_df = pd.read_csv("reshard_stats.csv") + result_df = pd.concat([existing_df, result_df], ignore_index=True) + except FileNotFoundError: + pass + result_df.to_csv("reshard_stats.csv", index=False) + print(result_df) + return result_df + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pod_name", type=str, required=True, help="Pod name") + parser.add_argument("--max_steps", type=int, default=10, help="Max steps") + args = parser.parse_args() + get_reshard_data(args) + +# python ./src/maxtext/trainers/post_train/rl/extract_time.py --pod_name sanbao-rl-0310-19 \ No newline at end of file diff --git a/src/maxtext/trainers/post_train/rl/reshard_debug.py b/src/maxtext/trainers/post_train/rl/reshard_debug.py new file mode 100644 index 0000000000..d47958fa0c --- /dev/null +++ b/src/maxtext/trainers/post_train/rl/reshard_debug.py @@ -0,0 +1,436 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Resharding Benchmark for the RL Trainer + +This module provides a unified `rl_train` function that consolidates the common +RL training logic. It handles model loading, reward function setup, dataset +processing, and training orchestration. By default, we run Group Relative Policy Optimization (GRPO) on +GSM8K math reasoning benchmark. The script is also flexible enough to run Group Sequence Policy Optimization (GSPO). + +Usage Examples: + +# GRPO on Qwen3-30B +python3 -m src.maxtext.trainers.post_train.rl.reshard_debug src/maxtext/configs/post_train/rl.yml \ + model_name=qwen3-30b-a3b \ + tokenizer_path=Qwen/Qwen3-30B-A3B \ + run_name=sanbao-rl-0310-1 \ + base_output_directory=gs://sanbao-bucket/mlperf_rl/qwen3/sanbao-rl-0310-1 \ + batch_size=16 \ + rl.num_generations=8 \ + num_batches=4 \ + rollout_data_parallelism=4 \ + rollout_tensor_parallelism=1 \ + rollout_expert_parallelism=4 \ + hbm_utilization_vllm=0.4 \ + scan_layers=True \ + allow_split_physical_axes=True \ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ + vllm_additional_config='{maxtext_config: {model_name: qwen3-30b-a3b, allow_split_physical_axes: true, log_config: false, weight_dtype: bfloat16}}' + +""" + +from __future__ import annotations +from typing import Sequence + +import collections +import jax +import json +import time +import logging +import os +import pathwaysutils + +from absl import app +from absl import logging as absl_logging +from flax import nnx +from jax.sharding import Mesh +from orbax import checkpoint as ocp +from transformers import AutoTokenizer +from tunix.rl import rl_cluster as rl_cluster_lib +from tunix.rl.rollout import base_rollout +from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner +from tunix.sft import metrics_logger, profiler +from tunix.sft.utils import show_hbm_usage + +# for vLLM we can skip JAX precompilation with this flag, it makes startup faster +os.environ["SKIP_JAX_PRECOMPILE"] = "1" + +from maxtext.configs import pyconfig +from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR +from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter +from maxtext.trainers.post_train.rl import utils_rl +from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils + + +def get_maxtext_model(config, devices=None): + """ + Load MaxText model with Tunix adapter. + # Note: pass the path to your scanned checkpoint for 'load_parameters_path'. + # To create a scanned checkpoint, you can use /maxtext/src/MaxText/checkpoint_conversion/to_maxtext.py and if + # using Pathways, please set `checkpoint_storage_use_ocdbt=False checkpoint_storage_use_zarr3=False` + # python src/MaxText/checkpoint_conversion/to_maxtext.py \ + # --model_name="gemma2-2b" \ + # --base_output_directory="/path/to/your/output/directory" \ + # --scan_layers=True \ + # --checkpoint_storage_use_ocdbt=False\ + # checkpoint_storage_use_zarr3=False + # Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e., + # load_parameters_path=/path/to/your/output/directory/0/items + """ + model, mesh = model_creation_utils.create_nnx_model(config, devices=devices) + with mesh: + use_no_op_mappings = "maxtext_config" in config.vllm_additional_config + tunix_model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=use_no_op_mappings) + tunix_model.config = None + return tunix_model, mesh + + +def setup_configs_and_devices(argv: list[str]): + """Setup device allocation and configs for training and inference.""" + config = pyconfig.initialize_pydantic(argv) + devices = jax.devices() + if config.num_trainer_slices == -1 and config.num_samplers_slices == -1: + max_logging.log("Running RL on a single slice") + num_vms = len(devices) // config.chips_per_vm + trainer_devices = devices + sampler_devices = devices + if num_vms >= 2 and config.use_pathways: + # Multiple hosts with Pathways - potentially split devices for trainer and sampler + # based on trainer_devices_fraction and sampler_devices_fraction + max_logging.log(f"{num_vms} VMs detected, allocating trainer and sampler devices, and using Pathways.") + num_devices = len(devices) + num_trainer_devices = int(num_devices * config.trainer_devices_fraction) + num_sampler_devices = int(num_devices * config.sampler_devices_fraction) + trainer_devices = devices[:num_trainer_devices] + sampler_devices = devices[num_devices - num_sampler_devices :] + if config.trainer_devices_fraction != 1.0: + max_logging.log(f"Using first {len(trainer_devices)} devices as Trainer devices") + if config.sampler_devices_fraction != 1.0: + max_logging.log(f"Using last {len(sampler_devices)} devices as Sampler devices") + trainer_config = config.model_copy() + sampler_config = config.model_copy() + elif config.num_trainer_slices > 0 and config.num_samplers_slices > 0: + max_logging.log("Running RL with Multislice") + devices_by_slice = collections.defaultdict(list) + for d in devices: + devices_by_slice[d.slice_index].append(d) + slice_indices = sorted(devices_by_slice.keys()) + + if len(slice_indices) < config.num_trainer_slices + config.num_samplers_slices: + raise ValueError("Not enough slices for trainer and samplers") + + trainer_devices = [] + for i in range(config.num_trainer_slices): + trainer_devices.extend(devices_by_slice[slice_indices[i]]) + + sampler_devices = [] + for i in range(config.num_trainer_slices, config.num_trainer_slices + config.num_samplers_slices): + sampler_devices.extend(devices_by_slice[slice_indices[i]]) + + trainer_devices_per_slice = len(trainer_devices) // config.num_trainer_slices + trainer_fsdp = trainer_devices_per_slice + tp = config.ici_tensor_parallelism + if tp > 1: + if trainer_devices_per_slice % tp != 0: + raise ValueError( + f"trainer_devices_per_slice ({trainer_devices_per_slice}) must be divisible by tensor parallelism ({tp})" + ) + if config.ici_fsdp_parallelism != -1 and config.ici_fsdp_parallelism * tp != trainer_devices_per_slice: + raise ValueError( + f"ici_fsdp_parallelism ({config.ici_fsdp_parallelism}) * ici_tensor_parallelism ({tp}) must equal " + f"devices_per_slice ({trainer_devices_per_slice})" + ) + trainer_fsdp = trainer_devices_per_slice // tp + + trainer_update = { + "num_slices": config.num_trainer_slices, + "ici_fsdp_parallelism": trainer_fsdp, + "ici_tensor_parallelism": tp, + "dcn_data_parallelism": config.num_trainer_slices, + } + + sampler_update = { + "num_slices": config.num_samplers_slices, + "ici_fsdp_parallelism": len(sampler_devices) // config.num_samplers_slices, + "ici_tensor_parallelism": -1, + "dcn_data_parallelism": config.num_samplers_slices, + } + + trainer_config = pyconfig.initialize_pydantic(argv, **trainer_update) + sampler_config = pyconfig.initialize_pydantic(argv, **sampler_update) + + else: + raise ValueError("num_trainer_slices and num_samplers_slices should be both -1 or positive") + + sampler_config.subslice_shape = config.rollout_subslice_shape + sampler_config.enable_single_controller = config.rollout_enable_single_controller + return trainer_config, sampler_config, trainer_devices, sampler_devices + + +def get_rollout_kwargs_for_parallelism(sampler_config, num_sampler_devices): + """Get rollout kwargs for vLLM rollout when using data parallelism.""" + dp = sampler_config.rollout_data_parallelism + tp = sampler_config.rollout_tensor_parallelism + ep = sampler_config.rollout_expert_parallelism + + # -1 means "auto-derive from the other two". At most one can be -1. + num_auto = sum(1 for x in [tp, dp, ep] if x == -1) + if num_auto > 1: + raise ValueError( + "At most one of rollout_tensor_parallelism, rollout_data_parallelism, " + "rollout_expert_parallelism can be -1 (auto-derived)." + ) + + if dp == -1: + if num_sampler_devices % (tp * ep) != 0: + raise ValueError( + f"num_sampler_devices({num_sampler_devices}) must be divisible by " + f"rollout_tensor_parallelism({tp}) * rollout_expert_parallelism({ep}) " + f"when rollout_data_parallelism is -1." + ) + dp = num_sampler_devices // tp // ep + elif tp == -1: + if num_sampler_devices % (dp * ep) != 0: + raise ValueError( + f"num_sampler_devices({num_sampler_devices}) must be divisible by " + f"rollout_data_parallelism({dp}) * rollout_expert_parallelism({ep}) " + f"when rollout_tensor_parallelism is -1." + ) + tp = num_sampler_devices // dp // ep + elif ep == -1: + if num_sampler_devices % (tp * dp) != 0: + raise ValueError( + f"num_sampler_devices({num_sampler_devices}) must be divisible by " + f"rollout_tensor_parallelism({tp}) * rollout_data_parallelism({dp}) " + f"when rollout_expert_parallelism is -1." + ) + ep = num_sampler_devices // tp // dp + elif tp * dp * ep != num_sampler_devices: + raise ValueError( + f"rollout_tensor_parallelism({tp}) * " + f"rollout_data_parallelism({dp}) * " + f"rollout_expert_parallelism({ep}) " + f"!= len(sampler_devices)({num_sampler_devices})" + ) + + rollout_kwargs = {} + rollout_kwargs["tensor_parallel_size"] = tp + rollout_kwargs["data_parallel_size"] = dp + rollout_kwargs["expert_parallel_size"] = ep + + return rollout_kwargs + + +def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): + """ + Run RL training with the provided configuration. + + Args: + trainer_config: MaxText configuration for the trainer. + sampler_config: MaxText configuration for the sampler. + trainer_devices: JAX devices for the trainer. + sampler_devices: JAX devices for the sampler. + """ + if not trainer_config.debug.rl: + # Apply filter to suppress noisy logs + noise_filter = max_logging.NoisyLogFilter() + logging.getLogger().addFilter(noise_filter) + absl_logging.get_absl_logger().addFilter(noise_filter) + + max_logging.log("Starting RL Resharding Debug Script") + + # Number of training steps. + max_train_steps = 1 + + # Create model tokenizer + model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path) + + # Load reference model + max_logging.log("Creating reference model and also meshes for reference and rollout") + reference_model, reference_mesh = get_maxtext_model(trainer_config, trainer_devices) + devices_array = maxtext_utils.create_device_mesh(sampler_config, sampler_devices) + # if trainer_devices=sampler_devices, then rollout_mesh=reference_mesh + # else rollout_mesh uses sampler_devices + rollout_mesh = Mesh(devices_array, sampler_config.mesh_axes) + if trainer_config.debug.rl: + max_logging.log("Reference Model initialized successfully") + nnx.display(reference_model) + max_logging.log(f"Reference mesh shape: {reference_mesh.shape}") + + # Sanity check that weights are loaded correctly. + _maxtext_state_flatten = nnx.state(reference_model).flat_state() + maxtext_state_flatten = {".".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten} + max_logging.log( + f"maxtext_state_flatten[base.token_embedder.embedding].value=\ + {maxtext_state_flatten['base.token_embedder.embedding'][...]}" + ) + + # TODO: @mazumdera: change this to use lora + if trainer_config.load_checkpoint_only_once: + max_logging.log("Creating policy model by copying reference model instead of restoring from checkpoint again.") + with reference_mesh: + actor_base_model = nnx.clone(reference_model.base) + use_no_op_mappings = "maxtext_config" in trainer_config.vllm_additional_config + actor_model = TunixMaxTextAdapter(base_model=actor_base_model, use_no_op_mappings=use_no_op_mappings) + actor_model.config = None + actor_mesh = reference_mesh + else: + max_logging.log("Creating policy model with same config as reference model on trainer mesh") + actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices) + + if trainer_config.debug.rl: + max_logging.log("Policy Model initialized successfully") + nnx.display(actor_model) + max_logging.log(f"Policy mesh shape: {actor_mesh.shape}") + + # Setup optimizer + optimizer = utils_rl.get_optimizer(trainer_config, max_train_steps) + + # Setup checkpointing + checkpointing_options = ocp.CheckpointManagerOptions( + save_interval_steps=trainer_config.checkpoint_period, max_to_keep=trainer_config.max_num_checkpoints_to_keep + ) + + # Set up micro batching + micro_batch_size = None if trainer_config.micro_batch_size == -1 else trainer_config.micro_batch_size + + # Parse vllm_additional_config + rollout_additional_config = None + if trainer_config.vllm_additional_config: + if isinstance(trainer_config.vllm_additional_config, dict): + # It's already parsed into a dict + rollout_additional_config = trainer_config.vllm_additional_config + elif isinstance(trainer_config.vllm_additional_config, str): + # It's a string, so we need to parse it + try: + rollout_additional_config = json.loads(trainer_config.vllm_additional_config) + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse additional_config JSON: {e}") from e + + max_logging.log(f"Parsed additional config: {rollout_additional_config}") + + # We need to parse vLLM config to get the logical axis rules for the sampler config. + vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml") + argv_list = ["", str(vllm_config_path), "log_config=False"] + vllm_config = pyconfig.initialize(argv_list) + + # RL Cluster config + # Note that we use vLLM as the rollout engine. + # and we are using Tensor Parallelism for rollout + cluster_config = rl_cluster_lib.ClusterConfig( + role_to_mesh={ + rl_cluster_lib.Role.ACTOR: actor_mesh, + rl_cluster_lib.Role.REFERENCE: reference_mesh, + rl_cluster_lib.Role.ROLLOUT: rollout_mesh, + }, + role_to_logical_axis_rule={ + rl_cluster_lib.Role.ACTOR: trainer_config.logical_axis_rules, + rl_cluster_lib.Role.REFERENCE: trainer_config.logical_axis_rules, + rl_cluster_lib.Role.ROLLOUT: vllm_config.logical_axis_rules, + }, + rollout_engine="vllm", + offload_to_cpu=False, + training_config=rl_cluster_lib.RLTrainingConfig( + actor_optimizer=optimizer, + eval_every_n_steps=trainer_config.eval_interval, + max_steps=max_train_steps, + # Micro batching + mini_batch_size=trainer_config.batch_size, + train_micro_batch_size=micro_batch_size, + rollout_micro_batch_size=micro_batch_size, + # Checkpoint saving + checkpoint_root_directory=trainer_config.checkpoint_dir, + checkpointing_options=checkpointing_options, + ), + rollout_config=base_rollout.RolloutConfig( + max_tokens_to_generate=trainer_config.max_target_length - trainer_config.max_prefill_predict_length, + max_prompt_length=trainer_config.max_prefill_predict_length, + kv_cache_size=trainer_config.max_target_length + trainer_config.kv_cache_buffer, + temperature=trainer_config.decode_sampling_temperature, + top_p=trainer_config.decode_sampling_nucleus_p, + top_k=trainer_config.decode_sampling_top_k, + rollout_vllm_model_version=trainer_config.tokenizer_path, + rollout_vllm_hbm_utilization=trainer_config.hbm_utilization_vllm, + rollout_vllm_tpu_backend_type="jax", + rollout_vllm_swap_space_size_gb=trainer_config.swap_space_vllm_gb, + rollout_vllm_hf_config_path=trainer_config.vllm_hf_config_path, + rollout_vllm_additional_config=rollout_additional_config, + rollout_vllm_init_with_random_weights=True, + rollout_vllm_enable_dp_attention=trainer_config.enable_dp_attention, + rollout_vllm_max_num_batched_tokens=trainer_config.max_num_batched_tokens, + rollout_vllm_max_num_seqs=trainer_config.max_num_seqs, + rollout_vllm_async_scheduling=trainer_config.async_scheduling, + rollout_vllm_kwargs={ + "hf_overrides": trainer_config.vllm_hf_overrides, + "enable_expert_parallel": sampler_config.rollout_expert_parallelism > 1, + }, + rollout_vllm_sampling_kwargs={ + "stop": trainer_config.stop_strings, + "detokenize": trainer_config.stop_strings is not None, + "include_stop_str_in_output": trainer_config.stop_strings is not None, + }, + **get_rollout_kwargs_for_parallelism(sampler_config, len(sampler_devices)), + ), + ) + # Create RL cluster + max_logging.log("Creating RL cluster...") + + rl_cluster = rl_cluster_lib.RLCluster( + actor=actor_model, + reference=reference_model, + tokenizer=model_tokenizer, + cluster_config=cluster_config, + ) + + max_logging.log( + "Calling rl_cluster.sync_weights() to reshard actor weights to rollout mesh..." + ) + + key = jax.random.PRNGKey(42) + for step in range(trainer_config.num_batches): + key, subkey = jax.random.split(key) + noise = jax.random.normal(subkey, ()) * 1e-3 + # Update all actor weights to trigger full resharding + state = nnx.state(actor_model, nnx.Param) + new_state = jax.tree_util.tree_map(lambda x: x + noise, state) + nnx.update(actor_model, new_state) + + show_hbm_usage(f"HBM before step {step}:") + start_time = time.time() + rl_cluster.sync_weights() + jax.tree_util.tree_map(jax.block_until_ready, rl_cluster.rollout._sampler.transformer_state) + end_time = time.time() + show_hbm_usage(f"HBM after step {step}:") + max_logging.log(f"Weight Syncing Time taken: {end_time - start_time:.4f}s") + max_logging.log(f"Resharding via sync_weights() completed: step {step}") + + +def main(argv: Sequence[str]) -> None: + """Main function to run RL training. + + Args: + argv: Command-line arguments. + """ + pathwaysutils.initialize() + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + + max_utils.print_system_information() + trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(argv) + rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices) + + +if __name__ == "__main__": + app.run(main) \ No newline at end of file