Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 51 additions & 17 deletions src/executorlib/backend/interactive_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,50 @@
)


def _execute_init_dict(
input_dict: dict,
memory: dict,
socket: Optional[zmq.Socket],
mpi_rank_zero: bool,
mpi_size_larger_one: bool,
) -> None:
"""
Execute an init-function message and update the in-process memory store.

Runs the callable in input_dict on every MPI rank, then gathers errors
from all ranks to rank 0 so that a failure on any non-zero rank is not
silently swallowed. Rank 0 sends the result or the first observed error
back to the scheduler via the ZMQ socket.

Args:
input_dict (dict): Message dict with keys "init", "fn", "args", "kwargs".
memory (dict): Per-rank memory store; updated in-place with the return
value of the init function on success.
socket (zmq.Socket | None): ZMQ socket used by rank 0 to reply to the
scheduler; None on non-zero ranks.
mpi_rank_zero (bool): True only on MPI rank 0.
mpi_size_larger_one (bool): True when the communicator has more than one rank.
"""
from mpi4py import MPI

init_error = None
try:
memory.update(call_funct(input_dict=input_dict, funct=None, memory=memory))
except Exception as error:
init_error = error
if mpi_size_larger_one:
all_errors = MPI.COMM_WORLD.gather(init_error, root=0)
else:
all_errors = [init_error]
if mpi_rank_zero:
first_error = next((e for e in all_errors if e is not None), None)
if first_error is not None:
interface_send(socket=socket, result_dict={"error": first_error})
backend_write_error_file(error=first_error, apply_dict=input_dict)
else:
interface_send(socket=socket, result_dict={"result": True})


def main() -> None:
"""
Entry point of the program.
Expand Down Expand Up @@ -97,23 +141,13 @@ def main() -> None:
and "args" in input_dict
and "kwargs" in input_dict
):
try:
memory.update(
call_funct(input_dict=input_dict, funct=None, memory=memory)
)
except Exception as error:
if mpi_rank_zero:
interface_send(
socket=socket,
result_dict={"error": error},
)
backend_write_error_file(
error=error,
apply_dict=input_dict,
)
else:
if mpi_rank_zero:
interface_send(socket=socket, result_dict={"result": True})
_execute_init_dict(
input_dict=input_dict,
memory=memory,
socket=socket,
mpi_rank_zero=mpi_rank_zero,
mpi_size_larger_one=mpi_size_larger_one,
)


if __name__ == "__main__":
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/standalone/interactive/test_spawner.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,25 @@ def test_execute_task(self):
self.assertEqual(f.result(), np.array([5]))
q.join()

@unittest.skipIf(
skip_mpi4py_test, "mpi4py is not installed, so the mpi4py tests are skipped."
)
def test_internal_memory_mpi(self):
with BlockAllocationTaskScheduler(
max_workers=1,
executor_kwargs={
"cores": 2,
"init_function": set_global,
},
spawner=MpiExecSpawner,
) as p:
cloudpickle_register(ind=1)
f = p.submit(get_global)
result = f.result()
self.assertEqual(len(result), 2)
np.testing.assert_array_equal(result[0], np.array([5]))
np.testing.assert_array_equal(result[1], np.array([5]))


class TestBlockAllocationTaskScheduler(unittest.TestCase):
def test_submit_tracks_future_state(self):
Expand Down
Loading