From 5f2486ab544ff216aa53836a8f10a12094da7a93 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 19 Jan 2026 14:45:50 +0100 Subject: [PATCH] make_shared_array spawn safe --- src/spikeinterface/core/core_tools.py | 8 +++- src/spikeinterface/core/waveform_tools.py | 37 +++++++++++++++---- .../sortingcomponents/matching/circus.py | 15 +++++++- .../sortingcomponents/matching/wobble.py | 15 +++++++- 4 files changed, 62 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 7640168cb7..6990ca1e7a 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -214,9 +214,15 @@ def add_suffix(file_path, possible_suffix): return file_path -def make_shared_array(shape, dtype): +def make_shared_array(shape, dtype, mp_context=None): + import multiprocessing as mp from multiprocessing.shared_memory import SharedMemory + # we need to set the mp context before creating the shared memory, to avoid + # SemLock errors + if mp_context is not None: + mp.set_start_method(mp_context, force=True) + dtype = np.dtype(dtype) shape = tuple(int(x) for x in shape) # We need to be sure that shape comes in int instead of numpy scalars nbytes = prod(shape) * dtype.itemsize diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 71bd794f18..c2486ec562 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -109,7 +109,16 @@ def extract_waveforms_to_buffers( dtype = np.dtype(dtype) waveforms_by_units, arrays_info = allocate_waveforms_buffers( - recording, spikes, unit_ids, nbefore, nafter, mode=mode, folder=folder, dtype=dtype, sparsity_mask=sparsity_mask + recording, + spikes, + unit_ids, + nbefore, + nafter, + mode=mode, + folder=folder, + dtype=dtype, + sparsity_mask=sparsity_mask, + **job_kwargs, ) distribute_waveforms_to_buffers( @@ -145,7 +154,16 @@ def extract_waveforms_to_buffers( def allocate_waveforms_buffers( - recording, spikes, unit_ids, nbefore, nafter, mode="memmap", folder=None, dtype=None, sparsity_mask=None + recording, + spikes, + unit_ids, + nbefore, + nafter, + mode="memmap", + folder=None, + dtype=None, + sparsity_mask=None, + **job_kwargs, ): """ Allocate memmap or shared memory buffers before snippet extraction. @@ -208,12 +226,13 @@ def allocate_waveforms_buffers( waveforms_by_units[unit_id] = arr arrays_info[unit_id] = filename elif mode == "shared_memory": + mp_context = job_kwargs.get("mp_context", None) if n_spikes == 0 or num_chans == 0: arr = np.zeros(shape, dtype=dtype) shm = None shm_name = None else: - arr, shm = make_shared_array(shape, dtype) + arr, shm = make_shared_array(shape, dtype, mp_context=mp_context) shm_name = shm.name waveforms_by_units[unit_id] = arr arrays_info[unit_id] = (shm, shm_name, dtype.str, shape) @@ -503,6 +522,9 @@ def extract_waveforms_to_single_buffer( ) return_in_uV = return_scaled + job_kwargs = fix_job_kwargs(job_kwargs) + mp_context = job_kwargs.get("mp_context", None) + n_samples = nbefore + nafter if dtype is None: @@ -532,15 +554,13 @@ def extract_waveforms_to_single_buffer( shm = None shm_name = None else: - all_waveforms, shm = make_shared_array(shape, dtype) + all_waveforms, shm = make_shared_array(shape, dtype, mp_context=mp_context) shm_name = shm.name # wf_array_info = (shm, shm_name, dtype.str, shape) wf_array_info = dict(shm=shm, shm_name=shm_name, dtype=dtype.str, shape=shape) else: raise ValueError("allocate_waveforms_buffers bad mode") - job_kwargs = fix_job_kwargs(job_kwargs) - if num_spikes > 0 and num_chans > 0: # and run func = _worker_distribute_single_buffer @@ -906,6 +926,7 @@ def estimate_templates_with_accumulator( job_kwargs = fix_job_kwargs(job_kwargs) num_worker = job_kwargs["n_jobs"] + mp_context = job_kwargs["mp_context"] if sparsity_mask is None: num_chans = int(recording.get_num_channels()) @@ -916,10 +937,10 @@ def estimate_templates_with_accumulator( shape = (num_worker, num_units, nbefore + nafter, num_chans) dtype = np.dtype("float32") - waveform_accumulator_per_worker, shm = make_shared_array(shape, dtype) + waveform_accumulator_per_worker, shm = make_shared_array(shape, dtype, mp_context=mp_context) shm_name = shm.name if return_std: - waveform_squared_accumulator_per_worker, shm_squared = make_shared_array(shape, dtype) + waveform_squared_accumulator_per_worker, shm_squared = make_shared_array(shape, dtype, mp_context=mp_context) shm_squared_name = shm_squared.name else: waveform_squared_accumulator_per_worker = None diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 302fe0a328..113aaf88ed 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -145,6 +145,8 @@ class CircusOMPPeeler(BaseTemplateMatching): shared_memory : bool, default True If True, the overlaps are stored in shared memory, which is more efficient when using numerous cores + mp_context : multiprocessing context, default None + The multiprocessing context to use for shared memory arrays """ _more_output_keys = [ @@ -175,6 +177,7 @@ def __init__( engine="numpy", shared_memory=True, torch_device="cpu", + mp_context=None, ): BaseTemplateMatching.__init__(self, recording, templates, return_output=return_output) @@ -217,11 +220,19 @@ def __init__( setattr(self, key, precomputed[key]) if self.shared_memory: + from spikeinterface.core.globals import get_global_job_kwargs + from spikeinterface.core.core_tools import make_shared_array + + # if None, use global context + if mp_context is None: + mp_context = get_global_job_kwargs()["mp_context"] + self.max_overlaps = max([len(o) for o in self.overlaps]) num_samples = len(self.overlaps[0][0]) - from spikeinterface.core.core_tools import make_shared_array - arr, shm = make_shared_array((self.num_templates, self.max_overlaps, num_samples), dtype=np.float32) + arr, shm = make_shared_array( + (self.num_templates, self.max_overlaps, num_samples), dtype=np.float32, mp_context=mp_context + ) for i in range(self.num_templates): n_overlaps = len(self.unit_overlaps_indices[i]) arr[i, :n_overlaps] = self.overlaps[i] diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index bd5071f2b9..002eb2c50b 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -379,6 +379,8 @@ class WobbleMatch(BaseTemplateMatching): shared_memory : bool, default True If True, the overlaps are stored in shared memory, which is more efficient when using numerous cores + mp_context : multiprocessing context, default None + The multiprocessing context to use for shared memory arrays. """ def __init__( @@ -390,6 +392,7 @@ def __init__( engine="numpy", torch_device="cpu", shared_memory=True, + mp_context=None, ): BaseTemplateMatching.__init__(self, recording, templates, return_output=return_output) @@ -431,13 +434,21 @@ def __init__( self.shared_memory = shared_memory if self.shared_memory: + from spikeinterface.core.globals import get_global_job_kwargs + from spikeinterface.core.core_tools import make_shared_array + + # if None, use global context + if mp_context is None: + mp_context = get_global_job_kwargs()["mp_context"] + self.max_overlaps = max([len(o) for o in pairwise_convolution]) num_samples = len(pairwise_convolution[0][0]) num_templates = len(templates_array) num_jittered = num_templates * params.jitter_factor - from spikeinterface.core.core_tools import make_shared_array - arr, shm = make_shared_array((num_jittered, self.max_overlaps, num_samples), dtype=np.float32) + arr, shm = make_shared_array( + (num_jittered, self.max_overlaps, num_samples), dtype=np.float32, mp_context=mp_context + ) for jittered_index in range(num_jittered): units_are_overlapping = sparsity.unit_overlap[jittered_index, :] overlapping_units = np.where(units_are_overlapping)[0]