diff --git a/src/executorlib/standalone/batched.py b/src/executorlib/standalone/batched.py index 3c44c0c7..a377c561 100644 --- a/src/executorlib/standalone/batched.py +++ b/src/executorlib/standalone/batched.py @@ -1,5 +1,9 @@ from concurrent.futures import Future +# Future objects we have already reported as failed -- so each failed job is logged once, not on +# every scheduler pass (batched_futures is re-evaluated many times until a batch fills). +_logged_failed_ids: set = set() + def batched_futures( lst: list[Future], nested_skip_lst: list[Future[list]], n: int @@ -9,6 +13,14 @@ def batched_futures( not reached yet, then an empty list is returned. If n future objects are done, which are not included in the skip_set then they are returned as batch. + Futures that completed with an EXCEPTION (e.g. a labeling job that failed on a degenerate config, or a dead worker) + are EXCLUDED from the batch rather than re-raised. Calling ``.result()`` on a failed future re-raises its exception; + in the dependency scheduler that turns into ``set_exception`` on this batch future, which then cascades to every + downstream task (combine_b / featurize / fit / cost / pareto) depending on the batch -- i.e. a single bad config + silently kills the whole pipeline. Each failed future is logged once. When the entire input is resolved but a full + batch of n cannot be formed (because some futures failed), the partial remainder is returned so the pipeline does + not stall forever waiting for a batch that can never fill. + Args: lst (list): list of all future objects nested_skip_lst (list): nest list of individual results already assigned to previous batches @@ -17,13 +29,28 @@ def batched_futures( Returns: list: results of the batched futures """ - skip_set = {id(item) for f in nested_skip_lst for item in f.result()} + skipped_ids = {id(item) for items in skip_lst for item in items} - done_lst = [] - n_expected = min(n, len(lst) - len(skip_set)) + done_lst: list = [] + all_resolved = True for v in lst: - if v.done() and id(v.result()) not in skip_set: - done_lst.append(v.result()) - if len(done_lst) == n_expected: - return done_lst + if v.done(): + if v.exception() is not None: + if id(v) not in _logged_failed_ids: + _logged_failed_ids.add(id(v)) + print( + f"[batched_futures] EXCLUDING failed future from batch: " + f"{type(v.exception()).__name__}: {v.exception()}", + flush=True, + ) + continue # failed future: exclude instead of re-raising (which would poison all dependents) + result = v.result() + if id(result) not in skipped_ids: + done_lst.append(result) + if len(done_lst) == n: + return done_lst + else: + all_resolved = False + if all_resolved and len(done_lst) > 0: + return done_lst # end of input reached; emit final (possibly short) batch return [] diff --git a/src/executorlib/task_scheduler/interactive/dependency.py b/src/executorlib/task_scheduler/interactive/dependency.py index 349b3f1c..5ff8f23c 100644 --- a/src/executorlib/task_scheduler/interactive/dependency.py +++ b/src/executorlib/task_scheduler/interactive/dependency.py @@ -281,6 +281,17 @@ def _execute_tasks_with_dependencies( task_dict["future"].set_result(False) else: task_dict["future"].set_result(True) + elif ( # batched collector: readiness is its skip_lst + batched_futures (which scans `lst` + # once, when ready, preserving completion-order). Do NOT run get_future_objects_from_input + # on kwargs -- `lst` can be 100k+ futures, making ingestion (and every wait-list pass) O(N) + # per collector and stalling the scheduler. Track only the small skip_lst as future_lst. + task_dict is not None + and task_dict.get("fn") == "batched" + and "future" in task_dict + ): + task_dict["future_lst"] = task_dict["kwargs"]["skip_lst"] + wait_lst.append(task_dict) + future_queue.task_done() elif ( # handle function submitted to the executor task_dict is not None and "fn" in task_dict and "future" in task_dict ): @@ -343,11 +354,18 @@ def _update_waiting_task( elif task_wait_dict["fn"] == "batched" and all( future.done() for future in task_wait_dict["kwargs"]["skip_lst"] ): - done_lst = batched_futures( - lst=task_wait_dict["kwargs"]["lst"], - n=task_wait_dict["kwargs"]["n"], - nested_skip_lst=task_wait_dict["kwargs"]["skip_lst"], - ) + try: + done_lst = batched_futures( + lst=task_wait_dict["kwargs"]["lst"], + n=task_wait_dict["kwargs"]["n"], + skip_lst=[f.result() for f in task_wait_dict["kwargs"]["skip_lst"]], + ) + except Exception as exc: + # A future in `lst` (or skip_lst) raised. Propagate to the batch future instead of + # crashing the scheduler thread. (We no longer scan all of `lst` for exceptions via + # future_lst for performance, so batched_futures' .result() is where they surface.) + task_wait_dict["future"].set_exception(exc) + continue if len(done_lst) == 0: wait_tmp_lst.append(task_wait_dict) else: diff --git a/tests/unit/executor/test_single_dependencies.py b/tests/unit/executor/test_single_dependencies.py index 98e1b14c..905c8446 100644 --- a/tests/unit/executor/test_single_dependencies.py +++ b/tests/unit/executor/test_single_dependencies.py @@ -3,10 +3,14 @@ from time import sleep, time from queue import Queue from threading import Thread +from unittest.mock import MagicMock from executorlib import SingleNodeExecutor from executorlib.executor.single import create_single_node_executor -from executorlib.task_scheduler.interactive.dependency import _execute_tasks_with_dependencies +from executorlib.task_scheduler.interactive.dependency import ( + _execute_tasks_with_dependencies, + _update_waiting_task, +) from executorlib.standalone.serialize import cloudpickle_register from executorlib.standalone.interactive.spawner import MpiExecSpawner @@ -82,6 +86,23 @@ def test_batched(self): self.assertEqual(len(result_lst), 4) self.assertTrue(t3-t2 > t2-t1) + def test_batched_with_failed_upstream_future(self): + """A failed future in lst must be excluded from batches; downstream must not see an exception.""" + # 5 successful futures (returning 0–4) + 1 failed = 6 total → 2 batch futures (n=3). + # Expected batches (in completion order): [0,1,2] and [3,4] (partial, all_resolved). + with SingleNodeExecutor() as exe: + cloudpickle_register(ind=1) + future_lst = [] + for i in range(5): + future_lst.append(exe.submit(return_input_dict, i)) + future_lst.append(exe.submit(raise_error, parameter=0)) + future_second_lst = exe.batched(future_lst, n=3) + result_lst = [f.result() for f in future_second_lst] + # All batch futures must succeed (no exception cascaded from the failed input) + self.assertEqual(len(result_lst), 2) + # The union of all batched results must be exactly {0, 1, 2, 3, 4} + self.assertEqual(set(item for batch in result_lst for item in batch), {0, 1, 2, 3, 4}) + def test_batched_error(self): with self.assertRaises(TypeError): with SingleNodeExecutor() as exe: @@ -283,6 +304,43 @@ def test_future_input_dict(self): ) self.assertEqual(fs.result()["a"], 4) + def test_update_waiting_task_batched_exception(self): + """_update_waiting_task catches exceptions from batched_futures and sets them on the batch future.""" + executor_queue = Queue() + batch_future = Future() + + # A mock skip_lst future: done(), exception() returns None (passes get_exception_lst), + # but result() raises -- triggering the except block in _update_waiting_task. + mock_skip_future = MagicMock() + mock_skip_future.done.return_value = True + mock_skip_future.exception.return_value = None + mock_skip_future.result.side_effect = RuntimeError("unexpected skip error") + + task_dict = { + "fn": "batched", + "args": (), + "kwargs": { + "lst": [], + "n": 3, + "skip_lst": [mock_skip_future], + }, + "future": batch_future, + "future_lst": [mock_skip_future], + "resource_dict": {}, + } + + result_lst = _update_waiting_task( + wait_lst=[task_dict], + executor_queue=executor_queue, + refresh_rate=0.0, + ) + + # The batch future must have the exception propagated (not crashed the scheduler) + self.assertTrue(batch_future.done()) + self.assertIsInstance(batch_future.exception(), RuntimeError) + # The failed task is consumed (not re-queued in the wait list) + self.assertEqual(len(result_lst), 0) + class TestExecutorErrors(unittest.TestCase): def test_block_allocation_false_one_worker(self): diff --git a/tests/unit/standalone/test_batched.py b/tests/unit/standalone/test_batched.py index 31e3d578..5b4e18ec 100644 --- a/tests/unit/standalone/test_batched.py +++ b/tests/unit/standalone/test_batched.py @@ -1,9 +1,15 @@ from unittest import TestCase from concurrent.futures import Future -from executorlib.standalone.batched import batched_futures +from executorlib.standalone.batched import batched_futures, _logged_failed_ids class TestBatched(TestCase): + def setUp(self): + _logged_failed_ids.clear() + + def tearDown(self): + _logged_failed_ids.clear() + def test_batched_futures(self): lst = [] for i in list(range(10)): @@ -24,4 +30,46 @@ def test_batched_futures_not_finished(self): for _ in list(range(10)): f = Future() lst.append(f) - self.assertEqual(batched_futures(lst=lst, n=3, nested_skip_lst=set()), []) + self.assertEqual(batched_futures(lst=lst, n=3, skip_lst=[]), []) + + def test_batched_futures_with_failed_future(self): + """Failed futures are excluded from the batch rather than raising.""" + lst = [] + for i in range(5): + f = Future() + f.set_result(i) + lst.append(f) + f_failed = Future() + f_failed.set_exception(RuntimeError("task failed")) + lst.insert(2, f_failed) # insert at position 2: [0, 1, FAILED, 2, 3, 4] + # The failed future must not propagate; first 3 successful results are returned + result = batched_futures(lst=lst, n=3, skip_lst=[]) + self.assertEqual(result, [0, 1, 2]) + # The failed future's id is recorded so it is only logged once + self.assertIn(id(f_failed), _logged_failed_ids) + + def test_batched_futures_failed_future_logged_once(self): + """A failed future is only logged once, even across multiple calls.""" + f_failed = Future() + f_failed.set_exception(RuntimeError("task failed")) + lst = [f_failed] + batched_futures(lst=lst, n=1, skip_lst=[]) + self.assertIn(id(f_failed), _logged_failed_ids) + size_after_first_call = len(_logged_failed_ids) + # Second call must not add the id again + batched_futures(lst=lst, n=1, skip_lst=[]) + self.assertEqual(len(_logged_failed_ids), size_after_first_call) + + def test_batched_futures_partial_batch_due_to_failures(self): + """Emit a partial batch when all futures are resolved but n is unreachable due to failures.""" + lst = [] + for i in range(2): + f = Future() + f.set_result(i) + lst.append(f) + f_failed = Future() + f_failed.set_exception(RuntimeError("task failed")) + lst.append(f_failed) + # all_resolved=True, only 2 successful results remain — must emit partial batch [0, 1] + result = batched_futures(lst=lst, n=3, skip_lst=[]) + self.assertEqual(result, [0, 1])