diff --git a/src/executorlib/standalone/interactive/arguments.py b/src/executorlib/standalone/interactive/arguments.py index f51f0b02..7f0f7d89 100644 --- a/src/executorlib/standalone/interactive/arguments.py +++ b/src/executorlib/standalone/interactive/arguments.py @@ -27,10 +27,21 @@ def find_future_in_list(lst): find_future_in_list(lst=args) find_future_in_list(lst=kwargs.values()) - boolean_flag = len([future for future in future_lst if future.done()]) == len( - future_lst - ) - return future_lst, boolean_flag + + return future_lst + + +def check_list_of_futures_is_done(future_lst: list[Future]) -> bool: + """ + Check if all future objects in the list of future objects are done + + Args: + future_lst (list): list of future objects + + Returns: + bool: True if all future objects in the list of future objects are done, False otherwise + """ + return len([future for future in future_lst if future.done()]) == len(future_lst) def get_exception_lst(future_lst: list[Future]) -> list: diff --git a/src/executorlib/task_scheduler/interactive/dependency.py b/src/executorlib/task_scheduler/interactive/dependency.py index e4d96fc4..e781195d 100644 --- a/src/executorlib/task_scheduler/interactive/dependency.py +++ b/src/executorlib/task_scheduler/interactive/dependency.py @@ -7,6 +7,7 @@ from executorlib.standalone.batched import batched_futures from executorlib.standalone.interactive.arguments import ( check_exception_was_raised, + check_list_of_futures_is_done, get_exception_lst, get_future_objects_from_input, update_futures_in_input, @@ -185,6 +186,7 @@ def batched( "args": (), "kwargs": {"lst": iterable, "n": n, "skip_lst": skip_lst}, "future": f, + "future_lst": iterable, "future_skip": f_skip, "resource_dict": {}, } @@ -249,7 +251,7 @@ def _execute_tasks_with_dependencies( executor (TaskSchedulerBase): Executor to execute the tasks with after the dependencies are resolved. refresh_rate (float): Set the refresh rate in seconds, how frequently the input queue is checked. """ - wait_lst: list = [] + future_dependency_lst: list = [] while True: try: task_dict = future_queue.get_nowait() @@ -258,10 +260,10 @@ def _execute_tasks_with_dependencies( if ( # shutdown the executor task_dict is not None and "shutdown" in task_dict and task_dict["shutdown"] ): - while len(wait_lst) > 0: + while len(future_dependency_lst) > 0: # Check functions in the wait list and execute them if all future objects are now ready - wait_lst = _update_waiting_task( - wait_lst=wait_lst, + future_dependency_lst = _handle_future_dependencies( + future_dependency_lst=future_dependency_lst, executor_queue=executor_queue, refresh_rate=refresh_rate, ) @@ -283,12 +285,24 @@ def _execute_tasks_with_dependencies( task_dict["future"].set_result(False) else: task_dict["future"].set_result(True) + elif ( # handle batched function submitted to the executor + task_dict is not None + and "fn" in task_dict + and task_dict["fn"] == "batched" + and "future" in task_dict + ): + future_dependency_lst.append(task_dict) + future_queue.task_done() elif ( # handle function submitted to the executor - task_dict is not None and "fn" in task_dict and "future" in task_dict + task_dict is not None + and "fn" in task_dict + and task_dict["fn"] != "batched" + and "future" in task_dict ): - future_lst, ready_flag = get_future_objects_from_input( + future_lst = get_future_objects_from_input( args=task_dict["args"], kwargs=task_dict["kwargs"] ) + ready_flag = check_list_of_futures_is_done(future_lst=future_lst) exception_lst = get_exception_lst(future_lst=future_lst) if not check_exception_was_raised(future_obj=task_dict["future"]): if len(exception_lst) > 0: @@ -301,12 +315,12 @@ def _execute_tasks_with_dependencies( executor_queue.put(task_dict) else: # Otherwise add the function to the wait list task_dict["future_lst"] = future_lst - wait_lst.append(task_dict) + future_dependency_lst.append(task_dict) future_queue.task_done() - elif len(wait_lst) > 0: + elif len(future_dependency_lst) > 0: # Check functions in the wait list and execute them if all future objects are now ready - wait_lst = _update_waiting_task( - wait_lst=wait_lst, + future_dependency_lst = _handle_future_dependencies( + future_dependency_lst=future_dependency_lst, executor_queue=executor_queue, refresh_rate=refresh_rate, ) @@ -315,14 +329,16 @@ def _execute_tasks_with_dependencies( sleep(refresh_rate) -def _update_waiting_task( - wait_lst: list[dict], executor_queue: queue.Queue, refresh_rate: float = 0.01 +def _handle_future_dependencies( + future_dependency_lst: list[dict], + executor_queue: queue.Queue, + refresh_rate: float = 0.01, ) -> list: """ Submit the waiting tasks, which future inputs have been completed, to the executor Args: - wait_lst (list): List of waiting tasks + future_dependency_lst (list): List of waiting tasks executor_queue (Queue): Queue of the internal executor refresh_rate (float): Set the refresh rate in seconds, how frequently the input queue is checked. @@ -330,7 +346,7 @@ def _update_waiting_task( list: list tasks which future inputs have not been completed """ wait_tmp_lst = [] - for task_wait_dict in wait_lst: + for task_wait_dict in future_dependency_lst: exception_lst = get_exception_lst(future_lst=task_wait_dict["future_lst"]) if len(exception_lst) > 0 and task_wait_dict["fn"] != "batched": task_wait_dict["future"].set_exception(exception_lst[0]) @@ -360,6 +376,6 @@ def _update_waiting_task( task_wait_dict["future_skip"].set_result([id(f) for f in done_lst]) else: wait_tmp_lst.append(task_wait_dict) - if len(wait_lst) == len(wait_tmp_lst): + if len(future_dependency_lst) == len(wait_tmp_lst): sleep(refresh_rate) return wait_tmp_lst diff --git a/tests/unit/standalone/interactive/test_arguments.py b/tests/unit/standalone/interactive/test_arguments.py index 2e86e9eb..40076b2e 100644 --- a/tests/unit/standalone/interactive/test_arguments.py +++ b/tests/unit/standalone/interactive/test_arguments.py @@ -3,6 +3,7 @@ from executorlib.standalone.interactive.arguments import ( check_exception_was_raised, + check_list_of_futures_is_done, get_exception_lst, get_future_objects_from_input, update_futures_in_input, @@ -13,14 +14,16 @@ class TestSerial(unittest.TestCase): def test_get_future_objects_from_input_with_future(self): input_args = (1, 2, Future(), [Future()], {3: Future()}) input_kwargs = {"a": 1, "b": [Future()], "c": {"d": Future()}, "e": Future()} - future_lst, boolean_flag = get_future_objects_from_input(args=input_args, kwargs=input_kwargs) + future_lst = get_future_objects_from_input(args=input_args, kwargs=input_kwargs) + boolean_flag = check_list_of_futures_is_done(future_lst=future_lst) self.assertEqual(len(future_lst), 6) self.assertFalse(boolean_flag) def test_get_future_objects_from_input_without_future(self): input_args = (1, 2) input_kwargs = {"a": 1} - future_lst, boolean_flag = get_future_objects_from_input(args=input_args, kwargs=input_kwargs) + future_lst = get_future_objects_from_input(args=input_args, kwargs=input_kwargs) + boolean_flag = check_list_of_futures_is_done(future_lst=future_lst) self.assertEqual(len(future_lst), 0) self.assertTrue(boolean_flag)