diff --git a/src/executorlib/backend/interactive_parallel.py b/src/executorlib/backend/interactive_parallel.py index 7f968391d..ff752834a 100644 --- a/src/executorlib/backend/interactive_parallel.py +++ b/src/executorlib/backend/interactive_parallel.py @@ -16,6 +16,50 @@ ) +def _execute_init_dict( + input_dict: dict, + memory: dict, + socket: Optional[zmq.Socket], + mpi_rank_zero: bool, + mpi_size_larger_one: bool, +) -> None: + """ + Execute an init-function message and update the in-process memory store. + + Runs the callable in input_dict on every MPI rank, then gathers errors + from all ranks to rank 0 so that a failure on any non-zero rank is not + silently swallowed. Rank 0 sends the result or the first observed error + back to the scheduler via the ZMQ socket. + + Args: + input_dict (dict): Message dict with keys "init", "fn", "args", "kwargs". + memory (dict): Per-rank memory store; updated in-place with the return + value of the init function on success. + socket (zmq.Socket | None): ZMQ socket used by rank 0 to reply to the + scheduler; None on non-zero ranks. + mpi_rank_zero (bool): True only on MPI rank 0. + mpi_size_larger_one (bool): True when the communicator has more than one rank. + """ + from mpi4py import MPI + + init_error = None + try: + memory.update(call_funct(input_dict=input_dict, funct=None, memory=memory)) + except Exception as error: + init_error = error + if mpi_size_larger_one: + all_errors = MPI.COMM_WORLD.gather(init_error, root=0) + else: + all_errors = [init_error] + if mpi_rank_zero: + first_error = next((e for e in all_errors if e is not None), None) + if first_error is not None: + interface_send(socket=socket, result_dict={"error": first_error}) + backend_write_error_file(error=first_error, apply_dict=input_dict) + else: + interface_send(socket=socket, result_dict={"result": True}) + + def main() -> None: """ Entry point of the program. @@ -97,23 +141,13 @@ def main() -> None: and "args" in input_dict and "kwargs" in input_dict ): - try: - memory.update( - call_funct(input_dict=input_dict, funct=None, memory=memory) - ) - except Exception as error: - if mpi_rank_zero: - interface_send( - socket=socket, - result_dict={"error": error}, - ) - backend_write_error_file( - error=error, - apply_dict=input_dict, - ) - else: - if mpi_rank_zero: - interface_send(socket=socket, result_dict={"result": True}) + _execute_init_dict( + input_dict=input_dict, + memory=memory, + socket=socket, + mpi_rank_zero=mpi_rank_zero, + mpi_size_larger_one=mpi_size_larger_one, + ) if __name__ == "__main__": diff --git a/tests/unit/standalone/interactive/test_spawner.py b/tests/unit/standalone/interactive/test_spawner.py index 1af872cdc..bd76c35e0 100644 --- a/tests/unit/standalone/interactive/test_spawner.py +++ b/tests/unit/standalone/interactive/test_spawner.py @@ -271,6 +271,25 @@ def test_execute_task(self): self.assertEqual(f.result(), np.array([5])) q.join() + @unittest.skipIf( + skip_mpi4py_test, "mpi4py is not installed, so the mpi4py tests are skipped." + ) + def test_internal_memory_mpi(self): + with BlockAllocationTaskScheduler( + max_workers=1, + executor_kwargs={ + "cores": 2, + "init_function": set_global, + }, + spawner=MpiExecSpawner, + ) as p: + cloudpickle_register(ind=1) + f = p.submit(get_global) + result = f.result() + self.assertEqual(len(result), 2) + np.testing.assert_array_equal(result[0], np.array([5])) + np.testing.assert_array_equal(result[1], np.array([5])) + class TestBlockAllocationTaskScheduler(unittest.TestCase): def test_submit_tracks_future_state(self):