Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions src/executorlib/standalone/batched.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from concurrent.futures import Future

# Future objects we have already reported as failed -- so each failed job is logged once, not on
# every scheduler pass (batched_futures is re-evaluated many times until a batch fills).
_logged_failed_ids: set = set()


def batched_futures(
lst: list[Future], nested_skip_lst: list[Future[list]], n: int
Expand All @@ -9,6 +13,14 @@ def batched_futures(
not reached yet, then an empty list is returned. If n future objects are done, which are not included in the skip_set
then they are returned as batch.

Futures that completed with an EXCEPTION (e.g. a labeling job that failed on a degenerate config, or a dead worker)
are EXCLUDED from the batch rather than re-raised. Calling ``.result()`` on a failed future re-raises its exception;
in the dependency scheduler that turns into ``set_exception`` on this batch future, which then cascades to every
downstream task (combine_b / featurize / fit / cost / pareto) depending on the batch -- i.e. a single bad config
silently kills the whole pipeline. Each failed future is logged once. When the entire input is resolved but a full
batch of n cannot be formed (because some futures failed), the partial remainder is returned so the pipeline does
not stall forever waiting for a batch that can never fill.

Args:
lst (list): list of all future objects
nested_skip_lst (list): nest list of individual results already assigned to previous batches
Expand All @@ -17,13 +29,28 @@ def batched_futures(
Returns:
list: results of the batched futures
"""
skip_set = {id(item) for f in nested_skip_lst for item in f.result()}
skipped_ids = {id(item) for items in skip_lst for item in items}

done_lst = []
n_expected = min(n, len(lst) - len(skip_set))
done_lst: list = []
all_resolved = True
for v in lst:
if v.done() and id(v.result()) not in skip_set:
done_lst.append(v.result())
if len(done_lst) == n_expected:
return done_lst
if v.done():
if v.exception() is not None:
if id(v) not in _logged_failed_ids:
_logged_failed_ids.add(id(v))
print(
f"[batched_futures] EXCLUDING failed future from batch: "
f"{type(v.exception()).__name__}: {v.exception()}",
flush=True,
)
continue # failed future: exclude instead of re-raising (which would poison all dependents)
result = v.result()
if id(result) not in skipped_ids:
done_lst.append(result)
if len(done_lst) == n:
return done_lst
else:
all_resolved = False
if all_resolved and len(done_lst) > 0:
return done_lst # end of input reached; emit final (possibly short) batch
return []
28 changes: 23 additions & 5 deletions src/executorlib/task_scheduler/interactive/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,17 @@ def _execute_tasks_with_dependencies(
task_dict["future"].set_result(False)
else:
task_dict["future"].set_result(True)
elif ( # batched collector: readiness is its skip_lst + batched_futures (which scans `lst`
# once, when ready, preserving completion-order). Do NOT run get_future_objects_from_input
# on kwargs -- `lst` can be 100k+ futures, making ingestion (and every wait-list pass) O(N)
# per collector and stalling the scheduler. Track only the small skip_lst as future_lst.
task_dict is not None
and task_dict.get("fn") == "batched"
and "future" in task_dict
):
task_dict["future_lst"] = task_dict["kwargs"]["skip_lst"]
wait_lst.append(task_dict)
future_queue.task_done()
elif ( # handle function submitted to the executor
task_dict is not None and "fn" in task_dict and "future" in task_dict
):
Expand Down Expand Up @@ -343,11 +354,18 @@ def _update_waiting_task(
elif task_wait_dict["fn"] == "batched" and all(
future.done() for future in task_wait_dict["kwargs"]["skip_lst"]
):
done_lst = batched_futures(
lst=task_wait_dict["kwargs"]["lst"],
n=task_wait_dict["kwargs"]["n"],
nested_skip_lst=task_wait_dict["kwargs"]["skip_lst"],
)
try:
done_lst = batched_futures(
lst=task_wait_dict["kwargs"]["lst"],
n=task_wait_dict["kwargs"]["n"],
skip_lst=[f.result() for f in task_wait_dict["kwargs"]["skip_lst"]],
)
except Exception as exc:
# A future in `lst` (or skip_lst) raised. Propagate to the batch future instead of
# crashing the scheduler thread. (We no longer scan all of `lst` for exceptions via
# future_lst for performance, so batched_futures' .result() is where they surface.)
task_wait_dict["future"].set_exception(exc)
continue
if len(done_lst) == 0:
wait_tmp_lst.append(task_wait_dict)
else:
Expand Down
60 changes: 59 additions & 1 deletion tests/unit/executor/test_single_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
from time import sleep, time
from queue import Queue
from threading import Thread
from unittest.mock import MagicMock

from executorlib import SingleNodeExecutor
from executorlib.executor.single import create_single_node_executor
from executorlib.task_scheduler.interactive.dependency import _execute_tasks_with_dependencies
from executorlib.task_scheduler.interactive.dependency import (
_execute_tasks_with_dependencies,
_update_waiting_task,
)
from executorlib.standalone.serialize import cloudpickle_register
from executorlib.standalone.interactive.spawner import MpiExecSpawner

Expand Down Expand Up @@ -82,6 +86,23 @@ def test_batched(self):
self.assertEqual(len(result_lst), 4)
self.assertTrue(t3-t2 > t2-t1)

def test_batched_with_failed_upstream_future(self):
"""A failed future in lst must be excluded from batches; downstream must not see an exception."""
# 5 successful futures (returning 0–4) + 1 failed = 6 total → 2 batch futures (n=3).
# Expected batches (in completion order): [0,1,2] and [3,4] (partial, all_resolved).
with SingleNodeExecutor() as exe:
cloudpickle_register(ind=1)
future_lst = []
for i in range(5):
future_lst.append(exe.submit(return_input_dict, i))
future_lst.append(exe.submit(raise_error, parameter=0))
future_second_lst = exe.batched(future_lst, n=3)
result_lst = [f.result() for f in future_second_lst]
# All batch futures must succeed (no exception cascaded from the failed input)
self.assertEqual(len(result_lst), 2)
# The union of all batched results must be exactly {0, 1, 2, 3, 4}
self.assertEqual(set(item for batch in result_lst for item in batch), {0, 1, 2, 3, 4})

def test_batched_error(self):
with self.assertRaises(TypeError):
with SingleNodeExecutor() as exe:
Expand Down Expand Up @@ -283,6 +304,43 @@ def test_future_input_dict(self):
)
self.assertEqual(fs.result()["a"], 4)

def test_update_waiting_task_batched_exception(self):
"""_update_waiting_task catches exceptions from batched_futures and sets them on the batch future."""
executor_queue = Queue()
batch_future = Future()

# A mock skip_lst future: done(), exception() returns None (passes get_exception_lst),
# but result() raises -- triggering the except block in _update_waiting_task.
mock_skip_future = MagicMock()
mock_skip_future.done.return_value = True
mock_skip_future.exception.return_value = None
mock_skip_future.result.side_effect = RuntimeError("unexpected skip error")

task_dict = {
"fn": "batched",
"args": (),
"kwargs": {
"lst": [],
"n": 3,
"skip_lst": [mock_skip_future],
},
"future": batch_future,
"future_lst": [mock_skip_future],
"resource_dict": {},
}

result_lst = _update_waiting_task(
wait_lst=[task_dict],
executor_queue=executor_queue,
refresh_rate=0.0,
)

# The batch future must have the exception propagated (not crashed the scheduler)
self.assertTrue(batch_future.done())
self.assertIsInstance(batch_future.exception(), RuntimeError)
# The failed task is consumed (not re-queued in the wait list)
self.assertEqual(len(result_lst), 0)


class TestExecutorErrors(unittest.TestCase):
def test_block_allocation_false_one_worker(self):
Expand Down
52 changes: 50 additions & 2 deletions tests/unit/standalone/test_batched.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from unittest import TestCase
from concurrent.futures import Future
from executorlib.standalone.batched import batched_futures
from executorlib.standalone.batched import batched_futures, _logged_failed_ids


class TestBatched(TestCase):
def setUp(self):
_logged_failed_ids.clear()

def tearDown(self):
_logged_failed_ids.clear()

def test_batched_futures(self):
lst = []
for i in list(range(10)):
Expand All @@ -24,4 +30,46 @@ def test_batched_futures_not_finished(self):
for _ in list(range(10)):
f = Future()
lst.append(f)
self.assertEqual(batched_futures(lst=lst, n=3, nested_skip_lst=set()), [])
self.assertEqual(batched_futures(lst=lst, n=3, skip_lst=[]), [])

def test_batched_futures_with_failed_future(self):
"""Failed futures are excluded from the batch rather than raising."""
lst = []
for i in range(5):
f = Future()
f.set_result(i)
lst.append(f)
f_failed = Future()
f_failed.set_exception(RuntimeError("task failed"))
lst.insert(2, f_failed) # insert at position 2: [0, 1, FAILED, 2, 3, 4]
# The failed future must not propagate; first 3 successful results are returned
result = batched_futures(lst=lst, n=3, skip_lst=[])
self.assertEqual(result, [0, 1, 2])
# The failed future's id is recorded so it is only logged once
self.assertIn(id(f_failed), _logged_failed_ids)

def test_batched_futures_failed_future_logged_once(self):
"""A failed future is only logged once, even across multiple calls."""
f_failed = Future()
f_failed.set_exception(RuntimeError("task failed"))
lst = [f_failed]
batched_futures(lst=lst, n=1, skip_lst=[])
self.assertIn(id(f_failed), _logged_failed_ids)
size_after_first_call = len(_logged_failed_ids)
# Second call must not add the id again
batched_futures(lst=lst, n=1, skip_lst=[])
self.assertEqual(len(_logged_failed_ids), size_after_first_call)

def test_batched_futures_partial_batch_due_to_failures(self):
"""Emit a partial batch when all futures are resolved but n is unreachable due to failures."""
lst = []
for i in range(2):
f = Future()
f.set_result(i)
lst.append(f)
f_failed = Future()
f_failed.set_exception(RuntimeError("task failed"))
lst.append(f_failed)
# all_resolved=True, only 2 successful results remain — must emit partial batch [0, 1]
result = batched_futures(lst=lst, n=3, skip_lst=[])
self.assertEqual(result, [0, 1])
Loading