Skip to content
Merged
1 change: 1 addition & 0 deletions doc/changes/dev/13777.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Avoid some unnecessary computations when ``n_jobs=None`` is equivalent to ``n_jobs=1``, by `Simon Kern`_.
12 changes: 11 additions & 1 deletion mne/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def parallel_func(
if n_jobs is not None:
warn("joblib not installed. Cannot run in parallel.")
n_jobs = 1
if n_jobs == 1:
if (n_jobs == 1) or (n_jobs is None and not _running_in_joblib_context()):
n_jobs = 1
my_func = func
parallel = list
Expand Down Expand Up @@ -154,6 +154,16 @@ def parallel_progress(op_iter):
return parallel_out, my_func, n_jobs


def _running_in_joblib_context():
"""Check if we are running in a joblib.parallel_config context manager."""
try:
from joblib.parallel import get_active_backend
except ImportError:
return False
_, n_jobs = get_active_backend()
return n_jobs is not None


def _check_n_jobs(n_jobs):
n_jobs = _ensure_int(n_jobs, "n_jobs", must_be="an int or None")
if os.getenv("MNE_FORCE_SERIAL", "").lower() in ("true", "1") and n_jobs != 1:
Expand Down
27 changes: 27 additions & 0 deletions mne/tests/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import multiprocessing
import os
import sys
from contextlib import nullcontext

import pytest
Expand Down Expand Up @@ -50,3 +51,29 @@ def fun(x):
with ctx:
parallel, p_fun, got_jobs = parallel_func(fun, n_jobs, verbose="debug")
assert got_jobs == want_jobs


def test_parallel_func_n_jobs_none():
"""Test n_jobs=None is same as n_jobs=1."""
joblib = pytest.importorskip("joblib")

def fun(x):
return x * 2

# test that n_jobs=None (outside context) behaves identically to n_jobs=1.
parallel_none, p_fun_none, n_jobs_none = parallel_func(fun, n_jobs=None)
parallel_one, p_fun_one, n_jobs_one = parallel_func(fun, n_jobs=1)

assert parallel_none is parallel_one is list
assert n_jobs_none == n_jobs_one == 1
assert p_fun_none is p_fun_one is fun, "fun should not be wrapped but is"

# TODO: test does not work on windows somehow, fix
if sys.platform != "win32":
# Test that n_jobs=None inside a joblib context uses Parallel.
with joblib.parallel_config(backend="loky", n_jobs=2):
parallel, p_fun, n_jobs = parallel_func(fun, n_jobs=None)

assert n_jobs == 2
assert parallel is not list
assert fun is not p_fun, "fun should be wrapped but is not"
Loading