From 209bd7a42460d290c0e4aa90c93c3c8548b1f5a8 Mon Sep 17 00:00:00 2001 From: Surbhi Jain Date: Mon, 18 May 2026 23:57:08 +0000 Subject: [PATCH] Deprecate cloud_tpu_diagnostics support and update requirements lockfile --- .../features_and_diagnostics.md | 18 - .../architecture/architecture_overview.md | 1 - docs/run_maxtext/decoupled_mode.md | 9 +- .../base_requirements/requirements.txt | 1 - .../tpu-post-train-requirements.txt | 2 +- .../cuda12-requirements.txt | 17 +- .../tpu-post-train-requirements.txt | 34 +- .../tpu-requirements.txt | 17 +- src/maxtext/common/gcloud_stub.py | 77 +- src/maxtext/configs/base.yml | 5 - src/maxtext/configs/types.py | 9 - src/maxtext/experimental/rl/grpo_trainer.py | 20 +- .../scratch_code/demo_from_config.ipynb | 934 ++++-------------- src/maxtext/trainers/pre_train/train.py | 48 +- tests/end_to_end/test_checkpointing.sh | 13 +- .../tpu/test_checkpoint_resharding.sh | 4 +- .../integration/checkpoint_resharding_test.py | 1 - tests/unit/gcloud_stub_test.py | 21 - 18 files changed, 273 insertions(+), 958 deletions(-) diff --git a/docs/guides/monitoring_and_debugging/features_and_diagnostics.md b/docs/guides/monitoring_and_debugging/features_and_diagnostics.md index a6952fae04..691ba240b6 100644 --- a/docs/guides/monitoring_and_debugging/features_and_diagnostics.md +++ b/docs/guides/monitoring_and_debugging/features_and_diagnostics.md @@ -16,24 +16,6 @@ # Features and diagnostics -## Collect stack traces - -When running a Single Program, Multiple Data (SPMD) job on accelerators, the overall process can hang if there is any error or any VM hangs/crashes for some reason. In this scenario, capturing stack traces will help to identify and troubleshoot the issues for the jobs running on TPU VMs. - -The following configurations will help to debug a fault or when a program is stuck or hung somewhere by collecting stack traces. Change the parameter values accordingly in `src/maxtext/configs/base.yml`: - -1. Set `collect_stack_trace: True` to enable collection of stack traces on faults or when the program is hung. This setting will periodically dump the traces for the program to help in debugging. To disable this, set `collect_stack_trace: False`. -2. Set `stack_trace_to_cloud: False` to display stack traces on console. `stack_trace_to_cloud: True` will create a temporary file in `/tmp/debugging` in the TPUs to store the stack traces. There is an agent running on TPU VMs that will periodically upload the traces from the temporary directory to cloud logging in the gcp project. You can view the traces in Logs Explorer on Cloud Logging using the following query: - -``` -logName="projects//logs/tpu.googleapis.com%2Fruntime_monitor" -jsonPayload.verb="stacktraceanalyzer" -``` - -3. `stack_trace_interval_seconds` signifies the duration in seconds between each stack trace collection event. Setting `stack_trace_interval_seconds: 600` will collect the stack traces every 600 seconds (10 minutes). - -Here is the related PyPI package: https://pypi.org/project/cloud-tpu-diagnostics. - (aot-compilation)= ## Ahead of Time compilation (AOT) diff --git a/docs/reference/architecture/architecture_overview.md b/docs/reference/architecture/architecture_overview.md index 1b73145dcd..f019982728 100644 --- a/docs/reference/architecture/architecture_overview.md +++ b/docs/reference/architecture/architecture_overview.md @@ -177,6 +177,5 @@ The critical technology enabling this strategy is the suite of checkpoint conver Debugging performance issues in a distributed system with thousands of accelerators is a notoriously difficult challenge. MaxText incorporates several built-in diagnostic features designed to provide visibility into the system's behavior at scale. -- Stack trace collection: To diagnose program hangs or faults, users can set `collect_stack_trace: True` in the configuration. This feature will periodically dump the Python stack traces from all worker processes. The traces can be directed to the console for immediate inspection or, more scalably, uploaded to Cloud Logging, where they can be aggregated and queried to identify misbehaving nodes. - HLO dumping: For deep, low-level performance analysis, MaxText allows users to dump the XLA High-Level Optimizer (HLO) graph. By setting the `dump_hlo` flag, the compiled graph for a specific training step can be saved to a local directory or uploaded to Cloud Storage. This HLO representation is invaluable for compiler engineers and advanced users who need to understand exactly how XLA is interpreting and optimizing the model, making it possible to debug subtle performance regressions or compiler-related issues. - Goodput monitoring: The framework integrates with the ml-goodput-measurement library, which provides a more holistic view of job efficiency than simple TFLOPs calculations. This allows for the tracking of metrics that capture overall "goodput," accounting for factors like data loading time, compilation overhead, and idle time, giving a truer picture of end-to-end performance. diff --git a/docs/run_maxtext/decoupled_mode.md b/docs/run_maxtext/decoupled_mode.md index b6e6423ca3..29c334b4a6 100644 --- a/docs/run_maxtext/decoupled_mode.md +++ b/docs/run_maxtext/decoupled_mode.md @@ -54,17 +54,12 @@ Optional environment variables: MaxText exposes a single module `maxtext.common.gcloud_stub` to avoid scattering environment checks: ```python -from maxtext.common.gcloud_stub import is_decoupled, cloud_diagnostics, jetstream +from maxtext.common.gcloud_stub import is_decoupled, jetstream if is_decoupled(): # Skip optional integrations or use local fallbacks pass -# Cloud diagnostics (returns diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration) -diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = ( - cloud_diagnostics() -) - # JetStream (serving) components config_lib, engine_api, token_utils, tokenizer_api, token_params_ns = jetstream() TokenizerParameters = getattr(token_params_ns, "TokenizerParameters", object) @@ -78,7 +73,7 @@ Behavior when `DECOUPLE_GCLOUD=TRUE`: ## Guidelines: -- Prefer calling `jetstream()` / `cloud_diagnostics()` once at module import and branching on `is_decoupled()` for functionality that truly requires the dependency. +- Prefer calling `jetstream()` once at module import and branching on `is_decoupled()` for functionality that truly requires the dependency. - Use `is_decoupled()` to avoid direct `os.environ["DECOUPLE_GCLOUD"]` checking. - Use `get_test_config_path()` instead of hard-coded `base.yml`. - Prefer conditional local fallbacks for cloud buckets and avoid introducing direct `gs://...` paths. diff --git a/src/dependencies/requirements/base_requirements/requirements.txt b/src/dependencies/requirements/base_requirements/requirements.txt index 4f17b24aa8..5ba8ee5093 100644 --- a/src/dependencies/requirements/base_requirements/requirements.txt +++ b/src/dependencies/requirements/base_requirements/requirements.txt @@ -3,7 +3,6 @@ aqtp array-record chex cloud-accelerator-diagnostics -cloud-tpu-diagnostics!=1.1.14 datasets drjax flax diff --git a/src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt b/src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt index 911be1cd49..9fb229fa4b 100644 --- a/src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt +++ b/src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt @@ -36,7 +36,7 @@ setuptools<81.0.0 sortedcontainers torch==2.11.0 torchax==0.0.11 -torchvision==0.25.0 +torchvision==0.26.0 tpu-info watchfiles xgrammar diff --git a/src/dependencies/requirements/generated_requirements/cuda12-requirements.txt b/src/dependencies/requirements/generated_requirements/cuda12-requirements.txt index 95dd84b5cf..262b59b1dc 100644 --- a/src/dependencies/requirements/generated_requirements/cuda12-requirements.txt +++ b/src/dependencies/requirements/generated_requirements/cuda12-requirements.txt @@ -23,9 +23,8 @@ cffi>=2.0.0 ; platform_python_implementation != 'PyPy' cfgv>=3.5.0 charset-normalizer>=3.4.7 chex>=0.1.91 -click>=8.3.3 +click>=8.4.0 cloud-accelerator-diagnostics>=0.1.1 -cloud-tpu-diagnostics>=0.1.5 cloudpickle>=3.1.2 clu>=0.0.12 colorama>=0.4.6 @@ -33,7 +32,7 @@ contourpy>=1.3.3 cryptography>=48.0.0 cycler>=0.12.1 datasets>=4.8.5 -decorator>=5.2.1 +decorator>=5.3.1 dill>=0.4.1 distlib>=0.4.0 distro>=1.9.0 @@ -56,10 +55,10 @@ gast>=0.7.0 gcsfs>=2026.2.0 google-api-core>=2.30.3 google-api-python-client>=2.196.0 -google-auth>=2.52.0 +google-auth>=2.53.0 google-auth-httplib2>=0.4.0 google-auth-oauthlib>=1.4.0 -google-cloud-aiplatform>=1.152.0 +google-cloud-aiplatform>=1.153.1 google-cloud-appengine-logging>=1.9.0 google-cloud-audit-log>=0.5.0 google-cloud-bigquery>=3.41.0 @@ -71,7 +70,7 @@ google-cloud-resource-manager>=1.17.0 google-cloud-storage>=3.10.1 google-cloud-storage-control>=1.11.0 google-crc32c>=1.8.0 -google-genai>=1.75.0 +google-genai>=2.4.0 google-pasta>=0.2.0 google-resumable-media>=2.9.0 googleapis-common-protos>=1.75.0 @@ -86,7 +85,7 @@ hf-xet>=1.5.0 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or httpcore>=1.0.9 httplib2>=0.31.2 httpx>=0.28.1 -huggingface-hub>=1.14.0 +huggingface-hub>=1.15.0 humanize>=4.15.0 hypothesis>=6.142.1 identify>=2.6.19 @@ -221,7 +220,7 @@ tensorflow-metadata>=1.17.3 tensorflow-text>=2.20.1 tensorstore>=0.1.82 termcolor>=3.3.0 -tiktoken>=0.12.0 +tiktoken>=0.13.0 tokamax>=0.0.12 tokenizers>=0.22.2 toml>=0.10.2 @@ -240,7 +239,7 @@ typing-inspection>=0.4.2 tzdata>=2026.2 ; sys_platform == 'emscripten' or sys_platform == 'win32' uritemplate>=4.2.0 urllib3>=2.6.3 -uvicorn>=0.46.0 +uvicorn>=0.47.0 uvloop>=0.22.1 virtualenv>=21.3.3 wadler-lindig>=0.1.7 diff --git a/src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt b/src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt index 0ff978bbd2..010ec6be09 100644 --- a/src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt +++ b/src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt @@ -23,10 +23,10 @@ astunparse>=1.6.3 attrs>=26.1.0 auditwheel>=6.6.0 black>=25.12.0 -boto3>=1.43.7 -botocore>=1.43.7 +boto3>=1.43.10 +botocore>=1.43.10 build>=1.4.3 -cachetools>=7.1.1 +cachetools>=7.1.3 cbor2>=6.1.1 certifi>=2026.2.25 cffi>=2.0.0 ; implementation_name == 'pypy' or platform_python_implementation != 'PyPy' @@ -34,9 +34,8 @@ cfgv>=3.5.0 charset-normalizer>=3.4.7 cheroot>=11.1.2 chex>=0.1.91 -click>=8.3.3 +click>=8.4.0 cloud-accelerator-diagnostics>=0.1.1 -cloud-tpu-diagnostics>=0.1.5 cloudpickle>=3.1.2 clu>=0.0.12 colorama>=0.4.6 @@ -52,7 +51,7 @@ dataclasses>=0.5 dataclasses-json>=0.0.1 datasets>=4.8.5 debugpy>=1.8.20 -decorator>=5.2.1 +decorator>=5.3.1 dill>=0.4.1 distlib>=0.4.0 distro>=1.9.0 @@ -80,10 +79,10 @@ gepa>=0.1.1 gguf>=0.19.0 google-api-core>=2.30.3 google-api-python-client>=2.196.0 -google-auth>=2.52.0 +google-auth>=2.53.0 google-auth-httplib2>=0.4.0 google-auth-oauthlib>=1.4.0 -google-cloud-aiplatform>=1.152.0 +google-cloud-aiplatform>=1.153.1 google-cloud-appengine-logging>=1.9.0 google-cloud-audit-log>=0.5.0 google-cloud-bigquery>=3.41.0 @@ -95,7 +94,7 @@ google-cloud-resource-manager>=1.17.0 google-cloud-storage>=3.10.1 google-cloud-storage-control>=1.11.0 google-crc32c>=1.8.0 -google-genai>=1.75.0 +google-genai>=2.4.0 google-metrax>=0.2.3 google-pasta>=0.2.0 google-resumable-media>=2.9.0 @@ -114,7 +113,7 @@ hf-xet>=1.5.0 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or httpcore>=1.0.9 httplib2>=0.31.2 httpx>=0.28.1 -huggingface-hub>=1.14.0 +huggingface-hub>=1.15.0 humanize>=4.15.0 hypothesis>=6.142.1 identify>=2.6.19 @@ -129,7 +128,7 @@ ipython>=9.13.0 ipython-pygments-lexers>=1.1.1 ipywidgets>=8.1.8 isort>=8.0.1 -jaraco-functools>=4.4.0 +jaraco-functools>=4.5.0 jax>=0.10.0 jaxlib>=0.10.0 jaxtyping>=0.3.9 @@ -154,7 +153,7 @@ libtpu>=0.0.40 ; platform_machine == 'x86_64' and sys_platform == 'linux' llguidance>=1.7.5 llvmlite>=0.47.0 loguru>=0.7.3 -lxml>=6.1.0 +lxml>=6.1.1 markdown>=3.10.2 markdown-it-py>=4.0.0 markupsafe>=3.0.3 @@ -203,7 +202,7 @@ nvidia-nvshmem-cu13>=3.4.5 ; sys_platform == 'linux' nvidia-nvtx>=13.0.85 ; sys_platform == 'linux' oauthlib>=3.3.1 omegaconf>=2.3.0 -openai>=2.36.0 +openai>=2.37.0 openai-harmony>=0.0.8 opentelemetry-api>=1.41.1 opt-einsum>=3.4.0 @@ -309,7 +308,7 @@ tensorflow-metadata>=1.17.3 tensorflow-text>=2.20.1 tensorstore>=0.1.82 termcolor>=3.3.0 -tiktoken>=0.12.0 +tiktoken>=0.13.0 tokamax>=0.0.12 tokenizers>=0.22.2 toml>=0.10.2 @@ -317,6 +316,7 @@ tomlkit>=0.15.0 toolz>=1.1.0 torch>=2.11.0 torchax>=0.0.11 +torchvision>=0.26.0 tornado>=6.5.5 tpu-info>=0.11.0 tqdm>=4.66.3 @@ -331,11 +331,11 @@ typing-inspection>=0.4.2 tzdata>=2026.2 ; sys_platform == 'emscripten' or sys_platform == 'win32' uritemplate>=4.2.0 urllib3>=2.6.3 -uvicorn>=0.46.0 +uvicorn>=0.47.0 uvloop>=0.22.1 virtualenv>=21.3.3 wadler-lindig>=0.1.7 -watchfiles>=1.1.1 +watchfiles>=1.2.0 wcwidth>=0.7.0 websockets>=16.0 werkzeug>=3.1.8 @@ -343,7 +343,7 @@ wheel>=0.46.3 widgetsnbextension>=4.0.15 win32-setctime>=1.2.0 ; sys_platform == 'win32' wrapt>=2.1.2 -xgrammar>=0.2.0 +xgrammar>=0.2.1 xprof>=2.22.3 xxhash>=3.7.0 yapf>=0.43.0 diff --git a/src/dependencies/requirements/generated_requirements/tpu-requirements.txt b/src/dependencies/requirements/generated_requirements/tpu-requirements.txt index 339751c77f..cca048842d 100644 --- a/src/dependencies/requirements/generated_requirements/tpu-requirements.txt +++ b/src/dependencies/requirements/generated_requirements/tpu-requirements.txt @@ -23,9 +23,8 @@ cffi>=2.0.0 ; platform_python_implementation != 'PyPy' cfgv>=3.5.0 charset-normalizer>=3.4.7 chex>=0.1.91 -click>=8.3.3 +click>=8.4.0 cloud-accelerator-diagnostics>=0.1.1 -cloud-tpu-diagnostics>=0.1.5 cloudpickle>=3.1.2 clu>=0.0.12 colorama>=0.4.6 @@ -33,7 +32,7 @@ contourpy>=1.3.3 cryptography>=48.0.0 cycler>=0.12.1 datasets>=4.8.5 -decorator>=5.2.1 +decorator>=5.3.1 dill>=0.4.1 distlib>=0.4.0 distro>=1.9.0 @@ -56,10 +55,10 @@ gast>=0.7.0 gcsfs>=2026.2.0 google-api-core>=2.30.3 google-api-python-client>=2.196.0 -google-auth>=2.52.0 +google-auth>=2.53.0 google-auth-httplib2>=0.4.0 google-auth-oauthlib>=1.4.0 -google-cloud-aiplatform>=1.152.0 +google-cloud-aiplatform>=1.153.1 google-cloud-appengine-logging>=1.9.0 google-cloud-audit-log>=0.5.0 google-cloud-bigquery>=3.41.0 @@ -71,7 +70,7 @@ google-cloud-resource-manager>=1.17.0 google-cloud-storage>=3.10.1 google-cloud-storage-control>=1.11.0 google-crc32c>=1.8.0 -google-genai>=1.75.0 +google-genai>=2.4.0 google-pasta>=0.2.0 google-resumable-media>=2.9.0 googleapis-common-protos>=1.75.0 @@ -86,7 +85,7 @@ hf-xet>=1.5.0 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or httpcore>=1.0.9 httplib2>=0.31.2 httpx>=0.28.1 -huggingface-hub>=1.14.0 +huggingface-hub>=1.15.0 humanize>=4.15.0 hypothesis>=6.142.1 identify>=2.6.19 @@ -207,7 +206,7 @@ tensorflow-metadata>=1.17.3 tensorflow-text>=2.20.1 tensorstore>=0.1.82 termcolor>=3.3.0 -tiktoken>=0.12.0 +tiktoken>=0.13.0 tokamax>=0.0.12 tokenizers>=0.22.2 toml>=0.10.2 @@ -223,7 +222,7 @@ typing-inspection>=0.4.2 tzdata>=2026.2 ; sys_platform == 'emscripten' or sys_platform == 'win32' uritemplate>=4.2.0 urllib3>=2.6.3 -uvicorn>=0.46.0 +uvicorn>=0.47.0 uvloop>=0.22.1 virtualenv>=21.3.3 wadler-lindig>=0.1.7 diff --git a/src/maxtext/common/gcloud_stub.py b/src/maxtext/common/gcloud_stub.py index 2ecc96bac7..fe6b00cbb5 100644 --- a/src/maxtext/common/gcloud_stub.py +++ b/src/maxtext/common/gcloud_stub.py @@ -18,8 +18,6 @@ integrations while still allowing local unit tests to import modules. This module provides: - is_decoupled(): returns True if decoupled flag set. -- cloud_diagnostics(): tuple(diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration) - providing either real objects or lightweight stubs. - jetstream(): returns a namespace-like object exposing Engine, Devices, ResultTokens etc. or stubs. - gcs_storage(): returns google.cloud.storage module or stub namespace with Client/Blob/Bucket. - goodput_modules(): returns (goodput, monitoring, is_stub) for ml_goodput_measurement integration or stubs. @@ -73,79 +71,6 @@ def _import_or_stub( raise -# ---------------- Cloud Diagnostics ----------------- - - -def _cloud_diag_stubs(): - """Return lightweight stubs for cloud TPU diagnostics.""" - import contextlib # pylint: disable=import-outside-toplevel - - class _StubDiag: - """Stub diagnostic object returning skip metadata.""" - - def run(self, *_a, **_k): - return {"status": "skipped"} - - def diagnose(self, *_a, **_k): - """Return a context manager that swallows diagnostic errors in stub mode.""" - - @contextlib.contextmanager - def _graceful_diagnose(): - try: - yield - except Exception as exc: # pylint: disable=broad-exception-caught - print("Warning: using stubs for cloud_diagnostics diagnose() - " f"caught: {exc}") - - return _graceful_diagnose() - - class _StubDebugConfig: - """Stub debug configuration.""" - - def __init__(self, *a, **k): # pylint: disable=unused-argument - pass - - class _StubStackTraceConfig: - """Stub stack trace configuration.""" - - def __init__(self, *a, **k): # pylint: disable=unused-argument - pass - - class _StubDiagnosticConfig: - """Stub diagnostic configuration wrapper.""" - - def __init__(self, *a, debug_config=None, **k): # pylint: disable=unused-argument - del a, k - self.debug_config = debug_config - - return ( - _StubDiag(), - SimpleNamespace(DebugConfig=_StubDebugConfig, StackTraceConfig=_StubStackTraceConfig), - SimpleNamespace(DiagnosticConfig=_StubDiagnosticConfig), - SimpleNamespace(StackTraceConfig=_StubStackTraceConfig), - ) - - -def cloud_diagnostics(): - """Return real cloud diagnostics modules or stubs. - - If a dependency is missing and we are decoupled, return stubs. Otherwise - re-raise the import error so callers fail fast. - """ - - def _import(): - from cloud_tpu_diagnostics import diagnostic # type: ignore # pylint: disable=import-outside-toplevel - from cloud_tpu_diagnostics.configuration import ( # type: ignore # pylint: disable=import-outside-toplevel - debug_configuration, - diagnostic_configuration, - stack_trace_configuration, - ) - - return diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration - - # Only stub on import failures if running decoupled; otherwise fail fast. - return _import_or_stub(_import, _cloud_diag_stubs, label="cloud_diagnostics", stub_if_decoupled=False) - - # ---------------- JetStream ----------------- @@ -390,7 +315,7 @@ def _import(): ) -__all__ = ["is_decoupled", "cloud_diagnostics", "jetstream", "gcs_storage", "goodput_modules"] +__all__ = ["is_decoupled", "jetstream", "gcs_storage", "goodput_modules"] # ---------------- Cloud Monitoring (monitoring_v3 / metric_pb2) ----------------- diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 670d155974..8165b57ee7 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -888,11 +888,6 @@ muon_beta: 0.95 # Decay rate for the exponentially weighted average of grads. muon_weight_decay: 0 # Strength of the weight decay regularization. This is multiplied with the learning rate. muon_consistent_rms: None # If None, apply width scaling to updates. If float, apply consistent rms scaling (recommend 0.2). -# Stack trace parameters -collect_stack_trace: False -stack_trace_to_cloud: False # Uploads to cloud logging if True, else to the console if False. -stack_trace_interval_seconds: 600 # Stack trace collection frequency in seconds. - # Use iota operator in Embed use_iota_embed: False # use positional embedding diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index bb18a81a5f..f2165fdde6 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1696,14 +1696,6 @@ class HloDump(BaseModel): dump_jaxpr_gcs_dir: PathStr = Field("", description="GCS directory to upload jaxpr dumps.") -class StackTrace(BaseModel): - """Configuration for collecting and logging stack traces.""" - - collect_stack_trace: bool = Field(False, description="Enable periodic stack trace collection.") - stack_trace_to_cloud: bool = Field(False, description="Upload stack traces to cloud logging instead of console.") - stack_trace_interval_seconds: int = Field(600, description="Frequency of stack trace collection in seconds.") - - class Metrics(BaseModel): """General configuration for metrics and monitoring.""" @@ -2277,7 +2269,6 @@ class MaxTextConfig( DevelopmentAndDebugging, Profiling, HloDump, - StackTrace, # Metrics and Monitoring Metrics, Goodput, diff --git a/src/maxtext/experimental/rl/grpo_trainer.py b/src/maxtext/experimental/rl/grpo_trainer.py index 28eef21cb0..4f5356686b 100644 --- a/src/maxtext/experimental/rl/grpo_trainer.py +++ b/src/maxtext/experimental/rl/grpo_trainer.py @@ -57,11 +57,6 @@ from flax import struct from flax.nnx import TrainState -from cloud_tpu_diagnostics import diagnostic -from cloud_tpu_diagnostics.configuration import debug_configuration -from cloud_tpu_diagnostics.configuration import diagnostic_configuration -from cloud_tpu_diagnostics.configuration import stack_trace_configuration - import transformers from ml_goodput_measurement.src.goodput import GoodputRecorder @@ -973,20 +968,9 @@ def main(argv: Sequence[str]) -> None: # Create the Goodput recorder recorder = create_goodput_recorder(config) - # Stack traces configurations - debug_config = debug_configuration.DebugConfig( - stack_trace_config=stack_trace_configuration.StackTraceConfig( - collect_stack_trace=config.collect_stack_trace, - stack_trace_to_cloud=config.stack_trace_to_cloud, - stack_trace_interval_seconds=config.stack_trace_interval_seconds, - ) - ) - diagnostic_config = diagnostic_configuration.DiagnosticConfig(debug_config) - record_goodput(recorder, RECORD_JOB_START_TIME) - with diagnostic.diagnose(diagnostic_config): - with maybe_monitor_goodput(config): - train_loop(config, config_inference, recorder) + with maybe_monitor_goodput(config): + train_loop(config, config_inference, recorder) if __name__ == "__main__": diff --git a/src/maxtext/scratch_code/demo_from_config.ipynb b/src/maxtext/scratch_code/demo_from_config.ipynb index 19fcc47cfb..ed50575931 100644 --- a/src/maxtext/scratch_code/demo_from_config.ipynb +++ b/src/maxtext/scratch_code/demo_from_config.ipynb @@ -1,720 +1,220 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "a8e986cb", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Added '/home/mazumdera/maxtext' to sys.path\n" - ] - } - ], - "source": [ - "import os\n", - "import sys\n", - "\n", - "from maxtext.utils.globals import MAXTEXT_REPO_ROOT\n", - "\n", - "# Add the project root to the system path if it's not already there\n", - "if MAXTEXT_REPO_ROOT not in sys.path:\n", - " sys.path.insert(0, MAXTEXT_REPO_ROOT)\n", - " print(f\"Added '{MAXTEXT_REPO_ROOT}' to sys.path\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0ab2e1dd", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2025-06-18 21:34:12.489183: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", - "E0000 00:00:1750282452.508183 1726814 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "E0000 00:00:1750282452.513660 1726814 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "W0000 00:00:1750282452.528073 1726814 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", - "W0000 00:00:1750282452.528091 1726814 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", - "W0000 00:00:1750282452.528093 1726814 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", - "W0000 00:00:1750282452.528094 1726814 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n" - ] - } - ], - "source": [ - "import maxtext as mt\n", - "from maxtext.configs import pyconfig\n", - "import numpy as np\n", - "from maxtext.input_pipeline import input_pipeline_utils\n", - "import os\n", - "from maxtext.common import common_types\n", - "import jax\n", - "from maxtext.inference import inference_utils\n", - "from maxtext.utils import max_logging\n", - "from maxtext.utils import maxtext_utils" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d2d2de93", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Updating keys from env and command line: ['run_name', 'enable_checkpointing', 'base_emb_dim', 'base_num_query_heads', 'base_num_kv_heads', 'base_num_decoder_layers', 'per_device_batch_size', 'max_target_length', 'max_prefill_predict_length']\n", - "Running Model: default\n", - "Updating keys from model: []\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:2025-06-18 21:34:16,611:jax._src.xla_bridge:913: A Google TPU may be present on this machine, but either a TPU-enabled jaxlib or libtpu is not installed. Falling back to cpu.\n", - "WARNING:jax._src.xla_bridge:A Google TPU may be present on this machine, but either a TPU-enabled jaxlib or libtpu is not installed. Falling back to cpu.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Not using emergency checkpoint, ignoring local_checkpoint_directory, local_checkpoint_period, use_replicator_service and replicator_backup_interval_minutes\n", - "dataset_type set to tfds, will use keys['dataset_path']='' and keys['dataset_name']='c4/en:3.0.1'\n", - "Config param activations_in_float32: False\n", - "Config param adam_b1: 0.9\n", - "Config param adam_b2: 0.95\n", - "Config param adam_eps: 1e-08\n", - "Config param adam_eps_root: 0.0\n", - "Config param adam_weight_decay: 0.1\n", - "Config param add_bos: True\n", - "Config param add_eos: True\n", - "Config param allow_split_physical_axes: False\n", - "Config param ar_cache_axis_order: 1,2,0,3\n", - "Config param async_checkpointing: True\n", - "Config param attention: autoselected\n", - "Config param attention_type: global\n", - "Config param attn_logits_soft_cap: None\n", - "Config param autoregressive_decode_assert: \n", - "Config param base_emb_dim: 256\n", - "Config param base_mlp_dim: 7168\n", - "Config param base_moe_mlp_dim: 7168\n", - "Config param base_num_decoder_layers: 2\n", - "Config param base_num_kv_heads: 2\n", - "Config param base_num_query_heads: 2\n", - "Config param base_output_directory: \n", - "Config param beta_fast: 32\n", - "Config param beta_slow: 1\n", - "Config param capacity_factor: -1.0\n", - "Config param cast_logits_to_fp32: True\n", - "Config param checkpoint_dir: test/checkpoints/\n", - "Config param checkpoint_is_quantized: False\n", - "Config param checkpoint_period: 10000\n", - "Config param checkpoint_storage_concurrent_gb: 96\n", - "Config param checkpoint_storage_target_data_file_size_bytes: 2147483648\n", - "Config param checkpoint_storage_use_ocdbt: True\n", - "Config param checkpoint_storage_use_zarr3: True\n", - "Config param chunk_attn_window_size: 0\n", - "Config param collect_stack_trace: False\n", - "Config param colocated_python_data_input: False\n", - "Config param compile_topology: \n", - "Config param compile_topology_num_slices: -1\n", - "Config param compiled_trainstep_file: \n", - "Config param compute_axis_order: 0,1,2,3\n", - "Config param context: remat\n", - "Config param context_parallel_load_balance: True\n", - "Config param cosine_learning_rate_final_fraction: 0.1\n", - "Config param custom_mesh: \n", - "Config param data_sharding: (('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'),)\n", - "Config param data_shuffle_seed: 0\n", - "Config param dataset_name: c4/en:3.0.1\n", - "Config param dataset_path: \n", - "Config param dataset_type: tfds\n", - "Config param dcn_autoregressive_parallelism: 1\n", - "Config param dcn_context_autoregressive_parallelism: 1\n", - "Config param dcn_context_parallelism: 1\n", - "Config param dcn_data_parallelism: -1\n", - "Config param dcn_expert_parallelism: 1\n", - "Config param dcn_fsdp_parallelism: 1\n", - "Config param dcn_fsdp_transpose_parallelism: 1\n", - "Config param dcn_parallelism: [-1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n", - "Config param dcn_pipeline_parallelism: 1\n", - "Config param dcn_sequence_parallelism: 1\n", - "Config param dcn_tensor_parallelism: 1\n", - "Config param dcn_tensor_sequence_parallelism: 1\n", - "Config param dcn_tensor_transpose_parallelism: 1\n", - "Config param decode_sampling_nucleus_p: -1\n", - "Config param decode_sampling_strategy: greedy\n", - "Config param decode_sampling_temperature: 1.0\n", - "Config param decode_sampling_top_k: 0\n", - "Config param decoder_block: DecoderBlockType.LLAMA2\n", - "Config param decoder_layer_input: device\n", - "Config param dpo_beta: 0.1\n", - "Config param dpo_label_smoothing: 0.0\n", - "Config param dropout_rate: 0.0\n", - "Config param dtype: bfloat16\n", - "Config param dtype_mm: float32\n", - "Config param dump_hlo: False\n", - "Config param dump_hlo_delete_local_after: True\n", - "Config param dump_hlo_gcs_dir: \n", - "Config param dump_hlo_local_dir: /tmp/xla_dump/\n", - "Config param dump_hlo_module_name: jit_train_step\n", - "Config param dump_hlo_upload_all: False\n", - "Config param dump_hlo_xla_flags: \n", - "Config param dump_step: -1\n", - "Config param emb_dim: 256\n", - "Config param enable_checkpoint_cloud_logger: False\n", - "Config param enable_checkpointing: False\n", - "Config param enable_data_shuffling: True\n", - "Config param enable_dropout: True\n", - "Config param enable_emergency_checkpoint: False\n", - "Config param enable_gcp_goodput_metrics: True\n", - "Config param enable_gcp_step_deviation_metrics: True\n", - "Config param enable_goodput_recording: False\n", - "Config param enable_jax_profiler: False\n", - "Config param enable_llm_inference_pool: False\n", - "Config param enable_model_warmup: False\n", - "Config param enable_padding_causal_mask: True\n", - "Config param enable_pathways_goodput: False\n", - "Config param enable_prefix_caching: False\n", - "Config param enable_single_controller: False\n", - "Config param enable_single_replica_ckpt_restoring: False\n", - "Config param enable_tensorboard: True\n", - "Config param eval_data_columns: ['text']\n", - "Config param eval_dataset_name: c4/en:3.0.1\n", - "Config param eval_interval: -1\n", - "Config param eval_per_device_batch_size: 1.0\n", - "Config param eval_split: validation\n", - "Config param eval_steps: -1\n", - "Config param expansion_factor_real_data: -1\n", - "Config param final_logits_soft_cap: None\n", - "Config param first_num_dense_layers: 0\n", - "Config param float32_logits: False\n", - "Config param float32_qk_product: False\n", - "Config param force_unroll: False\n", - "Config param freeze_vision_encoder_params: True\n", - "Config param fused_mlp: False\n", - "Config param fused_qkv: False\n", - "Config param gcs_metrics: False\n", - "Config param generate_slice: v5e-16\n", - "Config param global_batch_size_to_eval_on: 1\n", - "Config param global_batch_size_to_load: 1\n", - "Config param global_batch_size_to_load_eval: 1\n", - "Config param global_batch_size_to_train_on: 1\n", - "Config param global_parameter_scale: 1\n", - "Config param goodput_upload_interval_seconds: 30\n", - "Config param gradient_accumulation_steps: 1\n", - "Config param gradient_clipping_threshold: 1.0\n", - "Config param grain_eval_files: \n", - "Config param grain_file_type: arrayrecord\n", - "Config param grain_train_files: \n", - "Config param grain_worker_count: 1\n", - "Config param grain_worker_count_eval: 1\n", - "Config param hardware: tpu\n", - "Config param head_dim: 128\n", - "Config param heartbeat_reporting_interval_in_seconds: 5\n", - "Config param hf_data_dir: \n", - "Config param hf_eval_files: \n", - "Config param hf_eval_split: \n", - "Config param hf_path: \n", - "Config param hf_train_files: \n", - "Config param hidden_size_for_vit: 1408\n", - "Config param ici_autoregressive_parallelism: 1\n", - "Config param ici_context_autoregressive_parallelism: 1\n", - "Config param ici_context_parallelism: 1\n", - "Config param ici_data_parallelism: 1\n", - "Config param ici_expert_parallelism: 1\n", - "Config param ici_fsdp_parallelism: -1\n", - "Config param ici_fsdp_transpose_parallelism: 1\n", - "Config param ici_parallelism: [1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n", - "Config param ici_pipeline_parallelism: 1\n", - "Config param ici_sequence_parallelism: 1\n", - "Config param ici_tensor_parallelism: 1\n", - "Config param ici_tensor_sequence_parallelism: 1\n", - "Config param ici_tensor_transpose_parallelism: 1\n", - "Config param image_path: \n", - "Config param image_size_for_vit: 896\n", - "Config param inference_benchmark_test: False\n", - "Config param inference_metadata_file: \n", - "Config param inference_microbenchmark_log_file_path: \n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Config param inference_microbenchmark_loop_iters: 10\n", - "Config param inference_microbenchmark_num_samples: [1, 2, 3, 4, 5]\n", - "Config param inference_microbenchmark_prefill_lengths: 64,128,256,512,1024\n", - "Config param inference_microbenchmark_stages: prefill,generate\n", - "Config param inference_server: MaxtextInterleavedServer\n", - "Config param inhomogeneous_layer_cycle_interval: 1\n", - "Config param init_weights_seed: 0\n", - "Config param input_data_sharding_logical_axes: ['activation_embed_and_logits_batch', 'activation_norm_length']\n", - "Config param interleave_moe_layer_step: 1\n", - "Config param intermediate_size_for_vit: 5632\n", - "Config param jax_cache_dir: ~/jax_cache\n", - "Config param jax_debug_log_modules: \n", - "Config param jax_distributed_initialization_timeout: 300\n", - "Config param jax_profiler_port: 9999\n", - "Config param key_proj: remat\n", - "Config param kv_lora_rank: 512\n", - "Config param kv_quant_axis: heads_and_dkv\n", - "Config param kv_quant_dtype: int8\n", - "Config param learning_rate: 3e-05\n", - "Config param learning_rate_schedule_steps: 150001\n", - "Config param load_balance_loss_weight: 0.01\n", - "Config param load_from_prefill_dir: False\n", - "Config param load_full_state_path: \n", - "Config param load_parameters_path: \n", - "Config param local_checkpoint_directory: \n", - "Config param local_checkpoint_period: 0\n", - "Config param local_rope_max_timescale: -1\n", - "Config param log_config: True\n", - "Config param log_period: 100\n", - "Config param logical_axis_rules: (('activation_batch', ('data', 'fsdp', 'fsdp_transpose', 'expert')), ('activation_batch_no_exp', ('data', 'fsdp', 'fsdp_transpose')), ('activation_embed_and_logits_batch', ('data', 'stage', 'fsdp', 'fsdp_transpose', 'expert')), ('activation_heads', ('tensor', 'tensor_transpose', 'sequence', 'tensor_sequence', 'autoregressive')), ('activation_kv_heads', ('tensor', 'tensor_transpose', 'sequence', 'tensor_sequence')), ('activation_length', ('sequence', 'context')), ('activation_length', ('context',)), ('activation_norm_length', ('tensor_sequence', 'context', 'sequence')), ('activation_q_length', ('context',)), ('activation_kv_length', ()), ('activation_embed', ('tensor', 'tensor_transpose')), ('activation_mlp', ('tensor', 'tensor_transpose', 'tensor_sequence')), ('activation_kv', ('tensor', 'tensor_transpose', 'tensor_sequence')), ('activation_prefill_kv_batch', ('data', 'fsdp', 'fsdp_transpose', 'expert')), ('activation_kv_batch', ('data', 'fsdp', 'fsdp_transpose', 'expert')), ('activation_kv_head_dim', ('tensor', 'tensor_transpose', 'tensor_sequence')), ('activation_vocab', ('tensor', 'tensor_transpose', 'sequence', 'tensor_sequence')), ('activation_vocab', ('tensor', 'tensor_transpose')), ('activation_vocab', 'tensor_sequence'), ('activation_vocab', ('sequence', 'context')), ('activation_stage', 'stage'), ('activation_exp', ('expert',)), ('decode_batch', ('data', 'fsdp', 'fsdp_transpose', 'expert')), ('decode_length', ('sequence',)), ('mlp', ('fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive')), ('vocab', ('tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive')), ('heads', ('tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive')), ('q_heads', ('tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive')), ('kv_heads', ('tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive')), ('embed', ('fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert')), ('embed', ('fsdp', 'sequence', 'tensor_transpose', 'context', 'expert')), ('embed', ('fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert')), ('embed', ('fsdp', 'sequence', 'context', 'expert')), ('embed_no_exp', ('fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context')), ('embed_no_exp', ('fsdp', 'sequence', 'tensor_transpose', 'context')), ('embed_no_exp', ('fsdp', 'fsdp_transpose', 'sequence', 'context')), ('embed_no_exp', ('fsdp', 'sequence', 'context')), ('q_lora', ('fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert')), ('q_lora', ('fsdp', 'sequence', 'context', 'tensor_transpose', 'expert')), ('q_lora', ('fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert')), ('q_lora', ('fsdp', 'sequence', 'context', 'expert')), ('kv_lora', ('fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert')), ('kv_lora', ('fsdp', 'sequence', 'context', 'tensor_transpose', 'expert')), ('kv_lora', ('fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert')), ('kv_lora', ('fsdp', 'sequence', 'context', 'expert')), ('norm', ('tensor', 'tensor_transpose', 'tensor_sequence')), ('layers', 'stage'), ('kv', ()), ('kv_head_dim', ()), ('cache_batch_prefill', ()), ('cache_batch', ()), ('cache_heads_none', ()), ('cache_heads', ('autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence')), ('cache_heads', ('autoregressive', 'tensor', 'tensor_sequence')), ('cache_kv', ()), ('cache_sequence', ()), ('exp', 'expert'), ('paged_kv_heads', ('tensor',)), ('num_pages', ()), ('tokens_per_page', ()), ('paged_kv_head_dim_size', ()))\n", - "Config param logits_dot_in_fp32: False\n", - "Config param logits_via_embedding: False\n", - "Config param lora_input_adapters_path: \n", - "Config param matmul_precision: default\n", - "Config param max_checkify: False\n", - "Config param max_corpus_chars: 10000000\n", - "Config param max_position_embeddings: 163840\n", - "Config param max_prefill_predict_length: 4\n", - "Config param max_target_length: 4\n", - "Config param megablox: True\n", - "Config param mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']\n", - "Config param metrics_dir: test/metrics/\n", - "Config param metrics_file: \n", - "Config param micro_batch_size_to_eval_on: 1\n", - "Config param micro_batch_size_to_train_on: 1\n", - "Config param mla_naive_kvcache: True\n", - "Config param mlp_activations: ['silu', 'linear']\n", - "Config param mlp_dim: 7168\n", - "Config param mlpwi: remat\n", - "Config param mlpwi_0: remat\n", - "Config param mlpwi_1: remat\n", - "Config param mlpwo: remat\n", - "Config param model_call_mode: \n", - "Config param model_name: default\n", - "Config param moe_mlp_dim: 7168\n", - "Config param monitor_goodput: False\n", - "Config param monitor_step_time_deviation: True\n", - "Config param mscale: 1.0\n", - "Config param mu_dtype: float32\n", - "Config param multi_sampling: False\n", - "Config param n_routing_groups: -1\n", - "Config param nope_layer_interval: -1\n", - "Config param normalization_layer_epsilon: 1e-05\n", - "Config param normalize_embedding_logits: True\n", - "Config param num_attention_heads_for_vit: 16\n", - "Config param num_channels_for_vit: 3\n", - "Config param num_decoder_layers: 2\n", - "Config param num_epoch: 1\n", - "Config param num_experts: 1\n", - "Config param num_experts_per_tok: 1\n", - "Config param num_hidden_layers_for_vit: 34\n", - "Config param num_kv_heads: 2\n", - "Config param num_layers_per_pipeline_stage: 1\n", - "Config param num_pipeline_microbatches: -1\n", - "Config param num_pipeline_repeats: -1\n", - "Config param num_query_heads: 2\n", - "Config param num_slices: 1\n", - "Config param opt_type: adamw\n", - "Config param optimize_mesh_for_tpu_v6e: False\n", - "Config param optimizer_memory_host_offload: False\n", - "Config param original_max_position_embeddings: 4096\n", - "Config param out_proj: remat\n", - "Config param override_model_config: False\n", - "Config param packing: True\n", - "Config param pagedattn_max_pages_per_group: 1\n", - "Config param pagedattn_num_pages: 64\n", - "Config param pagedattn_pages_per_compute_block: 4\n", - "Config param pagedattn_tokens_per_page: 32\n", - "Config param param_scan_axis: 1\n", - "Config param parameter_memory_host_offload: False\n", - "Config param patch_size_for_vit: 14\n", - "Config param per_device_batch_size: 1.0\n", - "Config param pipeline_delay_activation_forwarding: False\n", - "Config param pipeline_fsdp_ag_once: False\n", - "Config param pipeline_parallel_layers: -1\n", - "Config param pixel_shuffle_ratio_for_vit: 0.5\n", - "Config param prefill_cache_axis_order: 1,2,0,3\n", - "Config param prefill_cache_dir: \n", - "Config param prefill_chunk_size: 256\n", - "Config param prefill_slice: v5e-16\n", - "Config param prefix_caching_dram_byte: 100000000000\n", - "Config param prefix_caching_hbm_byte: 10000000000\n", - "Config param profile_cleanly: True\n", - "Config param profile_periodically_period: -1\n", - "Config param profiler: \n", - "Config param profiler_steps: 5\n", - "Config param projector_dropout_for_vit: 0.0\n", - "Config param projector_input_dim_for_vit: 4096\n", - "Config param projector_output_dim_for_vit: 4096\n", - "Config param prometheus_port: 0\n", - "Config param prompt: I love to\n", - "Config param q_lora_rank: 0\n", - "Config param qk_nope_head_dim: 128\n", - "Config param qk_rope_head_dim: 64\n", - "Config param qkv_proj: remat\n", - "Config param quant_cfg_path: \n", - "Config param quantization: \n", - "Config param quantization_local_shard_count: 1\n", - "Config param quantize_kvcache: False\n", - "Config param query_proj: remat\n", - "Config param ragged_block_size: 256\n", - "Config param record_internal_nn_metrics: 0\n", - "Config param remat_policy: full\n", - "Config param remat_policy_for_vit: minimal\n", - "Config param replicate_quant_scale: False\n", - "Config param replicator_backup_interval_minutes: 0\n", - "Config param report_heartbeat_metric_for_gcp_monitoring: False\n", - "Config param report_performance_metric_for_gcp_monitoring: False\n", - "Config param reshape_q: False\n", - "Config param return_log_prob: False\n", - "Config param reuse_example_batch: 0\n", - "Config param rope_factor: 40\n", - "Config param rope_max_timescale: 10000\n", - "Config param rope_min_timescale: 1\n", - "Config param rope_theta_for_vit: 10000\n", - "Config param rope_type: default\n", - "Config param rope_use_scale: True\n", - "Config param routed_bias: False\n", - "Config param routed_scaling_factor: 1.0\n", - "Config param routed_score_func: \n", - "Config param run_name: test\n", - "Config param sa_block_kv: 512\n", - "Config param sa_block_kv_compute: 512\n", - "Config param sa_block_kv_dkv: 512\n", - "Config param sa_block_kv_dkv_compute: 512\n", - "Config param sa_block_kv_dq: 512\n", - "Config param sa_block_q: 512\n", - "Config param sa_block_q_dkv: 512\n", - "Config param sa_block_q_dq: 512\n", - "Config param sa_k_layout: HEAD_DIM_MINOR\n", - "Config param sa_q_layout: HEAD_DIM_MINOR\n", - "Config param sa_use_fused_bwd_kernel: False\n", - "Config param sa_v_layout: HEAD_DIM_MINOR\n", - "Config param save_config_to_gcs: False\n", - "Config param save_quantized_params_path: \n", - "Config param scan_layers: True\n", - "Config param scan_layers_per_stage: False\n", - "Config param scan_pipeline_iterations: True\n", - "Config param set_remat_policy_on_layers_per_stage: False\n", - "Config param set_remat_policy_on_pipeline_iterations: True\n", - "Config param sft_train_on_completion_only: False\n", - "Config param sharding_tolerance: 0.02\n", - "Config param shared_experts: 1\n", - "Config param skip_first_n_steps_for_profiler: 1\n", - "Config param skip_jax_distributed_system: False\n", - "Config param sliding_window_size: 0\n", - "Config param sparse_matmul: True\n", - "Config param stack_prefill_result_cache: False\n", - "Config param stack_trace_interval_seconds: 600\n", - "Config param stack_trace_to_cloud: False\n", - "Config param step_deviation_interval_seconds: 30\n", - "Config param steps: 150001\n", - "Config param target_eval_loss: 0.0\n", - "Config param temperature_tuning: False\n", - "Config param tensorboard_dir: test/tensorboard/\n", - "Config param tile_activation_dim: 1024\n", - "Config param tile_batch_seq: 512\n", - "Config param tile_weight_dim: 1024\n", - "Config param tokenize_eval_data: True\n", - "Config param tokenize_train_data: True\n", - "Config param tokenizer_path: assets/tokenizer.llama2\n", - "Config param tokenizer_type: sentencepiece\n", - "Config param topk_routing_group: -1\n", - "Config param train_data_columns: ['text']\n", - "Config param train_split: train\n", - "Config param trainable_position_size: -1\n", - "Config param upload_all_profiler_results: False\n", - "Config param use_chat_template: False\n", - "Config param use_chunked_prefill: False\n", - "Config param use_dpo: False\n", - "Config param use_iota_embed: False\n", - "Config param use_multimodal: False\n", - "Config param use_post_attn_norm: False\n", - "Config param use_post_ffw_norm: False\n", - "Config param use_qk_norm: False\n", - "Config param use_ragged_attention: False\n", - "Config param use_random_routing: False\n", - "Config param use_replicator_service: False\n", - "Config param use_sft: False\n", - "Config param use_untrainable_positional_embedding: False\n", - "Config param use_vertex_tensorboard: False\n", - "Config param using_pipeline_parallelism: False\n", - "Config param v_head_dim: 128\n", - "Config param value_proj: remat\n", - "Config param vertex_tensorboard_project: \n", - "Config param vertex_tensorboard_region: \n", - "Config param vision_output_dim_for_vit: 4096\n", - "Config param vocab_size: 32000\n", - "Config param warmup_steps_fraction: 0.1\n", - "Config param weight_dtype: float32\n", - "Num_devices: 1, shape (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)\n" - ] - }, - { - "ename": "NameError", - "evalue": "name 'global_store' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[3], line 18\u001b[0m\n\u001b[1;32m 1\u001b[0m config \u001b[38;5;241m=\u001b[39m pyconfig\u001b[38;5;241m.\u001b[39minitialize(\n\u001b[1;32m 2\u001b[0m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdecode.py\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m../configs/base.yml\u001b[39m\u001b[38;5;124m\"\u001b[39m], \u001b[38;5;66;03m#TODO: @mazumdera: why decode.py?\u001b[39;00m\n\u001b[1;32m 3\u001b[0m per_device_batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1.0\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 14\u001b[0m \n\u001b[1;32m 15\u001b[0m )\n\u001b[1;32m 17\u001b[0m model \u001b[38;5;241m=\u001b[39m mt\u001b[38;5;241m.\u001b[39mfrom_pretrained(config)\n\u001b[0;32m---> 18\u001b[0m mesh, init_rng \u001b[38;5;241m=\u001b[39m \u001b[43mglobal_store\u001b[49m\u001b[38;5;241m.\u001b[39mget_global_mesh_and_init_rng()\n\u001b[1;32m 19\u001b[0m state, _ \u001b[38;5;241m=\u001b[39m maxtext_utils\u001b[38;5;241m.\u001b[39msetup_decode_state(model, config, init_rng, mesh, \u001b[38;5;28;01mNone\u001b[39;00m)\n", - "\u001b[0;31mNameError\u001b[0m: name 'global_store' is not defined" - ] - } - ], - "source": [ - "from maxtext.utils.globals import MAXTEXT_PKG_DIR\n", - "\n", - "config = pyconfig.initialize(\n", - " [os.path.join(MAXTEXT_PKG_DIR, \"decode.py\"), os.path.join(MAXTEXT_PKG_DIR, \"configs\", \"base.yml\")],\n", - " per_device_batch_size=1.0,\n", - " run_name=\"test\",\n", - " enable_checkpointing=False,\n", - " base_num_decoder_layers=2,\n", - " max_target_length=4,\n", - " base_emb_dim=256,\n", - " base_num_query_heads=2,\n", - " base_num_kv_heads=2,\n", - " max_prefill_predict_length=4,\n", - " # tokenizer_path=\"assets/tokenizers/llama3.1-tokenizer/\",\n", - " # model_name=\"llama3.1-7b\",\n", - ")\n", - "\n", - "model = mt.from_config(config)\n", - "mesh = model.mesh\n", - "init_rng = jax.random.PRNGKey(config.init_weights_seed)\n", - "state, _ = maxtext_utils.setup_decode_state(model, config, init_rng, mesh, None)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d2d2d0c5", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tokenizer path: /home/mazumdera/maxtext/assets/tokenizer_llama3.tiktoken\n", - "Reloaded tiktoken model from /home/mazumdera/maxtext/assets/tokenizer_llama3.tiktoken\n", - "#words: 128256 - BOS ID: 128000 - EOS ID: 128001\n", - "input_ids=[128000, 40, 3021, 311], ids=[[128000 40 3021 311]], decoder_segment_ids = [[1. 1. 1. 1.]], decoder_positions= [[0 1 2 3]]\n" - ] - } - ], - "source": [ - "from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT\n", - "\n", - "source_tokenizer = _input_pipeline_utils.get_tokenizer(\n", - " os.path.join(MAXTEXT_ASSETS_ROOT, \"tokenizers\", \"tokenizer_llama3.tiktoken\"),\n", - " \"tiktoken\",\n", - " add_bos=True,\n", - " add_eos=False,\n", - ")\n", - "\n", - "\n", - "# TODO: @mazumdera: any way to geto segment and position ids like HF tokenizer gives us?\n", - "input_ids = source_tokenizer.encode(config.prompt) # .numpy()\n", - "ids = np.asarray(input_ids, dtype=np.int32)\n", - "s = (config.global_batch_size_to_train_on, config.max_target_length)\n", - "decoder_segment_ids = np.zeros(s) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR\n", - "decoder_positions = np.stack(\n", - " [np.arange(config.max_target_length, dtype=np.int32) for _ in range(config.global_batch_size_to_train_on)]\n", - ")\n", - "\n", - "# TODO: @mazumdera: simplify this config.global_batch_size_to_train_on=1\n", - "ids = np.stack([ids for _ in range(config.global_batch_size_to_train_on)])\n", - "max_logging.log(\n", - " f\"input_ids={input_ids}, ids={ids}, decoder_segment_ids = {decoder_segment_ids}, decoder_positions= {decoder_positions}\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e5a1fe11", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[CpuDevice(id=0)]" - ] - }, - "execution_count": 30, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import jax\n", - "\n", - "!export TPU_LIBRARY_PATH=/home/mazumdera/.local/lib/python3.10/site-packages/libtpu/libtpu.so\n", - "\n", - "jax.devices()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8d42b156", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/home/mazumdera/.local/lib/python3.10/site-packages/libtpu/libtpu.so\n" - ] - } - ], - "source": [ - "!ls /home/mazumdera/.local/lib/python3.10/site-packages/libtpu/libtpu.so" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7436751b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "full_train_logits[0, 0, :]=array([[ 0.6484375 , -1.09375 , -1.3359375 , ..., 0.0177002 ,\n", - " -0.8984375 , -0.57421875],\n", - " [ 0.8125 , -0.53125 , -0.3125 , ..., 1.34375 ,\n", - " 1.078125 , -1.3828125 ],\n", - " [ 0.6171875 , -2. , -2.0625 , ..., 0.13867188,\n", - " -0.9375 , -0.796875 ],\n", - " [-0.27734375, -1.3203125 , -0.765625 , ..., 1.1171875 ,\n", - " -0.26953125, 0.4296875 ]], dtype=float32)\n" - ] - } - ], - "source": [ - "import jax.experimental.multihost_utils\n", - "\n", - "full_train_logits = model.apply(\n", - " state.params,\n", - " ids,\n", - " decoder_positions,\n", - " decoder_segment_ids,\n", - " enable_dropout=False,\n", - " rngs={\"aqt\": init_rng},\n", - ")\n", - "full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits)\n", - "max_logging.log(f\"{full_train_logits[0, 0, :]=}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bb06c0c9", - "metadata": {}, - "outputs": [], - "source": [ - "selected_logits = jax.lax.dynamic_slice(\n", - " full_train_logits, (0, 0, full_train_logits.shape[2] - 1, 0), (1, 1, 1, full_train_logits.shape[3])\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "308f2a57", - "metadata": {}, - "outputs": [], - "source": [ - "init_rng, new_rng = jax.random.split(init_rng)\n", - "first_generated_token = inference_utils.sampling(\n", - " selected_logits,\n", - " new_rng,\n", - " config.decode_sampling_strategy,\n", - " topk=config.decode_sampling_top_k,\n", - " nucleus_topp=config.decode_sampling_nucleus_p,\n", - " temperature=config.decode_sampling_temperature,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "32555a83", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "26831" - ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "first_generated_token.item()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3de52746", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'-ad'" - ] - }, - "execution_count": 36, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "source_tokenizer.decode([first_generated_token.item()])" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "a8e986cb", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "\n", + "from maxtext.utils.globals import MAXTEXT_REPO_ROOT\n", + "\n", + "# Add the project root to the system path if it's not already there\n", + "if MAXTEXT_REPO_ROOT not in sys.path:\n", + " sys.path.insert(0, MAXTEXT_REPO_ROOT)\n", + " print(f\"Added '{MAXTEXT_REPO_ROOT}' to sys.path\")" + ] }, - "nbformat": 4, - "nbformat_minor": 5 + { + "cell_type": "code", + "execution_count": null, + "id": "0ab2e1dd", + "metadata": {}, + "outputs": [], + "source": [ + "import maxtext as mt\n", + "from maxtext.configs import pyconfig\n", + "import numpy as np\n", + "from maxtext.input_pipeline import input_pipeline_utils\n", + "import os\n", + "from maxtext.common import common_types\n", + "import jax\n", + "from maxtext.inference import inference_utils\n", + "from maxtext.utils import max_logging\n", + "from maxtext.utils import maxtext_utils" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2d2de93", + "metadata": {}, + "outputs": [], + "source": [ + "from maxtext.utils.globals import MAXTEXT_PKG_DIR\n", + "\n", + "config = pyconfig.initialize(\n", + " [os.path.join(MAXTEXT_PKG_DIR, \"decode.py\"), os.path.join(MAXTEXT_PKG_DIR, \"configs\", \"base.yml\")],\n", + " per_device_batch_size=1.0,\n", + " run_name=\"test\",\n", + " enable_checkpointing=False,\n", + " base_num_decoder_layers=2,\n", + " max_target_length=4,\n", + " base_emb_dim=256,\n", + " base_num_query_heads=2,\n", + " base_num_kv_heads=2,\n", + " max_prefill_predict_length=4,\n", + " # tokenizer_path=\"assets/tokenizers/llama3.1-tokenizer/\",\n", + " # model_name=\"llama3.1-7b\",\n", + ")\n", + "\n", + "model = mt.from_config(config)\n", + "mesh = model.mesh\n", + "init_rng = jax.random.PRNGKey(config.init_weights_seed)\n", + "state, _ = maxtext_utils.setup_decode_state(model, config, init_rng, mesh, None)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2d2d0c5", + "metadata": {}, + "outputs": [], + "source": [ + "from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT\n", + "\n", + "source_tokenizer = _input_pipeline_utils.get_tokenizer(\n", + " os.path.join(MAXTEXT_ASSETS_ROOT, \"tokenizers\", \"tokenizer_llama3.tiktoken\"),\n", + " \"tiktoken\",\n", + " add_bos=True,\n", + " add_eos=False,\n", + ")\n", + "\n", + "\n", + "# TODO: @mazumdera: any way to geto segment and position ids like HF tokenizer gives us?\n", + "input_ids = source_tokenizer.encode(config.prompt) # .numpy()\n", + "ids = np.asarray(input_ids, dtype=np.int32)\n", + "s = (config.global_batch_size_to_train_on, config.max_target_length)\n", + "decoder_segment_ids = np.zeros(s) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR\n", + "decoder_positions = np.stack(\n", + " [np.arange(config.max_target_length, dtype=np.int32) for _ in range(config.global_batch_size_to_train_on)]\n", + ")\n", + "\n", + "# TODO: @mazumdera: simplify this config.global_batch_size_to_train_on=1\n", + "ids = np.stack([ids for _ in range(config.global_batch_size_to_train_on)])\n", + "max_logging.log(\n", + " f\"input_ids={input_ids}, ids={ids}, decoder_segment_ids = {decoder_segment_ids}, decoder_positions= {decoder_positions}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5a1fe11", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "\n", + "!export TPU_LIBRARY_PATH=/home/mazumdera/.local/lib/python3.10/site-packages/libtpu/libtpu.so\n", + "\n", + "jax.devices()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d42b156", + "metadata": {}, + "outputs": [], + "source": [ + "!ls /home/mazumdera/.local/lib/python3.10/site-packages/libtpu/libtpu.so" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7436751b", + "metadata": {}, + "outputs": [], + "source": [ + "import jax.experimental.multihost_utils\n", + "\n", + "full_train_logits = model.apply(\n", + " state.params,\n", + " ids,\n", + " decoder_positions,\n", + " decoder_segment_ids,\n", + " enable_dropout=False,\n", + " rngs={\"aqt\": init_rng},\n", + ")\n", + "full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits)\n", + "max_logging.log(f\"{full_train_logits[0, 0, :]=}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb06c0c9", + "metadata": {}, + "outputs": [], + "source": [ + "selected_logits = jax.lax.dynamic_slice(\n", + " full_train_logits, (0, 0, full_train_logits.shape[2] - 1, 0), (1, 1, 1, full_train_logits.shape[3])\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "308f2a57", + "metadata": {}, + "outputs": [], + "source": [ + "init_rng, new_rng = jax.random.split(init_rng)\n", + "first_generated_token = inference_utils.sampling(\n", + " selected_logits,\n", + " new_rng,\n", + " config.decode_sampling_strategy,\n", + " topk=config.decode_sampling_top_k,\n", + " nucleus_topp=config.decode_sampling_nucleus_p,\n", + " temperature=config.decode_sampling_temperature,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32555a83", + "metadata": {}, + "outputs": [], + "source": [ + "first_generated_token.item()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3de52746", + "metadata": {}, + "outputs": [], + "source": [ + "source_tokenizer.decode([first_generated_token.item()])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 69d27487d2..6dd0731cc6 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -19,7 +19,6 @@ # See github.com/google/maxtext/issues/20 for more from typing import Any, Sequence -import contextlib import datetime import functools import os @@ -57,7 +56,6 @@ maybe_record_goodput, record_goodput, ) -from maxtext.common.gcloud_stub import cloud_diagnostics as _cloud_diag, is_decoupled from maxtext.common.gcloud_stub import vertex_tensorboard_modules from maxtext.common import metric_logger from maxtext.common.metric_logger import record_activation_metrics @@ -74,8 +72,6 @@ from maxtext.utils.gradient_accumulation import gradient_accumulation_loss_and_grad from maxtext.utils.vocabulary_tiling import vocab_tiling_linen_loss -_diag_modules = _cloud_diag() -diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = _diag_modules VertexTensorboardManager, _vertex_tb_is_stub = vertex_tensorboard_modules() @@ -793,41 +789,16 @@ def initialize(argv: Sequence[str]) -> tuple[pyconfig.HyperParameters, Any, Any] # Create the Goodput recorder recorder = create_goodput_recorder(config) - # Stack traces configurations - debug_config = debug_configuration.DebugConfig( - stack_trace_config=stack_trace_configuration.StackTraceConfig( - collect_stack_trace=config.collect_stack_trace, - stack_trace_to_cloud=config.stack_trace_to_cloud, - stack_trace_interval_seconds=config.stack_trace_interval_seconds, - ) - ) - diagnostic_config = diagnostic_configuration.DiagnosticConfig(debug_config) - return config, recorder, diagnostic_config - - -def run(config, recorder, diagnostic_config): - """Run the job given hyperparameters and utilities. - - In decoupled mode (DECOUPLE_GCLOUD=TRUE) cloud diagnostics may be stubbed; if so, skip wrapping. - """ - # Use nullcontext when diagnostics are stubbed or in decoupled mode - diagnostics_context = ( - contextlib.nullcontext() - if is_decoupled() or getattr(diagnostic, "__class__", None).__name__ == "_StubDiag" - else diagnostic.diagnose(diagnostic_config) - ) + return config, recorder - if is_decoupled() or getattr(diagnostic, "__class__", None).__name__ == "_StubDiag": - max_logging.log("[DECOUPLED NO-OP] skipping cloud diagnostics wrapper.") - with ( - diagnostics_context, - max_utils.maybe_get_transformer_engine_context(config), - ): +def run(config, recorder): + """Run the job given hyperparameters and utilities.""" + with (max_utils.maybe_get_transformer_engine_context(config),): train_loop(config, recorder) -def get_train_func(config, recorder, diagnostic_config, argv): +def get_train_func(config, recorder, argv): """Returns the train function, wrapping in elastic_retry if elastic training is enabled.""" if config.elastic_enabled: max_logging.log("Elastic utils: Elastic training enabled.") @@ -840,11 +811,10 @@ def on_slices_ready(): def elastic_train_wrapper(argv: Sequence[str]) -> None: """Wrapper for elastic training initializes variables and runs the train loop.""" - elastic_config, elastic_recorder, elastic_diagnostic_config = initialize(argv) + elastic_config, elastic_recorder = initialize(argv) run( elastic_config, elastic_recorder, - elastic_diagnostic_config, ) train_func = elastic_utils.elastic_retry( @@ -855,15 +825,15 @@ def elastic_train_wrapper(argv: Sequence[str]) -> None: else: # Use the already initialized variables def train_func(): - run(config, recorder, diagnostic_config) + run(config, recorder) return train_func def main(argv: Sequence[str]) -> None: - config, recorder, diagnostic_config = initialize(argv) + config, recorder = initialize(argv) record_goodput(recorder, RECORD_JOB_START_TIME) - train_func = get_train_func(config, recorder, diagnostic_config, argv) + train_func = get_train_func(config, recorder, argv) with maybe_monitor_goodput(config): train_func() diff --git a/tests/end_to_end/test_checkpointing.sh b/tests/end_to_end/test_checkpointing.sh index 2841fd88b6..8405f7d535 100644 --- a/tests/end_to_end/test_checkpointing.sh +++ b/tests/end_to_end/test_checkpointing.sh @@ -14,13 +14,12 @@ fi RUN_NAME=${1}-${4} OUTPUT_PATH=${2} DATASET_PATH=${3} -COLLECT_STACK_TRACE=${4} -DATASET_TYPE=${5} -ATTENTION=${6} -if [ -z "${6}" ]; then +DATASET_TYPE=${4} +ATTENTION=${5} +if [ -z "${5}" ]; then ATTENTION='autoselected' fi -ASYNC_CHECKPOINTING=${7:-true} +ASYNC_CHECKPOINTING=${6:-true} eval_metrics=checkpoint_save_restore model_params=" base_emb_dim=384 base_num_query_heads=8 base_num_kv_heads=8 base_mlp_dim=192 base_num_decoder_layers=8 head_dim=128" CMD_DATA="" @@ -38,7 +37,7 @@ fi # This command runs training for some steps and saves a checkpoint. CMD1="python3 -m maxtext.trainers.pre_train.train ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml run_name=$RUN_NAME steps=5 max_target_length=128 per_device_batch_size=1\ metrics_file=saved_metrics.txt checkpoint_period=3 base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\ - async_checkpointing=$ASYNC_CHECKPOINTING collect_stack_trace=$COLLECT_STACK_TRACE attention=$ATTENTION" + async_checkpointing=$ASYNC_CHECKPOINTING attention=$ATTENTION" CMD1+=$model_params CMD1+=$CMD_DATA @@ -46,7 +45,7 @@ CMD1+=$CMD_DATA # This ensures actual new training steps are executed after restoring checkpoint from the above training run. CMD2="python3 -m maxtext.trainers.pre_train.train ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml run_name=$RUN_NAME steps=10 max_target_length=128 per_device_batch_size=1\ metrics_file=restored_metrics.txt base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\ - async_checkpointing=$ASYNC_CHECKPOINTING collect_stack_trace=$COLLECT_STACK_TRACE attention=$ATTENTION" + async_checkpointing=$ASYNC_CHECKPOINTING attention=$ATTENTION" CMD2+=$model_params CMD2+=$CMD_DATA diff --git a/tests/end_to_end/tpu/test_checkpoint_resharding.sh b/tests/end_to_end/tpu/test_checkpoint_resharding.sh index fa4558686a..3f31602d6b 100644 --- a/tests/end_to_end/tpu/test_checkpoint_resharding.sh +++ b/tests/end_to_end/tpu/test_checkpoint_resharding.sh @@ -8,11 +8,11 @@ DATASET_PATH=${3} # Train and save checkpoint - sharded with DCN Data Parallelism + ICI FSDP Parallelism python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml run_name=$RUN_NAME steps=101\ metrics_file='saved_metrics.txt' checkpoint_period=20 base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\ - dcn_data_parallelism=2 dcn_fsdp_parallelism=1 ici_fsdp_parallelism=4 ici_tensor_parallelism=1 collect_stack_trace=False + dcn_data_parallelism=2 dcn_fsdp_parallelism=1 ici_fsdp_parallelism=4 ici_tensor_parallelism=1 # Retrieve checkpoint - sharded with DCN Data Parallelism + ICI FSDP + Tensor Parallelism python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml run_name=$RUN_NAME steps=102\ metrics_file='restored_metrics.txt' base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\ - dcn_data_parallelism=2 dcn_fsdp_parallelism=1 ici_fsdp_parallelism=2 ici_tensor_parallelism=2 collect_stack_trace=False + dcn_data_parallelism=2 dcn_fsdp_parallelism=1 ici_fsdp_parallelism=2 ici_tensor_parallelism=2 python3 tests/end_to_end/tpu/eval_assert.py checkpoint_save_restore metrics.txt learning/loss diff --git a/tests/integration/checkpoint_resharding_test.py b/tests/integration/checkpoint_resharding_test.py index 0f0566ba92..215ef8606e 100644 --- a/tests/integration/checkpoint_resharding_test.py +++ b/tests/integration/checkpoint_resharding_test.py @@ -54,7 +54,6 @@ def get_resharding_command(run_date, steps, metrics_file, base_output_directory, f"dataset_path={dataset_path}", "dataset_type=synthetic", "grain_worker_count=0", - "collect_stack_trace=False", ] + model_params + parallelism_args diff --git a/tests/unit/gcloud_stub_test.py b/tests/unit/gcloud_stub_test.py index caf77cc3a7..4d173120d3 100644 --- a/tests/unit/gcloud_stub_test.py +++ b/tests/unit/gcloud_stub_test.py @@ -81,27 +81,6 @@ def test_jetstream_returns_stubs_when_deps_missing_and_decoupled(self): self.assertTrue(hasattr(engine_api, "Engine")) self.assertTrue(hasattr(engine_api, "ResultTokens")) - def test_cloud_diagnostics_contract_in_decoupled_mode(self): - """cloud_diagnostics() returns 4-tuple; content can be real or stub.""" - with mock.patch.dict(os.environ, {"DECOUPLE_GCLOUD": "TRUE"}): - diag, debug_cfg, diag_cfg, stack_cfg = gcloud_stub.cloud_diagnostics() - self.assertIsNotNone(diag) - self.assertIsNotNone(debug_cfg) - self.assertIsNotNone(diag_cfg) - self.assertIsNotNone(stack_cfg) - - def test_cloud_diagnostics_returns_stub_object_when_missing_and_decoupled(self): - """Force stub branch -> diag is stub object with .run().""" - with mock.patch.dict(os.environ, {"DECOUPLE_GCLOUD": "TRUE"}): - with mock.patch("maxtext.common.gcloud_stub._import_or_stub") as _ios: - _ios.side_effect = lambda import_fn, stub_fn, **kwargs: stub_fn() - diag, debug_cfg, diag_cfg, stack_cfg = gcloud_stub.cloud_diagnostics() - - self.assertTrue(hasattr(diag, "run")) - self.assertIsNotNone(debug_cfg) - self.assertIsNotNone(diag_cfg) - self.assertIsNotNone(stack_cfg) - def test_monitoring_modules_returns_stub_tuple_when_decoupled_and_missing(self): # Force stub path regardless of installed deps. with mock.patch.dict(os.environ, {"DECOUPLE_GCLOUD": "TRUE"}):