diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 78a58c2d..8c3c30b8 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -251,6 +251,10 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): s0 = time.perf_counter() + # Disable profiler for the first two runs to avoid duplicate uploads + original_enable_profiler = config.enable_profiler if "enable_profiler" in config.get_keys() else False + config.get_keys()["enable_profiler"] = False + # Using global_batch_size_to_train_on so not to create more config variables prompt = [config.prompt] * config.global_batch_size_to_train_on negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on @@ -321,6 +325,8 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): max_logging.log("\n".join(summary)) s0 = time.perf_counter() + # Restore original profiler setting for the profiling run + config.get_keys()["enable_profiler"] = original_enable_profiler if max_utils.profiler_enabled(config): # Injecting user requested XLA tracing flags xla_flags = os.environ.get("XLA_FLAGS", "") diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 3a885fba..8cff92a3 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -154,6 +154,24 @@ def stop(self): if _jax_profiler_enabled(self.config): jax.profiler.stop_trace() + trace_dir = self.config.tensorboard_dir + if trace_dir.startswith("gs://"): + local_dir = os.path.join("/tmp/profiler_traces", self.config.run_name) + if os.path.exists(local_dir): + max_logging.log(f"Uploading profiler traces from {local_dir} to {trace_dir}...") + client = storage.Client() + bucket_name, prefix = parse_gcs_bucket_and_prefix(trace_dir) + bucket = client.bucket(bucket_name) + + for root, _, files in os.walk(local_dir): + for file in files: + local_file = os.path.join(root, file) + rel_path = os.path.relpath(local_file, local_dir) + blob_name = os.path.join(prefix, rel_path) + blob = bucket.blob(blob_name) + blob.upload_from_filename(local_file) + max_logging.log(f"Uploaded {local_file} to gs://{bucket_name}/{blob_name}") + def __enter__(self): self.start() return self @@ -161,24 +179,6 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.stop() - trace_dir = self.config.tensorboard_dir - if trace_dir.startswith("gs://"): - local_dir = os.path.join("/tmp/profiler_traces", self.config.run_name) - if os.path.exists(local_dir): - max_logging.log(f"Uploading profiler traces from {local_dir} to {trace_dir}...") - client = storage.Client() - bucket_name, prefix = parse_gcs_bucket_and_prefix(trace_dir) - bucket = client.bucket(bucket_name) - - for root, _, files in os.walk(local_dir): - for file in files: - local_file = os.path.join(root, file) - rel_path = os.path.relpath(local_file, local_dir) - blob_name = os.path.join(prefix, rel_path) - blob = bucket.blob(blob_name) - blob.upload_from_filename(local_file) - max_logging.log(f"Uploaded {local_file} to gs://{bucket_name}/{blob_name}") - def initialize_summary_writer(config): return writer.SummaryWriter(config.tensorboard_dir) if jax.process_index() == 0 else None diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index 355ba6ae..f86f97b7 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -469,6 +469,7 @@ def scan_body(carry, t): if config and max_utils.profiler_enabled(config) and step == last_profiling_step: if profiler: + latents.block_until_ready() profiler.stop() return latents diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index 2c294a12..27be57ec 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -712,5 +712,6 @@ def scan_body(carry, t): if config and max_utils.profiler_enabled(config) and step == last_profiling_step: if profiler: + latents.block_until_ready() profiler.stop() return latents diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py index aa4bbba2..37c3ef85 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py @@ -478,5 +478,6 @@ def scan_body(carry, t): if config and max_utils.profiler_enabled(config) and step == last_profiling_step: if profiler: + latents.block_until_ready() profiler.stop() return latents diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py index 1ba54f2e..63af476f 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py @@ -820,5 +820,6 @@ def scan_body(carry, t): if config and max_utils.profiler_enabled(config) and step == last_profiling_step: if profiler: + latents.block_until_ready() profiler.stop() return latents diff --git a/src/maxdiffusion/tests/profiler_test.py b/src/maxdiffusion/tests/profiler_test.py index 7d4b1987..7b7683b3 100644 --- a/src/maxdiffusion/tests/profiler_test.py +++ b/src/maxdiffusion/tests/profiler_test.py @@ -16,6 +16,9 @@ import unittest from unittest.mock import patch +import os + +real_exists = os.path.exists from maxdiffusion import max_utils @@ -80,6 +83,82 @@ def test_profiler_disabled(self, mock_ml_run): max_utils.ensure_machinelearning_job_runs(config) mock_ml_run.assert_not_called() + @patch("maxdiffusion.max_utils.storage.Client") + @patch("maxdiffusion.max_utils.os.path.exists") + @patch("maxdiffusion.max_utils.os.walk", return_value=[("/tmp/profiler_traces/test_run", [], ["file1.trace"])]) + @patch("jax.profiler.start_trace") + @patch("jax.profiler.stop_trace") + @patch("jax.process_index", return_value=0) + def test_jax_profiler_manual_gcs( + self, + mock_process_index, + mock_stop_trace, + mock_start_trace, + mock_os_walk, + mock_os_exists, + mock_storage_client, + ): + """Tests manual start/stop with GCS upload.""" + mock_os_exists.side_effect = lambda path: True if path == "/tmp/profiler_traces/test_run" else real_exists(path) + config = MockConfig( + enable_ml_diagnostics=False, + enable_profiler=True, + tensorboard_dir="gs://test-bucket/tensorboard", + run_name="test_run", + ) + + profiler = max_utils.Profiler(config) + profiler.start() + mock_start_trace.assert_called_once() + + profiler.stop() + mock_stop_trace.assert_called_once() + + # Verify GCS upload was attempted + mock_storage_client.assert_called_once() + mock_bucket = mock_storage_client.return_value.bucket + mock_bucket.assert_called_once_with("test-bucket") + mock_blob = mock_bucket.return_value.blob + mock_blob.assert_called_once_with("tensorboard/file1.trace") + mock_blob.return_value.upload_from_filename.assert_called_once_with("/tmp/profiler_traces/test_run/file1.trace") + + @patch("maxdiffusion.max_utils.storage.Client") + @patch("maxdiffusion.max_utils.os.path.exists") + @patch("maxdiffusion.max_utils.os.walk", return_value=[("/tmp/profiler_traces/test_run", [], ["file1.trace"])]) + @patch("jax.profiler.start_trace") + @patch("jax.profiler.stop_trace") + @patch("jax.process_index", return_value=0) + def test_jax_profiler_context_gcs( + self, + mock_process_index, + mock_stop_trace, + mock_start_trace, + mock_os_walk, + mock_os_exists, + mock_storage_client, + ): + """Tests context manager with GCS upload.""" + mock_os_exists.side_effect = lambda path: True if path == "/tmp/profiler_traces/test_run" else real_exists(path) + config = MockConfig( + enable_ml_diagnostics=False, + enable_profiler=True, + tensorboard_dir="gs://test-bucket/tensorboard", + run_name="test_run", + ) + + with max_utils.Profiler(config): + mock_start_trace.assert_called_once() + + mock_stop_trace.assert_called_once() + + # Verify GCS upload was attempted + mock_storage_client.assert_called_once() + mock_bucket = mock_storage_client.return_value.bucket + mock_bucket.assert_called_once_with("test-bucket") + mock_blob = mock_bucket.return_value.blob + mock_blob.assert_called_once_with("tensorboard/file1.trace") + mock_blob.return_value.upload_from_filename.assert_called_once_with("/tmp/profiler_traces/test_run/file1.trace") + if __name__ == "__main__": unittest.main()