Skip to content

Commit 9d3ee57

Browse files
committed
limit val_max_num_task
1 parent fe501a2 commit 9d3ee57

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

ajet/backbone/trainer_verl.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,5 +1042,24 @@ def get_val_dataset(self):
10421042
self.config.ajet.task_reader,
10431043
)
10441044
tasks = task_reader.get_validation_tasks()
1045+
1046+
# clip validation tasks if val_max_num_task_each_validation is set
1047+
val_max_num_task = self.config.ajet.trainer_common.val_max_num_task_each_validation
1048+
if val_max_num_task is not None and len(tasks) > val_max_num_task:
1049+
original_size = len(tasks)
1050+
clip_method = self.config.ajet.trainer_common.val_max_num_task_clip_method
1051+
if clip_method == "fix_seed_random_n":
1052+
rng = np.random.RandomState(seed=42)
1053+
indices = rng.choice(len(tasks), val_max_num_task, replace=False)
1054+
tasks = [tasks[i] for i in sorted(indices)]
1055+
elif clip_method == "random_n":
1056+
indices = np.random.choice(len(tasks), val_max_num_task, replace=False)
1057+
tasks = [tasks[i] for i in sorted(indices)]
1058+
elif clip_method == "first_n":
1059+
tasks = tasks[:val_max_num_task]
1060+
else:
1061+
raise ValueError(f"Unknown val_max_num_task_clip_method: {clip_method}, expected 'fix_seed_random_n', 'random_n', or 'first_n'")
1062+
logger.info(f"Clipped validation dataset from {original_size} to {val_max_num_task} tasks using '{clip_method}'")
1063+
10451064
self.main_val_dataset = tasks
10461065
return self.main_val_dataset

ajet/default_config/ajet_default.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ ajet:
235235
val_before_train: False
236236
val_pass_n: 4
237237
val_only: False
238+
val_max_num_task_each_validation: null
239+
val_max_num_task_clip_method: fix_seed_random_n # fix_seed_random_n, random_n, or first_n
238240
val_print_to_markdown_file_path: null
239241
train_print_to_markdown_file_path: null
240242

ajet/utils/launch_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,9 @@ def start_ray_service(args, env, cluster=False):
219219
# Get the current Python interpreter directory
220220
python_dir = os.path.dirname(sys.executable)
221221
ray_path = os.path.join(python_dir, "ray")
222+
# if CUDA_VISIBLE_DEVICES is set, remove it from `env`
223+
if "CUDA_VISIBLE_DEVICES" in env:
224+
del env["CUDA_VISIBLE_DEVICES"]
222225
if not cluster:
223226
companion = LaunchCommandWhenAbsent(
224227
full_argument_list=[f"{ray_path} start --head --block"],
@@ -252,7 +255,7 @@ def start_ray_service(args, env, cluster=False):
252255
tag="ray_service_worker",
253256
use_pty=True,
254257
)
255-
launch_wait_time = 9999999999
258+
launch_wait_time = 9999999999999
256259
# success_std_string = "Connected to Ray cluster"
257260
success_std_string = "Just wait here forever"
258261
companion.launch(

0 commit comments

Comments
 (0)