diff --git a/tests/buffer/experience_pipeline_test.py b/tests/buffer/experience_pipeline_test.py index 8f7aae740e..0605dc0d73 100644 --- a/tests/buffer/experience_pipeline_test.py +++ b/tests/buffer/experience_pipeline_test.py @@ -4,7 +4,7 @@ import ray import torch -from tests.tools import RayUnittestBaseAysnc, get_template_config +from tests.tools import RayUnittestBaseAsync, get_template_config from trinity.buffer import get_buffer_reader from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline from trinity.common.config import ( @@ -34,7 +34,7 @@ def get_experiences(task_num: int, repeat_times: int = 1, step_num: int = 1) -> ] -class TestExperiencePipeline(RayUnittestBaseAysnc): +class TestExperiencePipeline(RayUnittestBaseAsync): def setUp(self): if os.path.exists(BUFFER_FILE_PATH): os.remove(BUFFER_FILE_PATH) diff --git a/tests/buffer/experience_storage_test.py b/tests/buffer/experience_storage_test.py index 0071fabf39..158542afa9 100644 --- a/tests/buffer/experience_storage_test.py +++ b/tests/buffer/experience_storage_test.py @@ -7,7 +7,7 @@ import torch from parameterized import parameterized -from tests.tools import RayUnittestBaseAysnc +from tests.tools import RayUnittestBaseAsync from trinity.buffer.reader.sql_reader import SQLReader from trinity.buffer.writer.sql_writer import SQLWriter from trinity.common.config import ExperienceBufferConfig, ReplayBufferConfig @@ -17,7 +17,7 @@ DB_PATH = os.path.join(os.path.dirname(__file__), "test.db") -class ExperienceStorageTest(RayUnittestBaseAysnc): +class ExperienceStorageTest(RayUnittestBaseAsync): def setUp(self): self.total_num = 8 self.put_batch_size = 2 diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index 537514222f..695a5ce619 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -7,7 +7,7 @@ import torch from parameterized import parameterized -from tests.tools import RayUnittestBaseAysnc +from tests.tools import RayUnittestBaseAsync from trinity.buffer.reader.queue_reader import QueueReader from trinity.buffer.writer.queue_writer import QueueWriter from trinity.common.config import ExperienceBufferConfig, ReplayBufferConfig @@ -17,7 +17,7 @@ BUFFER_FILE_PATH = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl") -class TestQueueBuffer(RayUnittestBaseAysnc): +class TestQueueBuffer(RayUnittestBaseAsync): @parameterized.expand( [ ( diff --git a/tests/buffer/reader_test.py b/tests/buffer/reader_test.py index c1291720b0..05984932af 100644 --- a/tests/buffer/reader_test.py +++ b/tests/buffer/reader_test.py @@ -1,4 +1,4 @@ -from tests.tools import RayUnittestBaseAysnc, get_unittest_dataset_config +from tests.tools import RayUnittestBaseAsync, get_unittest_dataset_config from trinity.buffer.buffer import get_buffer_reader from trinity.buffer.reader import READER from trinity.buffer.reader.file_reader import FileReader, TaskFileReader @@ -12,7 +12,7 @@ def __init__(self, config): super().__init__(config) -class TestBufferReader(RayUnittestBaseAysnc): +class TestBufferReader(RayUnittestBaseAsync): async def test_buffer_reader_registration(self) -> None: config = get_unittest_dataset_config("countdown", "train") config.batch_size = 2 diff --git a/tests/buffer/sample_strategy_test.py b/tests/buffer/sample_strategy_test.py index c3a9af9179..f56d6f523e 100644 --- a/tests/buffer/sample_strategy_test.py +++ b/tests/buffer/sample_strategy_test.py @@ -5,7 +5,7 @@ import torch from parameterized import parameterized_class -from tests.tools import RayUnittestBaseAysnc, get_template_config +from tests.tools import RayUnittestBaseAsync, get_template_config from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY from trinity.algorithm.sample_strategy.sample_strategy import SampleStrategy from trinity.buffer.buffer import get_buffer_writer @@ -21,7 +21,7 @@ (6,), ], ) -class ExperienceStorageTest(RayUnittestBaseAysnc): +class ExperienceStorageTest(RayUnittestBaseAsync): def setUp(self): self.config = get_template_config() self.num_steps = 20 @@ -249,5 +249,5 @@ async def test_sql_staleness_control_sample_strategy(self): def tearDown(self): asyncio.run(self.buffer_writer.release()) - shutil.rmtree(self.config.checkpoint_job_dir) + shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) return super().tearDown() diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index b2cdd5712d..4d63f041e4 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -4,7 +4,7 @@ import torch from parameterized import parameterized -from tests.tools import RayUnittestBaseAysnc +from tests.tools import RayUnittestBaseAsync from trinity.buffer import get_buffer_reader from trinity.buffer.reader.sql_reader import SQLReader from trinity.buffer.writer.sql_writer import SQLWriter @@ -19,7 +19,7 @@ db_path = os.path.join(os.path.dirname(__file__), "test.db") -class TestSQLBuffer(RayUnittestBaseAysnc): +class TestSQLBuffer(RayUnittestBaseAsync): @parameterized.expand( [ (True,), diff --git a/tests/buffer/task_scheduler_test.py b/tests/buffer/task_scheduler_test.py index 6b5785a48e..901bd51aa3 100644 --- a/tests/buffer/task_scheduler_test.py +++ b/tests/buffer/task_scheduler_test.py @@ -33,7 +33,7 @@ def setUpClass(cls): def tearDownClass(cls): super().tearDownClass() if os.path.exists(cls.temp_output_path): - shutil.rmtree(cls.temp_output_path) + shutil.rmtree(cls.temp_output_path, ignore_errors=True) def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, int]]) -> None: for task, index in zip(batch_tasks, indices): diff --git a/tests/common/config_test.py b/tests/common/config_test.py index 2502dae824..4abe8b7774 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -184,4 +184,4 @@ def test_chat_template_path(self): def tearDown(self): if os.path.exists(CHECKPOINT_ROOT_DIR): - shutil.rmtree(CHECKPOINT_ROOT_DIR) + shutil.rmtree(CHECKPOINT_ROOT_DIR, ignore_errors=True) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index b6c675f550..b1deb90dff 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -8,7 +8,7 @@ from transformers import AutoTokenizer from tests.tools import ( - RayUnittestBaseAysnc, + RayUnittestBaseAsync, get_api_model_path, get_model_path, get_template_config, @@ -113,7 +113,7 @@ async def prepare_engines(engines, auxiliary_engines): (2, 1, 3, True, True), ], ) -class ModelWrapperTest(RayUnittestBaseAysnc): +class ModelWrapperTest(RayUnittestBaseAsync): def setUp(self): # configure the model self.config = get_template_config() @@ -233,7 +233,7 @@ async def test_generate(self): (20, 5, 15), ], ) -class TestModelLen(RayUnittestBaseAysnc): +class TestModelLen(RayUnittestBaseAsync): def setUp(self): self.config = get_template_config() self.config.mode = "explore" @@ -302,7 +302,7 @@ def _check_experience(exp): ) -class TestModelLenWithoutPromptTruncation(RayUnittestBaseAysnc): +class TestModelLenWithoutPromptTruncation(RayUnittestBaseAsync): def setUp(self): self.config = get_template_config() self.config.mode = "explore" @@ -351,7 +351,7 @@ async def test_model_len(self): ) -class TestAPIServer(RayUnittestBaseAysnc): +class TestAPIServer(RayUnittestBaseAsync): def setUp(self): self.config = get_template_config() self.config.mode = "explore" @@ -482,7 +482,7 @@ async def test_api(self): """ -class TestLogprobs(RayUnittestBaseAysnc): +class TestLogprobs(RayUnittestBaseAsync): def setUp(self): self.config = get_template_config() self.config.mode = "explore" @@ -669,7 +669,7 @@ async def test_logprobs_api(self): ) -class TestAsyncAPIServer(RayUnittestBaseAysnc): +class TestAsyncAPIServer(RayUnittestBaseAsync): def setUp(self): self.config = get_template_config() self.config.mode = "explore" @@ -880,7 +880,7 @@ def test_action_mask_with_tools(self): (False, None), ], ) -class TestAPIServerToolCall(RayUnittestBaseAysnc): +class TestAPIServerToolCall(RayUnittestBaseAsync): def setUp(self): self.config = get_template_config() self.config.mode = "explore" @@ -1161,7 +1161,7 @@ async def test_api_tool_calls(self): ) -class TestSuperLongGeneration(RayUnittestBaseAysnc): +class TestSuperLongGeneration(RayUnittestBaseAsync): def setUp(self): self.config = get_template_config() self.config.mode = "explore" @@ -1217,7 +1217,7 @@ async def test_generate(self): self.assertGreater(response.logprobs.shape[0], 1000) -class TestTinkerAPI(RayUnittestBaseAysnc): +class TestTinkerAPI(RayUnittestBaseAsync): """Test the Tinker API integration with the vLLM engine.""" def setUp(self): diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 299bbf66b8..b17bf7709b 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -12,7 +12,7 @@ from tests.tools import ( RayUnittestBase, - RayUnittestBaseAysnc, + RayUnittestBaseAsync, TensorBoardParser, get_api_model_path, get_checkpoint_path, @@ -180,7 +180,7 @@ def run_agent(proxy_url, model_path: str): return response.choices[0].message.content -class ServeTest(RayUnittestBaseAysnc): +class ServeTest(RayUnittestBaseAsync): def setUp(self): self.config = get_template_config() self.config.mode = "serve" diff --git a/tests/manager/synchronizer_test.py b/tests/manager/synchronizer_test.py index 6ab488986e..f0d5415863 100644 --- a/tests/manager/synchronizer_test.py +++ b/tests/manager/synchronizer_test.py @@ -82,12 +82,14 @@ def run_trainer(config: Config, max_steps: int, intervals: List[int]) -> None: ray.init(ignore_reinit_error=True, namespace=config.ray_namespace) trainer_monkey_patch(config, max_steps, intervals) train(config) + ray.shutdown() def run_explorer(config: Config, max_steps: int, intervals: List[int]) -> None: ray.init(ignore_reinit_error=True, namespace=config.ray_namespace) explorer_monkey_patch(config, max_steps, intervals) explore(config) + ray.shutdown() def run_both( @@ -97,17 +99,26 @@ def run_both( trainer_monkey_patch(config, max_steps, trainer_intervals) explorer_monkey_patch(config, max_steps, explorer_intervals) both(config) + ray.shutdown() class BaseTestSynchronizer(unittest.TestCase): def setUp(self): if multiprocessing.get_start_method(allow_none=True) != "spawn": multiprocessing.set_start_method("spawn", force=True) + self.process_list = [] def tearDown(self): - checkpoint_path = get_checkpoint_path() ray.shutdown(_exiting_interpreter=True) - shutil.rmtree(os.path.join(checkpoint_path, "unittest")) + if os.path.exists(CHECKPOINT_ROOT_DIR): + shutil.rmtree(CHECKPOINT_ROOT_DIR, ignore_errors=True) + for process in self.process_list: + if process.is_alive(): + process.terminate() + process.join(timeout=10) + if process.is_alive(): + process.kill() + process.join() class TestSynchronizerExit(BaseTestSynchronizer): @@ -151,6 +162,8 @@ def test_synchronizer(self): target=run_trainer, args=(trainer_config, 8, [2, 1, 2, 1, 2, 1, 2, 1]) ) trainer_process.start() + self.process_list.append(trainer_process) + ray.init(ignore_reinit_error=True) while True: try: @@ -164,6 +177,7 @@ def test_synchronizer(self): args=(explorer1_config, 8, [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5]), ) explorer_process_1.start() + self.process_list.append(explorer_process_1) self.assertEqual( synchronizer, ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace) @@ -176,14 +190,13 @@ def test_synchronizer(self): except ValueError: print("waiting for explorer1 to start.") time.sleep(5) - trainer_process.terminate() - trainer_process.join() + + trainer_process.join(timeout=200) self.assertEqual( synchronizer, ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace) ) - explorer_process_1.terminate() - explorer_process_1.join() + explorer_process_1.join(timeout=200) time.sleep(6) with self.assertRaises(ValueError): ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace) @@ -278,6 +291,8 @@ def test_synchronizer(self): target=run_trainer, args=(trainer_config, self.max_steps, self.trainer_intervals) ) trainer_process.start() + self.process_list.append(trainer_process) + ray.init(ignore_reinit_error=True) while True: try: @@ -291,10 +306,12 @@ def test_synchronizer(self): args=(explorer1_config, self.max_steps, self.explorer1_intervals), ) explorer_process_1.start() + self.process_list.append(explorer_process_1) explorer_process_2 = multiprocessing.Process( target=run_explorer, args=(explorer2_config, self.max_steps, self.explorer2_intervals) ) explorer_process_2.start() + self.process_list.append(explorer_process_2) explorer_process_1.join(timeout=200) explorer_process_2.join(timeout=200) @@ -364,6 +381,7 @@ def test_synchronizer(self): args=(config, self.max_steps, self.trainer_intervals, self.explorer_intervals), ) both_process.start() + self.process_list.append(both_process) both_process.join(timeout=200) # check the tensorboard @@ -375,7 +393,3 @@ def test_synchronizer(self): ) rollout_metrics = parser.metric_list("rollout") self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) - - def tearDown(self): - if os.path.exists(CHECKPOINT_ROOT_DIR): - shutil.rmtree(CHECKPOINT_ROOT_DIR) diff --git a/tests/service/data_juicer_test.py b/tests/service/data_juicer_test.py index 60440e0d6e..2a966b5546 100644 --- a/tests/service/data_juicer_test.py +++ b/tests/service/data_juicer_test.py @@ -219,7 +219,7 @@ async def test_data_juicer_operators(self): class TestDataJuicerTaskPipeline(RayUnittestBase): def setUp(self): if os.path.exists(TASKSET_OUTPUT_DIR): - shutil.rmtree(TASKSET_OUTPUT_DIR) + shutil.rmtree(TASKSET_OUTPUT_DIR, ignore_errors=True) def test_data_juicer_task_pipeline(self): config = get_template_config() diff --git a/tests/tools.py b/tests/tools.py index 7c8713319e..4b6055f8a4 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -253,7 +253,7 @@ def tearDownClass(cls): ray.shutdown(_exiting_interpreter=True) -class RayUnittestBaseAysnc(unittest.IsolatedAsyncioTestCase): +class RayUnittestBaseAsync(unittest.IsolatedAsyncioTestCase): @classmethod def setUpClass(cls): ray.init(ignore_reinit_error=True, namespace="trinity_unittest") diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index c37c05559c..2a54b9a494 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -16,7 +16,7 @@ from tests.tools import ( RayUnittestBase, - RayUnittestBaseAysnc, + RayUnittestBaseAsync, TensorBoardParser, get_checkpoint_path, get_lora_config, @@ -109,22 +109,22 @@ def test_trainer(self): both(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) rollout_metrics = parser.metric_list("rollout") - self.assertTrue(len(rollout_metrics) > 0) + self.assertGreater(len(rollout_metrics), 0) self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) eval_metrics = parser.metric_list("eval") - self.assertTrue(len(eval_metrics) > 0) + self.assertGreater(len(eval_metrics), 0) self.assertEqual(parser.metric_max_step(eval_metrics[0]), 8) actor_metrics = parser.metric_list("actor") - self.assertTrue(len(actor_metrics) > 0) + self.assertGreater(len(actor_metrics), 0) self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8) actor_kl_metrics = parser.metric_list("actor/kl") - self.assertTrue(len(actor_kl_metrics) > 0) + self.assertGreater(len(actor_kl_metrics), 0) actor_kl_loss = parser.metric_values("actor/kl_loss") self.assertEqual(actor_kl_loss[0], 0.0) critic_kl_metrics = parser.metric_list("critic/kl") - self.assertTrue(len(critic_kl_metrics) > 0) + self.assertGreater(len(critic_kl_metrics), 0) response_metrics = parser.metric_list("response_length") - self.assertTrue(len(response_metrics) > 0) + self.assertGreater(len(response_metrics), 0) self.assertEqual(parser.metric_max_step(response_metrics[0]), 8) ray.shutdown(_exiting_interpreter=True) # check checkpoint @@ -138,13 +138,13 @@ def test_trainer(self): checkpoint_root_path=self.config.checkpoint_job_dir, trainer_type=self.config.trainer.trainer_type, ) - self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_4, "actor"))) > 0) - self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_8, "actor"))) > 0) - self.assertTrue( - len(os.listdir(os.path.join(checkpoint_step_4, "actor", "huggingface"))) > 0 + self.assertGreater(len(os.listdir(os.path.join(checkpoint_step_4, "actor"))), 0) + self.assertGreater(len(os.listdir(os.path.join(checkpoint_step_8, "actor"))), 0) + self.assertGreater( + len(os.listdir(os.path.join(checkpoint_step_4, "actor", "huggingface"))), 0 ) - self.assertTrue( - len(os.listdir(os.path.join(checkpoint_step_8, "actor", "huggingface"))) > 0 + self.assertGreater( + len(os.listdir(os.path.join(checkpoint_step_8, "actor", "huggingface"))), 0 ) self.assertEqual(step_num, 8) ray.init(ignore_reinit_error=True, namespace=self.config.ray_namespace) @@ -161,7 +161,7 @@ def test_trainer(self): for prefix in ["eval", "bench"]: for taskset_name in ["countdown", "copy_countdown"]: metrics = parser.metric_list(f"{prefix}/{taskset_name}") - self.assertTrue(len(metrics) > 0) + self.assertGreater(len(metrics), 0, f"{prefix}/{taskset_name} metrics not found") for eval_stats in ["mean", "best", "worst"]: for k in [2, 4]: for stats in ["mean", "std"]: @@ -171,7 +171,7 @@ def test_trainer(self): def tearDown(self): # remove dir only when the test passed - shutil.rmtree(self.config.checkpoint_job_dir) + shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) class TestStepAheadAsyncRL(BaseTrainerCase): @@ -200,17 +200,17 @@ def test_trainer(self): both(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) rollout_metrics = parser.metric_list("rollout") - self.assertTrue(len(rollout_metrics) > 0) + self.assertGreater(len(rollout_metrics), 0) self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) actor_metrics = parser.metric_list("actor") - self.assertTrue(len(actor_metrics) > 0) + self.assertGreater(len(actor_metrics), 0) self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4) actor_kl_metrics = parser.metric_list("actor/kl") - self.assertTrue(len(actor_kl_metrics) > 0) + self.assertGreater(len(actor_kl_metrics), 0) critic_kl_metrics = parser.metric_list("critic/kl") - self.assertTrue(len(critic_kl_metrics) > 0) + self.assertGreater(len(critic_kl_metrics), 0) response_metrics = parser.metric_list("response_length") - self.assertTrue(len(response_metrics) > 0) + self.assertGreater(len(response_metrics), 0) self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) ray.shutdown(_exiting_interpreter=True) # check checkpoint @@ -224,7 +224,7 @@ def test_trainer(self): def tearDown(self): # remove dir only when the test passed - shutil.rmtree(self.config.checkpoint_job_dir) + shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) @parameterized_class( @@ -266,15 +266,15 @@ def test_trainer(self): both(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) rollout_metrics = parser.metric_list("rollout") - self.assertTrue(len(rollout_metrics) > 0) + self.assertGreater(len(rollout_metrics), 0) pipeline_metrics = parser.metric_list("experience_pipeline") - self.assertTrue(len(pipeline_metrics) > 0) + self.assertGreater(len(pipeline_metrics), 0) self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) actor_metrics = parser.metric_list("actor") - self.assertTrue(len(actor_metrics) > 0) + self.assertGreater(len(actor_metrics), 0) self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4) response_metrics = parser.metric_list("response_length") - self.assertTrue(len(response_metrics) > 0) + self.assertGreater(len(response_metrics), 0) self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) # TODO: used for real testing # rewards = parser.metric_values("critic/rewards/mean") @@ -285,7 +285,7 @@ def test_trainer(self): def tearDown(self): # remove dir only when the test passed - shutil.rmtree(self.config.checkpoint_job_dir) + shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) class TestTrainerSFTWarmupGSM8K(BaseTrainerCase): @@ -345,12 +345,12 @@ def test_trainer(self, mock_load): sft_config = stage_configs[0] parser = TensorBoardParser(os.path.join(sft_config.monitor.cache_dir, "tensorboard")) rollout_metrics = parser.metric_list("rollout") - self.assertTrue(len(rollout_metrics) == 0) + self.assertEqual(len(rollout_metrics), 0) sft_metrics = parser.metric_list("actor/sft") - self.assertTrue(len(sft_metrics) > 0) + self.assertGreater(len(sft_metrics), 0) self.assertEqual(parser.metric_max_step(sft_metrics[0]), 3) response_metrics = parser.metric_list("response_length") - self.assertTrue(len(response_metrics) > 0) + self.assertGreater(len(response_metrics), 0) self.assertEqual(parser.metric_min_step(response_metrics[0]), 1) self.assertEqual(parser.metric_max_step(response_metrics[0]), 3) @@ -362,14 +362,14 @@ def test_trainer(self, mock_load): grpo_config = stage_configs[1] parser = TensorBoardParser(os.path.join(grpo_config.monitor.cache_dir, "tensorboard")) rollout_metrics = parser.metric_list("rollout") - self.assertTrue(len(rollout_metrics) > 0) + self.assertGreater(len(rollout_metrics), 0) self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) actor_metrics = parser.metric_list("actor") - self.assertTrue(len(actor_metrics) > 0) + self.assertGreater(len(actor_metrics), 0) sft_metrics = parser.metric_list("actor/sft") - self.assertTrue(len(sft_metrics) == 0) + self.assertEqual(len(sft_metrics), 0) response_metrics = parser.metric_list("response_length") - self.assertTrue(len(response_metrics) > 0) + self.assertGreater(len(response_metrics), 0) self.assertEqual(parser.metric_min_step(response_metrics[0]), 1) self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) # test save checkpoint when sft finish @@ -385,11 +385,11 @@ def test_trainer(self, mock_load): trainer_type="verl", ) self.assertEqual(step_num, 4) - self.assertTrue(len(os.listdir(os.path.join(checkpoint_dir, "actor"))) > 0) + self.assertGreater(len(os.listdir(os.path.join(checkpoint_dir, "actor"))), 0) def tearDown(self): # TODO: remove dir only when the test passed - shutil.rmtree(self.config.checkpoint_job_dir) + shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) class TestTrainerDPO(BaseTrainerCase): @@ -411,12 +411,12 @@ def test_trainer(self): train(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) actor_metrics = parser.metric_list("actor") - self.assertTrue(len(actor_metrics) > 0) + self.assertGreater(len(actor_metrics), 0) self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4) def tearDown(self): # remove dir only when the test passed - shutil.rmtree(self.config.checkpoint_job_dir) + shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) class TestTrainerSFT(BaseTrainerCase): @@ -439,12 +439,12 @@ def test_trainer(self): train(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) actor_metrics = parser.metric_list("actor") - self.assertTrue(len(actor_metrics) > 0) + self.assertGreater(len(actor_metrics), 0) self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4) def tearDown(self): # remove dir only when the test passed - shutil.rmtree(self.config.checkpoint_job_dir) + shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) class TestTrainerToolsSFT(BaseTrainerCase): @@ -467,15 +467,15 @@ def test_trainer_tools(self): train(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) actor_metrics = parser.metric_list("actor") - self.assertTrue(len(actor_metrics) > 0) + self.assertGreater(len(actor_metrics), 0) self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4) def tearDown(self): # remove dir only when the test passed - shutil.rmtree(self.config.checkpoint_job_dir) + shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) -def run_trainer(config: Config) -> None: +def run_trainer(config: Config, stop_event=None) -> None: ray.init( namespace=config.ray_namespace, runtime_env={ @@ -485,10 +485,15 @@ def run_trainer(config: Config) -> None: } }, ) - train(config) + try: + train(config) + finally: + if stop_event: + stop_event.set() + ray.shutdown() -def run_explorer(config: Config) -> None: +def run_explorer(config: Config, stop_event=None) -> None: ray.init( namespace=config.ray_namespace, runtime_env={ @@ -498,10 +503,15 @@ def run_explorer(config: Config) -> None: } }, ) - explore(config) + try: + explore(config) + finally: + if stop_event: + stop_event.set() + ray.shutdown() -def run_both(config: Config) -> None: +def run_both(config: Config, stop_event=None) -> None: ray.init( namespace=config.ray_namespace, runtime_env={ @@ -511,10 +521,15 @@ def run_both(config: Config) -> None: } }, ) - both(config) + try: + both(config) + finally: + if stop_event: + stop_event.set() + ray.shutdown() -def run_serve(config: Config) -> None: +def run_serve(config: Config, stop_event=None) -> None: ray.init( namespace=config.ray_namespace, runtime_env={ @@ -524,7 +539,12 @@ def run_serve(config: Config) -> None: } }, ) - serve(config) + try: + serve(config) + finally: + if stop_event: + stop_event.set() + ray.shutdown() @parameterized_class( @@ -535,6 +555,7 @@ class TestFullyAsyncMode(unittest.TestCase): def setUp(self): if multiprocessing.get_start_method(allow_none=True) != "spawn": multiprocessing.set_start_method("spawn", force=True) + self.process_list = [] def test_fully_async_mode(self): config = get_template_config() @@ -585,8 +606,12 @@ def test_fully_async_mode(self): explorer2_config.trainer = deepcopy(trainer_config.trainer) explorer1_config.check_and_update() - trainer_process = multiprocessing.Process(target=run_trainer, args=(trainer_config,)) + trainer_stop_event = multiprocessing.Event() + trainer_process = multiprocessing.Process( + target=run_trainer, args=(trainer_config, trainer_stop_event) + ) trainer_process.start() + self.process_list.append(trainer_process) ray.init(ignore_reinit_error=True) while True: @@ -597,20 +622,34 @@ def test_fully_async_mode(self): print("waiting for trainer to start.") time.sleep(5) - explorer_process_1 = multiprocessing.Process(target=run_explorer, args=(explorer1_config,)) + explorer1_stop_event = multiprocessing.Event() + explorer_process_1 = multiprocessing.Process( + target=run_explorer, args=(explorer1_config, explorer1_stop_event) + ) explorer_process_1.start() + self.process_list.append(explorer_process_1) time.sleep(5) explorer2_config.explorer.name = "explorer2" explorer2_config.check_and_update() - explorer_process_2 = multiprocessing.Process(target=run_explorer, args=(explorer2_config,)) + explorer2_stop_event = multiprocessing.Event() + explorer_process_2 = multiprocessing.Process( + target=run_explorer, args=(explorer2_config, explorer2_stop_event) + ) explorer_process_2.start() + self.process_list.append(explorer_process_2) - explorer_process_1.join() - explorer_process_2.join() + explorer_process_1.join(timeout=300) + if explorer_process_1.is_alive(): + self.fail("explorer1 process is still alive") + explorer_process_2.join(timeout=300) + if explorer_process_2.is_alive(): + self.fail("explorer2 process is still alive") # wait for trainer process to finish. trainer_process.join(timeout=200) + if trainer_process.is_alive(): + self.fail("trainer process is still alive") # check the tensorboard parser = TensorBoardParser( @@ -668,26 +707,33 @@ def test_fully_async_mode(self): 8, ) log_files = os.listdir(os.path.join(explorer1_config.checkpoint_job_dir, "log")) - self.assertTrue("trainer.log" in log_files) - self.assertTrue("synchronizer.log" in log_files) - self.assertTrue("explorer1.log" in log_files) - self.assertTrue("explorer2.log" in log_files) - self.assertTrue("explorer1_runner_0.log" in log_files) - self.assertTrue("explorer1_runner_7.log" in log_files) - self.assertTrue("explorer2_runner_0.log" in log_files) - self.assertTrue("explorer2_runner_7.log" in log_files) - self.assertTrue("explorer1_experience_pipeline.log" in log_files) - self.assertTrue("explorer2_experience_pipeline.log" in log_files) + self.assertIn("trainer.log", log_files) + self.assertIn("synchronizer.log", log_files) + self.assertIn("explorer1.log", log_files) + self.assertIn("explorer2.log", log_files) + self.assertIn("explorer1_runner_0.log", log_files) + self.assertIn("explorer1_runner_7.log", log_files) + self.assertIn("explorer2_runner_0.log", log_files) + self.assertIn("explorer2_runner_7.log", log_files) + self.assertIn("explorer1_experience_pipeline.log", log_files) + self.assertIn("explorer2_experience_pipeline.log", log_files) files_to_check = ["trainer.log", "synchronizer.log", "explorer1.log", "explorer2.log"] for file_name in files_to_check: with open(os.path.join(explorer1_config.checkpoint_job_dir, "log", file_name)) as f: lines = f.readlines() - self.assertTrue(len(lines) > 0) + self.assertGreater(len(lines), 0, f"{file_name} is empty") ray.shutdown() def tearDown(self): checkpoint_path = get_checkpoint_path() - shutil.rmtree(os.path.join(checkpoint_path, "unittest")) + shutil.rmtree(os.path.join(checkpoint_path, "unittest"), ignore_errors=True) + for process in self.process_list: + if process.is_alive(): + process.terminate() + process.join(timeout=10) + if process.is_alive(): + process.kill() + process.join() @parameterized_class( @@ -719,6 +765,7 @@ def setUp(self): self.config.trainer.save_hf_checkpoint = "last" self.config.trainer.trainer_strategy = self.strategy self.config.check_and_update() + self.process_list = [] def test_trainer(self): """Test the checkpoint saving.""" @@ -730,8 +777,10 @@ def test_trainer(self): _trainer_config.trainer.max_actor_ckpt_to_keep = 2 _trainer_config.trainer.max_critic_ckpt_to_keep = 2 - trainer_process = multiprocessing.Process(target=run_both, args=(self.config,)) + stop_event = multiprocessing.Event() + trainer_process = multiprocessing.Process(target=run_both, args=(self.config, stop_event)) trainer_process.start() + self.process_list.append(trainer_process) default_local_dir = _trainer_config.trainer.default_local_dir state_dict_iteration = checkpoint_iteration = 0 @@ -751,7 +800,10 @@ def test_trainer(self): "__1_1.distcp", "__0_0.distcp", } - while state_dict_iteration < 4 and checkpoint_iteration < 4: + start_time = time.time() + while not stop_event.is_set() and time.time() - start_time < 60 * 10: + time.sleep(10) + if os.path.exists(state_dict_iteration_file): try: with open(state_dict_iteration_file, "r") as f: @@ -783,10 +835,10 @@ def test_trainer(self): items = os.listdir(huggingface_dir) self.assertIn("config.json", items) self.assertIn("generation_config.json", items) - print(f"State dict check at {state_dict_iteration} iteration passed.") + # print(f"State dict check at {state_dict_iteration} iteration passed.") # for debug if checkpoint_iteration > 0: - for sub_dir_name in ["actor", "critic"]: + for sub_dir_name in ["critic", "actor"]: iteration_dir = os.path.join( default_local_dir, f"global_step_{checkpoint_iteration}", sub_dir_name ) @@ -811,8 +863,10 @@ def test_trainer(self): megatron_dist_ckpt_items, ) huggingface_dir = os.path.join(iteration_dir, "huggingface") + huggingface_dir_files = os.listdir(huggingface_dir) self.assertEqual( - set(os.listdir(huggingface_dir)) - {"generation_config.json"}, + set(huggingface_dir_files) + - {"generation_config.json", "model.safetensors"}, { "vocab.json", "merges.txt", @@ -824,14 +878,22 @@ def test_trainer(self): "special_tokens_map.json", }, ) - print(f"Checkpoint check at {checkpoint_iteration} iteration passed.") - - time.sleep(1) - trainer_process.join() + # print(f"Checkpoint check at {checkpoint_iteration} iteration passed.") # for debug + if not stop_event.is_set(): + self.fail("Training process failed to stop.") + trainer_process.join(timeout=10) + self.assertIn("model.safetensors", huggingface_dir_files) def tearDown(self): # remove dir only when the test passed - shutil.rmtree(self.config.checkpoint_job_dir) + shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) + for process in self.process_list: + if process.is_alive(): + process.terminate() + process.join(timeout=10) + if process.is_alive(): + process.kill() + process.join() class TestTrainerMIX(BaseTrainerCase): @@ -890,20 +952,20 @@ def test_trainer(self): # test rollout metrics rollout_metrics = parser.metric_list("rollout") - self.assertTrue(len(rollout_metrics) > 0) + self.assertGreater(len(rollout_metrics), 0) self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) self.assertEqual( parser.metric_values("experience_pipeline/experience_count")[1], 16 ) # 16 rft experiences # test actor metrics actor_metrics = parser.metric_list("actor") - self.assertTrue(len(actor_metrics) > 0) + self.assertGreater(len(actor_metrics), 0) expert_metrics = parser.metric_list("actor/expert/") self.assertEqual(parser.metric_max_step(expert_metrics[0]), 4) # SFT usual_metrics = parser.metric_list("actor/usual/") self.assertEqual(parser.metric_max_step(usual_metrics[0]), 4) # RFT response_metrics = parser.metric_list("response_length") - self.assertTrue(len(response_metrics) > 0) + self.assertGreater(len(response_metrics), 0) self.assertEqual(parser.metric_min_step(response_metrics[0]), 1) self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) # test save checkpoint at last step @@ -912,10 +974,10 @@ def test_trainer(self): trainer_type="verl", ) self.assertEqual(step_num, 4) - self.assertTrue(len(os.listdir(os.path.join(checkpoint_dir, "actor"))) > 0) + self.assertGreater(len(os.listdir(os.path.join(checkpoint_dir, "actor"))), 0) def tearDown(self): - shutil.rmtree(self.config.checkpoint_job_dir) + shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) async def run_math_workflow(serve_url: str, task: dict): @@ -950,14 +1012,13 @@ async def run_math_workflow(serve_url: str, task: dict): await proxy_client.feedback_async(sum(reward.values()), [response.id]) -class TestServeWithTrainer(RayUnittestBaseAysnc): +class TestServeWithTrainer(RayUnittestBaseAsync): def setUp(self): if multiprocessing.get_start_method(allow_none=True) != "spawn": multiprocessing.set_start_method("spawn", force=True) checkpoint_path = get_checkpoint_path() shutil.rmtree(os.path.join(checkpoint_path, "unittest"), ignore_errors=True) - async def test_serve_with_trainer(self): # noqa: C901 config = get_template_config() config.project = "unittest" config.name = f"serve_with_trainer_{datetime.now().strftime('%Y%m%d%H%M%S')}" @@ -983,14 +1044,18 @@ async def test_serve_with_trainer(self): # noqa: C901 config.explorer.rollout_model.enable_openai_api = True config.explorer.rollout_model.tensor_parallel_size = 1 config.explorer.service_status_check_interval = 5 + self.config = config + self.process_list = [] - trainer_config = deepcopy(config) + async def test_serve_with_trainer(self): # noqa: C901 + trainer_config = deepcopy(self.config) trainer_config.mode = "train" trainer_config.check_and_update() trainer_config.trainer.max_actor_ckpt_to_keep = 10 trainer_process = multiprocessing.Process(target=run_trainer, args=(trainer_config,)) trainer_process.start() + self.process_list.append(trainer_process) ray.init(ignore_reinit_error=True) while True: @@ -1000,11 +1065,12 @@ async def test_serve_with_trainer(self): # noqa: C901 except ValueError: print("waiting for trainer to start.") await asyncio.sleep(5) - serve_config = deepcopy(config) + serve_config = deepcopy(self.config) serve_config.mode = "serve" serve_config.check_and_update() serve_process = multiprocessing.Process(target=run_serve, args=(serve_config,)) serve_process.start() + self.process_list.append(serve_process) state_manager = StateManager( path=serve_config.checkpoint_job_dir, @@ -1030,60 +1096,64 @@ async def test_serve_with_trainer(self): # noqa: C901 break await asyncio.sleep(2) - config.buffer.explorer_input.taskset.batch_size = 4 - reader = get_buffer_reader(config.buffer.explorer_input.taskset) - - try: - for i in range(3): - tasks = reader.read() - await asyncio.gather( - *(run_math_workflow(server_url, task.raw_task) for task in tasks) + self.config.buffer.explorer_input.taskset.batch_size = 4 + reader = get_buffer_reader(self.config.buffer.explorer_input.taskset) + + for i in range(3): + tasks = reader.read() + await asyncio.gather(*(run_math_workflow(server_url, task.raw_task) for task in tasks)) + await proxy_client.commit_async() + # wait for synchronizer started + end_time = time.time() + find_checkpoint = False + while time.time() - end_time < 100: + _, step_num = get_checkpoint_dir_with_step_num( + checkpoint_root_path=serve_config.checkpoint_job_dir, + raise_error=False, ) - await proxy_client.commit_async() - # wait for synchronizer started - end_time = time.time() - find_checkpoint = False - while time.time() - end_time < 100: - _, step_num = get_checkpoint_dir_with_step_num( - checkpoint_root_path=serve_config.checkpoint_job_dir, - raise_error=False, - ) - if step_num >= i + 1: # checkpoint has been generated - find_checkpoint = True - break - await asyncio.sleep(1) - self.assertTrue(find_checkpoint, f"Checkpoint at step {i + 1} not found in time.") - metrics = await proxy_client.get_metrics_async() - self.assertTrue(metrics["rollout/total_experience_count"] == 4 * (i + 1)) - self.assertTrue(metrics["rollout/ready_experience_count"] == 4 * (i + 1)) - self.assertTrue(metrics["rollout/model_0/total_request_count"] > 0) - self.assertTrue(metrics["rollout/model_1/total_request_count"] > 0) - if i > 1: - self.assertTrue(metrics["rollout/model_0/model_version"] > 0) - self.assertTrue(metrics["rollout/model_1/model_version"] > 0) + if step_num >= i + 1: # checkpoint has been generated + find_checkpoint = True + break + await asyncio.sleep(1) + self.assertTrue(find_checkpoint, f"Checkpoint at step {i + 1} not found in time.") metrics = await proxy_client.get_metrics_async() - self.assertEqual(metrics["rollout/total_experience_count"], 12) - self.assertEqual(metrics["rollout/ready_experience_count"], 12) - self.assertTrue(metrics["rollout/model_0/total_request_count"] > 0) - self.assertTrue(metrics["rollout/model_1/total_request_count"] > 0) - self.assertEqual( - metrics["rollout/model_0/total_request_count"] - + metrics["rollout/model_1/total_request_count"], - metrics["rollout/total_experience_count"], - ) - # at least updated to version 2 - self.assertTrue(metrics["rollout/model_0/model_version"] >= 2) - self.assertTrue(metrics["rollout/model_1/model_version"] >= 2) - # check final checkpoint - _, step_num = get_checkpoint_dir_with_step_num( - checkpoint_root_path=serve_config.checkpoint_job_dir, - step_num=3, - ) - finally: - serve_process.terminate() - trainer_process.terminate() - serve_process.join() - trainer_process.join() + self.assertEqual(metrics["rollout/total_experience_count"], 4 * (i + 1)) + self.assertEqual(metrics["rollout/ready_experience_count"], 4 * (i + 1)) + self.assertGreater(metrics["rollout/model_0/total_request_count"], 0) + self.assertGreater(metrics["rollout/model_1/total_request_count"], 0) + if i > 1: + self.assertGreater(metrics["rollout/model_0/model_version"], 0) + self.assertGreater(metrics["rollout/model_1/model_version"], 0) + metrics = await proxy_client.get_metrics_async() + self.assertEqual(metrics["rollout/total_experience_count"], 12) + self.assertEqual(metrics["rollout/ready_experience_count"], 12) + self.assertGreater(metrics["rollout/model_0/total_request_count"], 0) + self.assertGreater(metrics["rollout/model_1/total_request_count"], 0) + self.assertEqual( + metrics["rollout/model_0/total_request_count"] + + metrics["rollout/model_1/total_request_count"], + metrics["rollout/total_experience_count"], + ) + # at least updated to version 2 + await asyncio.sleep(5) # wait for model version update + self.assertGreaterEqual(metrics["rollout/model_0/model_version"], 2) + self.assertGreaterEqual(metrics["rollout/model_1/model_version"], 2) + # check final checkpoint + _, step_num = get_checkpoint_dir_with_step_num( + checkpoint_root_path=serve_config.checkpoint_job_dir, + step_num=3, + ) + + def tearDown(self): + shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) + for process in self.process_list: + if process.is_alive(): + process.terminate() + process.join(timeout=10) + if process.is_alive(): + process.kill() + process.join() + super().tearDown() class TestMultiModalGRPO(BaseTrainerCase): @@ -1106,25 +1176,25 @@ def test_trainer(self): # check metrics are available parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) rollout_metrics = parser.metric_list("rollout") - self.assertTrue(len(rollout_metrics) > 0) + self.assertGreater(len(rollout_metrics), 0) self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2) actor_metrics = parser.metric_list("actor") - self.assertTrue(len(actor_metrics) > 0) + self.assertGreater(len(actor_metrics), 0) self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2) response_metrics = parser.metric_list("response_length") - self.assertTrue(len(response_metrics) > 0) + self.assertGreater(len(response_metrics), 0) self.assertEqual(parser.metric_max_step(response_metrics[0]), 2) # check save lastest checkpoint checkpoint_step_2, step_num = get_checkpoint_dir_with_step_num( checkpoint_root_path=self.config.checkpoint_job_dir, trainer_type=self.config.trainer.trainer_type, ) - self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_2, "actor"))) > 0) + self.assertGreater(len(os.listdir(os.path.join(checkpoint_step_2, "actor"))), 0) self.assertEqual(step_num, 2) def tearDown(self): # remove dir only when the test passed - shutil.rmtree(self.config.checkpoint_job_dir) + shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) class TestMultiModalSFT(BaseTrainerCase): @@ -1149,22 +1219,22 @@ def test_trainer(self): # check metrics are available parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) actor_metrics = parser.metric_list("actor") - self.assertTrue(len(actor_metrics) > 0) + self.assertGreater(len(actor_metrics), 0) self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2) response_metrics = parser.metric_list("response_length") - self.assertTrue(len(response_metrics) > 0) + self.assertGreater(len(response_metrics), 0) self.assertEqual(parser.metric_max_step(response_metrics[0]), 2) # check save lastest checkpoint checkpoint_step_2, step_num = get_checkpoint_dir_with_step_num( checkpoint_root_path=self.config.checkpoint_job_dir, trainer_type=self.config.trainer.trainer_type, ) - self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_2, "actor"))) > 0) + self.assertGreater(len(os.listdir(os.path.join(checkpoint_step_2, "actor"))), 0) self.assertEqual(step_num, 2) def tearDown(self): # remove dir only when the test passed - shutil.rmtree(self.config.checkpoint_job_dir) + shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) class TestTrainerLoRA(BaseTrainerCase): @@ -1194,13 +1264,13 @@ def test_trainer(self): # check metrics are available parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) rollout_metrics = parser.metric_list("rollout") - self.assertTrue(len(rollout_metrics) > 0) + self.assertGreater(len(rollout_metrics), 0) self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2) actor_metrics = parser.metric_list("actor") - self.assertTrue(len(actor_metrics) > 0) + self.assertGreater(len(actor_metrics), 0) self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2) response_metrics = parser.metric_list("response_length") - self.assertTrue(len(response_metrics) > 0) + self.assertGreater(len(response_metrics), 0) self.assertEqual(parser.metric_max_step(response_metrics[0]), 2) ray.shutdown(_exiting_interpreter=True) # check save lastest checkpoint @@ -1208,9 +1278,9 @@ def test_trainer(self): checkpoint_root_path=self.config.checkpoint_job_dir, trainer_type=self.config.trainer.trainer_type, ) - self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_2, "actor"))) > 0) - self.assertTrue( - len(os.listdir(os.path.join(checkpoint_step_2, "actor", "lora_adapter"))) > 0 + self.assertGreater(len(os.listdir(os.path.join(checkpoint_step_2, "actor"))), 0) + self.assertGreater( + len(os.listdir(os.path.join(checkpoint_step_2, "actor", "lora_adapter"))), 0 ) self.assertEqual(step_num, 2) @@ -1224,7 +1294,7 @@ def test_trainer(self): parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) for prefix in ["eval", "bench"]: gsm8k_metrics = parser.metric_list(f"{prefix}/gsm8k") - self.assertTrue(len(gsm8k_metrics) > 0) + self.assertGreater(len(gsm8k_metrics), 0, f"{prefix}/gsm8k metrics not found") for eval_stats in ["mean", "best", "worst"]: for k in [2, 4, 8]: for stats in ["mean", "std"]: @@ -1233,7 +1303,7 @@ def test_trainer(self): self.assertEqual(metric_steps, [0, 2]) def tearDown(self): - shutil.rmtree(self.config.checkpoint_job_dir) + shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) class TestOverRollout(BaseTrainerCase): @@ -1264,29 +1334,29 @@ def test_trainer(self): both(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) rollout_metrics = parser.metric_list("rollout") - self.assertTrue(len(rollout_metrics) > 0) + self.assertGreater(len(rollout_metrics), 0) eval_metrics = parser.metric_list("eval") - self.assertTrue(len(eval_metrics) > 0) + self.assertGreater(len(eval_metrics), 0) self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2) self.assertTrue(parser.metric_exist("experience_pipeline/experience_count")) experience_counts = parser.metric_values("experience_pipeline/experience_count") - self.assertTrue(len(experience_counts) == 2) + self.assertEqual(len(experience_counts), 2) for count in experience_counts: - self.assertTrue( - count >= 2 * 4 + self.assertGreaterEqual( + count, 2 * 4 ) # at least process 2 tasks in each step, repeat_times is 4 pg_loss = parser.metric_values("actor/pg_loss") - self.assertTrue(len(pg_loss) >= 1) # trainer only has at least 1 step + self.assertGreaterEqual(len(pg_loss), 1) # trainer only has at least 1 step exp_save_path = self.config.buffer.trainer_input.experience_buffer.path with open(exp_save_path, "r", encoding="utf-8") as f: lines = f.readlines() - self.assertTrue( - len(lines) >= 2 * 4 * 2 + self.assertGreaterEqual( + len(lines), 2 * 4 * 2 ) # at least contain total_steps * repeat_times * batch_size * min_waited_tasks def tearDown(self): # remove dir only when the test passed - shutil.rmtree(self.config.checkpoint_job_dir) + shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) class TestTrainerPromptTruncation(BaseTrainerCase): @@ -1307,10 +1377,10 @@ def test_trainer(self): parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) rollout_metrics = parser.metric_list("rollout") - self.assertTrue(len(rollout_metrics) > 0) + self.assertGreater(len(rollout_metrics), 0) self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2) actor_metrics = parser.metric_list("actor") - self.assertTrue(len(actor_metrics) > 0) + self.assertGreater(len(actor_metrics), 0) self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2) max_prompt_length = parser.metric_values("prompt_length/max") self.assertEqual(max(max_prompt_length), 5) @@ -1327,7 +1397,7 @@ def test_trainer(self): def tearDown(self): # remove dir only when the test passed - shutil.rmtree(self.config.checkpoint_job_dir) + shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) class TestTinkerTrainer(BaseTrainerCase): @@ -1349,17 +1419,17 @@ def test_trainer(self): both(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) rollout_metrics = parser.metric_list("rollout") - self.assertTrue(len(rollout_metrics) > 0) + self.assertGreater(len(rollout_metrics), 0) pipeline_metrics = parser.metric_list("experience_pipeline") - self.assertTrue(len(pipeline_metrics) > 0) + self.assertGreater(len(pipeline_metrics), 0) self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) actor_metrics = parser.metric_list("actor") - self.assertTrue(len(actor_metrics) > 0) + self.assertGreater(len(actor_metrics), 0) self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4) response_metrics = parser.metric_list("response_length") - self.assertTrue(len(response_metrics) > 0) + self.assertGreater(len(response_metrics), 0) self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) def tearDown(self): # remove dir only when the test passed - shutil.rmtree(self.config.checkpoint_job_dir) + shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) diff --git a/trinity/trainer/verl/fsdp_checkpoint_manager.py b/trinity/trainer/verl/fsdp_checkpoint_manager.py index c168270366..84eadfe2b2 100644 --- a/trinity/trainer/verl/fsdp_checkpoint_manager.py +++ b/trinity/trainer/verl/fsdp_checkpoint_manager.py @@ -72,7 +72,11 @@ def __init__(self, *args, ray_namespace: str = "", **kwargs): self._optimizer_state_dict_thread = None self._extra_state_dict_thread = None self._save_model_thread = None - self.previous_state_dict_step = None + self.latest_model_save_step = None + self.latest_optimizer_save_step = None + self.latest_extra_state_save_step = None + self.latest_hf_model_save_step = None + self.latest_tokenizer_save_step = None def _upload_state_dict(self, state_dict: Union[dict, None], global_step: int): """ @@ -133,21 +137,65 @@ def _save(): thread.start() setattr(self, thread_name, thread) - def _save_model(self, local_path, global_step): + def _save_model(self, local_path, global_step) -> bool: + """ + Save the model state dict to the specified local path. + + Args: + local_path (str): The local path where the model state dict should be saved. + global_step (int): The current training step number. + + Returns: + bool: True if the model save operation was initiated, False if a save for + this global_step has already been performed. + """ + if self.latest_model_save_step == global_step: + return False + model_state_dict = self.model.state_dict() self._save_with_thread( model_state_dict, local_path, "model", "_model_state_dict_thread", global_step, True ) + self.latest_model_save_step = global_step + return True + + def _save_optimizer(self, local_path, global_step) -> bool: + """ + Save the optimizer state dict to the specified local path. - self.previous_state_dict_step = global_step + Args: + local_path (str): The local path where the optimizer state dict should be saved. + global_step (int): The current training step number. + + Returns: + bool: True if the optimizer save operation was initiated, False if a save for + this global_step has already been performed. + """ + if self.latest_optimizer_save_step == global_step: + return False - def _save_optimizer(self, local_path, global_step): optimizer_state_dict = self.optimizer.state_dict() self._save_with_thread( optimizer_state_dict, local_path, "optim", "_optimizer_state_dict_thread", global_step ) + self.latest_optimizer_save_step = global_step + return True + + def _save_extra_state(self, local_path, global_step) -> bool: + """ + Save the extra state dict to the specified local path. + + Args: + local_path (str): The local path where the extra state dict should be saved. + global_step (int): The current training step number. + + Returns: + bool: True if the extra state dict save operation was initiated, False if a save for + this global_step has already been performed. + """ + if self.latest_extra_state_save_step == global_step: + return False - def _save_extra_state(self, local_path, global_step): lr_scheduler_state_dict = ( self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None ) @@ -158,18 +206,176 @@ def _save_extra_state(self, local_path, global_step): self._save_with_thread( extra_state_dict, local_path, "extra_state", "_extra_state_dict_thread", global_step ) + self.latest_extra_state_save_step = global_step + return True + + def _get_model_config(self): + if fsdp_version(self.model) == 1: + unwrap_model = self.model._fsdp_wrapped_module + else: + unwrap_model = self.model + + model_config = unwrap_model.config + if ( + unwrap_model.can_generate() + and hasattr(model_config, "name_or_path") + and model_config.name_or_path + ): + # Some model's name_or_path is empty if not initialized from pretrained, + # in this cases, we don't save generation config. + generation_config = GenerationConfig.from_pretrained(model_config.name_or_path) + else: + generation_config = None + + return model_config, generation_config + + def _save_tokenizer(self, local_path, global_step): + """ + Save the tokenizer class to the specified local path. + + Args: + local_path (str): The local path where the tokenizer class should be saved. + global_step (int): The current training step number. + + Returns: + bool: True if the tokenizer save operation was initiated, False if a save for + this global_step has already been performed. + """ + if self.latest_tokenizer_save_step == global_step: + return False + + if self.rank == 0: + # Save HF tokenizer/processor and model config on rank 0 to huggingface/ directory, no matter whether + # huggingface model is requested to be saved or not. + + hf_config_tokenizer_path = os.path.join(local_path, "huggingface") + local_mkdir_safe(hf_config_tokenizer_path) + + model_config, generation_config = self._get_model_config() + model_config.save_pretrained(hf_config_tokenizer_path) + if generation_config is not None: + generation_config.save_pretrained(hf_config_tokenizer_path) + + self.processing_class.save_pretrained(hf_config_tokenizer_path) + log_with_rank( + f"Saved model config and tokenizer class to {os.path.abspath(hf_config_tokenizer_path)}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) + + # Also save runtime FSDP config + fsdp_config_path = os.path.join(local_path, "fsdp_config.json") + fsdp_config = FSDPConfig( + FSDP_version=fsdp_version(self.model), + world_size=self.world_size, + ) + with open(fsdp_config_path, "w") as f: + json.dump(asdict(fsdp_config), f, indent=4) + + # wait for everyone to dump to local + torch.distributed.barrier() + self.latest_tokenizer_save_step = global_step + + return self.rank == 0 + + def _save_hf_model(self, local_path, global_step) -> bool: + """ + Save the HuggingFace model to the specified local path. + + Args: + local_path (str): The local path where the HuggingFace model should be saved. + global_step (int): The current training step number. + + Returns: + bool: True if the HuggingFace model save operation was initiated, False if a save for + this global_step has already been performed. + """ + + if self.latest_hf_model_save_step == global_step: + return False + + # Only rank 0 will save hf model and, + # offload to cpu to save LLMs which may be too large to fit in one GPU + state_dict = get_fsdp_full_state_dict(self.model, offload_to_cpu=True, rank0_only=True) - def save_state_dict( # noqa: C901 + if self.rank == 0: + hf_local_path = os.path.join(local_path, "huggingface") + os.makedirs(hf_local_path, exist_ok=True) + + model_config, generation_config = self._get_model_config() + + if "ForTokenClassification" in model_config.architectures[0]: + from transformers import AutoModelForTokenClassification + + auto_model_cls = AutoModelForTokenClassification + elif "ForCausalLM" in model_config.architectures[0]: + from transformers import AutoModelForCausalLM + + auto_model_cls = AutoModelForCausalLM + elif "ForConditionalGeneration" in model_config.architectures[0]: + from transformers import AutoModelForVision2Seq + + auto_model_cls = AutoModelForVision2Seq + else: + raise NotImplementedError(f"Unknown architecture {model_config['architectures']}") + + with init_empty_weights(): + save_model = auto_model_cls.from_config(model_config, torch_dtype=torch.bfloat16) + save_model.to_empty(device="cpu") + + if save_model.can_generate(): + if generation_config is not None: + save_model.generation_config = generation_config + else: + logger.warning( + f"{self.__class__.__name__}.save_checkpoint: Generation config file not found in, " + "using a generation config created from the model config when saving hf_model." + ) + + if self._save_model_thread is not None: + self._save_model_thread.join() + + state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()} + + def _save_hf_model_thread_target(): + runtime_context = ray.get_runtime_context() + node_id = runtime_context.get_node_id() + job_id = runtime_context.get_job_id() + ray.get( + self.checkpoint_monitor.notify_started.remote(node_id=node_id, job_id=job_id) + ) + save_model.save_pretrained(hf_local_path, state_dict=state_dict) + log_with_rank( + f"Saved hf_model to {os.path.abspath(hf_local_path)}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) + ray.get(self.checkpoint_monitor.notify_finished.remote(global_step)) + + self._save_model_thread = threading.Thread( + target=_save_hf_model_thread_target, + ) + self._save_model_thread.start() + + # wait for rank0 to dump hf_model to local + torch.distributed.barrier() + self.latest_hf_model_save_step = global_step + + return self.rank == 0 + + def save_state_dict( self, local_path: str, global_step: int = 0, ): - if self.previous_state_dict_step is None: + if self.latest_model_save_step is None: # First sync in trainer.prepare - self.previous_state_dict_step = global_step + self.latest_model_save_step = global_step self._upload_state_dict(None, global_step) return - elif self.previous_state_dict_step == global_step: + elif self.latest_model_save_step == global_step: # No need to save for sync again return if local_path is None: @@ -192,7 +398,7 @@ def save_state_dict( # noqa: C901 ) ) - def save_checkpoint( # noqa: C901 + def save_checkpoint( self, local_path: str, global_step: int = 0, @@ -221,6 +427,7 @@ def save_checkpoint( # noqa: C901 # record the previous global step self.previous_global_step = global_step + local_path = local_mkdir_safe(local_path) # remove previous local_path, only rank 0 should do this if ( @@ -229,12 +436,12 @@ def save_checkpoint( # noqa: C901 and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 0 and len(self.previous_saved_paths) >= max_ckpt_to_keep # type: ignore - ): + and local_path != self.previous_saved_paths[-1] # type: ignore + ): # last step may save twice keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 # type: ignore self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) # type: ignore self.previous_saved_paths = self.previous_saved_paths[keep_start:] # type: ignore - local_path = local_mkdir_safe(local_path) torch.distributed.barrier() # check if the checkpoint_save_contents is valid @@ -259,134 +466,18 @@ def save_checkpoint( # noqa: C901 self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg ): if self.should_save_model: - if self.previous_state_dict_step != global_step: - state_dict_thread_count += 1 - self._save_model(local_path, global_step) + state_dict_thread_count += self._save_model(local_path, global_step) if self.should_save_optimizer: - checkpoint_thread_count += 1 - self._save_optimizer(local_path, global_step) + checkpoint_thread_count += self._save_optimizer(local_path, global_step) if self.should_save_extra: - checkpoint_thread_count += 1 - self._save_extra_state(local_path, global_step) - - if self.rank == 0: - # Save HF tokenizer/processor and model config on rank 0 to huggingface/ directory, no matter whether - # huggingface model is requested to be saved or not. - - if fsdp_version(self.model) == 1: - unwrap_model = self.model._fsdp_wrapped_module - else: - unwrap_model = self.model - - hf_config_tokenizer_path = os.path.join(local_path, "huggingface") - local_mkdir_safe(hf_config_tokenizer_path) - model_config = unwrap_model.config - if ( - unwrap_model.can_generate() - and hasattr(model_config, "name_or_path") - and model_config.name_or_path - ): - # Some model's name_or_path is empty if not initialized from pretrained, - # in this cases, we don't save generation config. - generation_config = GenerationConfig.from_pretrained(model_config.name_or_path) - generation_config.save_pretrained(hf_config_tokenizer_path) - else: - generation_config = None - - model_config.save_pretrained(hf_config_tokenizer_path) - self.processing_class.save_pretrained(hf_config_tokenizer_path) - log_with_rank( - f"Saved model config and tokenizer class to {os.path.abspath(hf_config_tokenizer_path)}", - rank=self.rank, - logger=logger, - log_only_rank_0=True, - ) - - # Also save runtime FSDP config - fsdp_config_path = os.path.join(local_path, "fsdp_config.json") - fsdp_config = FSDPConfig( - FSDP_version=fsdp_version(self.model), - world_size=self.world_size, - ) - with open(fsdp_config_path, "w") as f: - json.dump(asdict(fsdp_config), f, indent=4) + checkpoint_thread_count += self._save_extra_state(local_path, global_step) - # wait for everyone to dump to local - torch.distributed.barrier() + self._save_tokenizer(local_path, global_step) if self.should_save_hf_model or save_as_hf: - # Only rank 0 will save hf model and, - # offload to cpu to save LLMs which may be too large to fit in one GPU - state_dict = get_fsdp_full_state_dict(self.model, offload_to_cpu=True, rank0_only=True) - - if self.rank == 0: - checkpoint_thread_count += 1 - hf_local_path = os.path.join(local_path, "huggingface") - os.makedirs(hf_local_path, exist_ok=True) - - if "ForTokenClassification" in model_config.architectures[0]: - from transformers import AutoModelForTokenClassification - - auto_model_cls = AutoModelForTokenClassification - elif "ForCausalLM" in model_config.architectures[0]: - from transformers import AutoModelForCausalLM - - auto_model_cls = AutoModelForCausalLM - elif "ForConditionalGeneration" in model_config.architectures[0]: - from transformers import AutoModelForVision2Seq - - auto_model_cls = AutoModelForVision2Seq - else: - raise NotImplementedError( - f"Unknown architecture {model_config['architectures']}" - ) - - with init_empty_weights(): - save_model = auto_model_cls.from_config( - model_config, torch_dtype=torch.bfloat16 - ) - save_model.to_empty(device="cpu") - - if save_model.can_generate(): - if generation_config is not None: - save_model.generation_config = generation_config - else: - print( - f"Warning: {self.__class__.__name__}.save_checkpoint: Generation config file not found in, using a generation config created from the model config when saving hf_model." - ) - - if self._save_model_thread is not None: - self._save_model_thread.join() - - state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()} - - def _save_model(): - runtime_context = ray.get_runtime_context() - node_id = runtime_context.get_node_id() - job_id = runtime_context.get_job_id() - ray.get( - self.checkpoint_monitor.notify_started.remote( - node_id=node_id, job_id=job_id - ) - ) - save_model.save_pretrained(hf_local_path, state_dict=state_dict) - log_with_rank( - f"Saved hf_model to {os.path.abspath(hf_local_path)}", - rank=self.rank, - logger=logger, - log_only_rank_0=True, - ) - ray.get(self.checkpoint_monitor.notify_finished.remote(global_step)) - - self._save_model_thread = threading.Thread( - target=_save_model, - ) - self._save_model_thread.start() - - # wait for rank0 to dump hf_model to local - torch.distributed.barrier() + checkpoint_thread_count += self._save_hf_model(local_path, global_step) ray.get( self.checkpoint_monitor.register_thread_count.remote( @@ -395,7 +486,10 @@ def _save_model(): checkpoint_thread_count=checkpoint_thread_count, ) ) - self.previous_saved_paths.append(local_path) + if ( + len(self.previous_saved_paths) == 0 or local_path != self.previous_saved_paths[-1] + ): # last step may save twice + self.previous_saved_paths.append(local_path) def wait_on_save_thread(self) -> None: """ diff --git a/trinity/trainer/verl/megatron_checkpoint_manager.py b/trinity/trainer/verl/megatron_checkpoint_manager.py index 8674934452..9e53f1964c 100644 --- a/trinity/trainer/verl/megatron_checkpoint_manager.py +++ b/trinity/trainer/verl/megatron_checkpoint_manager.py @@ -63,9 +63,26 @@ def __init__( self.checkpoint_monitor = CheckpointMonitor.get_actor( namespace=ray_namespace, ) - self.previous_state_dict_step = None + self.latest_model_save_step = None + self.latest_tokenizer_save_step = None + self.latest_extra_state_save_step = None + self.latest_hf_model_save_step = None + + def _save_state_dict(self, local_path, global_step) -> bool: + """ + Save the model state dict to the specified local path. + + Args: + local_path (str): The local path where the model state dict should be saved. + global_step (int): The current training step number. + + Returns: + bool: True if the model save operation was initiated, False if a save for + this global_step has already been performed. + """ + if self.latest_model_save_step == global_step: + return False - def _save_state_dict(self, local_path, global_step): dist_checkpoint_path = get_dist_checkpoint_path(local_path) hf_ckpt_path = get_hf_model_checkpoint_path(local_path) @@ -144,18 +161,154 @@ def finalize_save_fn(): else: finalize_save_fn() - self.previous_state_dict_step = global_step + self.latest_model_save_step = global_step + return True + + def _save_tokenizer(self, local_path, global_step) -> bool: + """ + Save the tokenizer class to the specified local path. + + Args: + local_path (str): The local path where the tokenizer class should be saved. + global_step (int): The current training step number. + + Returns: + bool: True if the tokenizer save operation was initiated, False if a save for + this global_step has already been performed. + """ + if self.latest_tokenizer_save_step == global_step: + return False + + # Only rank 0 saves the hf config and tokenizer to huggingface path + # No matter whether we save hf model or not + if self.rank == 0: + # Save tokenizer + hf_config_tokenizer_path = get_hf_model_checkpoint_path(local_path) + self.processing_class.save_pretrained(hf_config_tokenizer_path) + log_with_rank( + f"Saved Huggingface tokenizer to {hf_config_tokenizer_path}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) + + self.latest_tokenizer_save_step = global_step + return self.rank == 0 + + def _save_extra_state(self, local_path, global_step) -> bool: + """ + Save the extra state dict to the specified local path. + + Args: + local_path (str): The local path where the extra state dict should be saved. + global_step (int): The current training step number. + + Returns: + bool: True if the extra state dict save operation was initiated, False if a save for + this global_step has already been performed. + """ + if self.latest_extra_state_save_step == global_step: + return False + + if self.rank == 0: + # Save transformer config + log_with_rank( + f"Transformer config: {self.transformer_config}", rank=self.rank, logger=logger + ) + transformer_config_dict = asdict(self.transformer_config) + to_convert_types = {torch.dtype: str, AttnBackend: str} + ignore_types = [Callable] + pop_keys = [] + for key, value in transformer_config_dict.items(): + if type(value) in to_convert_types: + transformer_config_dict[key] = to_convert_types[type(value)](value) + if type(value) in ignore_types: + pop_keys.append(key) + if callable(value): + pop_keys.append(key) + for key in pop_keys: + transformer_config_dict.pop(key) + transformer_config_path = get_transformer_config_checkpoint_path(local_path) + with open(transformer_config_path, "w") as f: + json.dump(transformer_config_dict, f, indent=2) + + return self.rank == 0 + + def _save_hf_model(self, local_path, global_step) -> bool: + """ + Save the Huggingface model to the specified local path. + + Args: + local_path (str): The local path where the Huggingface model should be saved. + global_step (int): The current training step number. + + Returns: + bool: True if the Huggingface model save operation was initiated, False if a save for + this global_step has already been performed. + """ + if self.latest_hf_model_save_step == global_step: + return False + + try: + # wait for everyone to dump to local + state_dict = self.weight_saver( + self.model, + self.hf_config, + dtype=self.param_dtype, + is_value_model=self.is_value_model, + tie_word_embeddings=self.share_embeddings_and_output_weights, + ) + + torch.distributed.barrier() + if self.rank == 0: + # TODO: async save or use mbridge to save hf model + hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) + import warnings + + from accelerate import init_empty_weights + + with init_empty_weights(), warnings.catch_warnings(): + warnings.simplefilter("ignore") + if "mistral7b-rm" in self.config.model.path: + from transformers import MistralForSequenceClassification + + model = MistralForSequenceClassification.from_pretrained( + self.config.model.path, torch_dtype=torch.bfloat16 + ) # use score head instead of lm_head + state_dict["score.weight"] = state_dict["score.weight"] + else: + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained( + self.config.model.path, torch_dtype=torch.bfloat16 + ) + state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()} + model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict) + log_with_rank( + f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) + except Exception: + logger.error( + f"Failed to save Huggingface model to {local_path}, you can try to set `use_mbridge=true` to save it.", + exc_info=True, + ) + + self.latest_hf_model_save_step = global_step + return self.rank == 0 def save_state_dict( # noqa: C901 self, local_path: str, global_step: int = 0, ): - if self.previous_state_dict_step is None: + if self.latest_model_save_step is None: # First sync in trainer.prepare - self.previous_state_dict_step = global_step + self.latest_model_save_step = global_step return - elif self.previous_state_dict_step == global_step: + elif self.latest_model_save_step == global_step: # No need to save for sync again return @@ -167,7 +320,7 @@ def save_state_dict( # noqa: C901 ) ) - def save_checkpoint( # noqa: C901 + def save_checkpoint( self, local_path: str, global_step: int = 0, @@ -176,6 +329,7 @@ def save_checkpoint( # noqa: C901 ): # record the previous global step self.previous_global_step = global_step + local_path = local_mkdir_safe(local_path) # remove previous local_path if ( @@ -183,106 +337,32 @@ def save_checkpoint( # noqa: C901 and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 0 and len(self.previous_saved_paths) >= max_ckpt_to_keep # type: ignore - ): + and local_path != self.previous_saved_paths[-1] # type: ignore + ): # last step may save twice keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 # type: ignore self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) # type: ignore self.previous_saved_paths = self.previous_saved_paths[keep_start:] # type: ignore - local_path = local_mkdir_safe(local_path) + torch.distributed.barrier() state_dict_thread_count = 0 if self.should_save_model: - if self.previous_state_dict_step != global_step: - self._save_state_dict(local_path, global_step) - state_dict_thread_count += 1 + state_dict_thread_count += self._save_state_dict(local_path, global_step) - # Only rank 0 saves the hf config and tokenizer to huggingface path - # No matter whether we save hf model or not - if self.rank == 0: - # Save tokenizer - hf_config_tokenizer_path = get_hf_model_checkpoint_path(local_path) - self.processing_class.save_pretrained(hf_config_tokenizer_path) - log_with_rank( - f"Saved Huggingface tokenizer to {hf_config_tokenizer_path}", - rank=self.rank, - logger=logger, - log_only_rank_0=True, - ) + self._save_tokenizer(local_path, global_step) if self.should_save_extra: - if self.rank == 0: - # Save transformer config - log_with_rank( - f"Transformer config: {self.transformer_config}", rank=self.rank, logger=logger - ) - transformer_config_dict = asdict(self.transformer_config) - to_convert_types = {torch.dtype: str, AttnBackend: str} - ignore_types = [Callable] - pop_keys = [] - for key, value in transformer_config_dict.items(): - if type(value) in to_convert_types: - transformer_config_dict[key] = to_convert_types[type(value)](value) - if type(value) in ignore_types: - pop_keys.append(key) - if callable(value): - pop_keys.append(key) - for key in pop_keys: - transformer_config_dict.pop(key) - transformer_config_path = get_transformer_config_checkpoint_path(local_path) - with open(transformer_config_path, "w") as f: - json.dump(transformer_config_dict, f, indent=2) + self._save_extra_state(local_path, global_step) if self.should_save_hf_model or save_as_hf: - try: - # wait for everyone to dump to local - state_dict = self.weight_saver( - self.model, - self.hf_config, - dtype=self.param_dtype, - is_value_model=self.is_value_model, - tie_word_embeddings=self.share_embeddings_and_output_weights, - ) - - torch.distributed.barrier() - if self.rank == 0: - # TODO: async save or use mbridge to save hf model - hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) - import warnings - - from accelerate import init_empty_weights - - with init_empty_weights(), warnings.catch_warnings(): - warnings.simplefilter("ignore") - if "mistral7b-rm" in self.config.model.path: - from transformers import MistralForSequenceClassification - - model = MistralForSequenceClassification.from_pretrained( - self.config.model.path, torch_dtype=torch.bfloat16 - ) # use score head instead of lm_head - state_dict["score.weight"] = state_dict["score.weight"] - else: - from transformers import AutoModelForCausalLM - - model = AutoModelForCausalLM.from_pretrained( - self.config.model.path, torch_dtype=torch.bfloat16 - ) - state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()} - model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict) - log_with_rank( - f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}", - rank=self.rank, - logger=logger, - log_only_rank_0=True, - ) - except Exception: - logger.error( - f"Failed to save Huggingface model to {local_path}, you can try to set `use_mbridge=true` to save it.", - exc_info=True, - ) + self._save_hf_model(local_path, global_step) ray.get( self.checkpoint_monitor.register_thread_count.remote( global_step, state_dict_thread_count=state_dict_thread_count ) ) - self.previous_saved_paths.append(local_path) + if ( + len(self.previous_saved_paths) == 0 or local_path != self.previous_saved_paths[-1] + ): # last step may save twice + self.previous_saved_paths.append(local_path) diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 9c52af3d66..8583fd3363 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -271,7 +271,6 @@ def __init__( processor=processor, ) self.init_workers() - self.last_full_save_step = None def _validate_config(self): # TODO algorithm = ALGORITHM_TYPE.get(self.algorithm_config.algorithm_type) @@ -483,9 +482,7 @@ async def train_step(self, batch_exps: List[Experience]) -> Dict: # noqa C901 return metrics def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = False) -> None: - if self.last_full_save_step != self.global_steps: - self.last_full_save_step = self.global_steps - self._save_checkpoint(save_as_hf=save_as_hf) + self._save_checkpoint(save_as_hf=save_as_hf) if block_until_saved: self.actor_rollout_wg.wait_on_save_thread() if self.algorithm and self.algorithm.use_critic: