Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,15 @@ def add_suffix(file_path, possible_suffix):
return file_path


def make_shared_array(shape, dtype):
def make_shared_array(shape, dtype, mp_context=None):
import multiprocessing as mp
from multiprocessing.shared_memory import SharedMemory

# we need to set the mp context before creating the shared memory, to avoid
# SemLock errors
if mp_context is not None:
mp.set_start_method(mp_context, force=True)

dtype = np.dtype(dtype)
shape = tuple(int(x) for x in shape) # We need to be sure that shape comes in int instead of numpy scalars
nbytes = prod(shape) * dtype.itemsize
Expand Down
37 changes: 29 additions & 8 deletions src/spikeinterface/core/waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,16 @@ def extract_waveforms_to_buffers(
dtype = np.dtype(dtype)

waveforms_by_units, arrays_info = allocate_waveforms_buffers(
recording, spikes, unit_ids, nbefore, nafter, mode=mode, folder=folder, dtype=dtype, sparsity_mask=sparsity_mask
recording,
spikes,
unit_ids,
nbefore,
nafter,
mode=mode,
folder=folder,
dtype=dtype,
sparsity_mask=sparsity_mask,
**job_kwargs,
)

distribute_waveforms_to_buffers(
Expand Down Expand Up @@ -145,7 +154,16 @@ def extract_waveforms_to_buffers(


def allocate_waveforms_buffers(
recording, spikes, unit_ids, nbefore, nafter, mode="memmap", folder=None, dtype=None, sparsity_mask=None
recording,
spikes,
unit_ids,
nbefore,
nafter,
mode="memmap",
folder=None,
dtype=None,
sparsity_mask=None,
**job_kwargs,
):
"""
Allocate memmap or shared memory buffers before snippet extraction.
Expand Down Expand Up @@ -208,12 +226,13 @@ def allocate_waveforms_buffers(
waveforms_by_units[unit_id] = arr
arrays_info[unit_id] = filename
elif mode == "shared_memory":
mp_context = job_kwargs.get("mp_context", None)
if n_spikes == 0 or num_chans == 0:
arr = np.zeros(shape, dtype=dtype)
shm = None
shm_name = None
else:
arr, shm = make_shared_array(shape, dtype)
arr, shm = make_shared_array(shape, dtype, mp_context=mp_context)
shm_name = shm.name
waveforms_by_units[unit_id] = arr
arrays_info[unit_id] = (shm, shm_name, dtype.str, shape)
Expand Down Expand Up @@ -503,6 +522,9 @@ def extract_waveforms_to_single_buffer(
)
return_in_uV = return_scaled

job_kwargs = fix_job_kwargs(job_kwargs)
mp_context = job_kwargs.get("mp_context", None)

n_samples = nbefore + nafter

if dtype is None:
Expand Down Expand Up @@ -532,15 +554,13 @@ def extract_waveforms_to_single_buffer(
shm = None
shm_name = None
else:
all_waveforms, shm = make_shared_array(shape, dtype)
all_waveforms, shm = make_shared_array(shape, dtype, mp_context=mp_context)
shm_name = shm.name
# wf_array_info = (shm, shm_name, dtype.str, shape)
wf_array_info = dict(shm=shm, shm_name=shm_name, dtype=dtype.str, shape=shape)
else:
raise ValueError("allocate_waveforms_buffers bad mode")

job_kwargs = fix_job_kwargs(job_kwargs)

if num_spikes > 0 and num_chans > 0:
# and run
func = _worker_distribute_single_buffer
Expand Down Expand Up @@ -906,6 +926,7 @@ def estimate_templates_with_accumulator(

job_kwargs = fix_job_kwargs(job_kwargs)
num_worker = job_kwargs["n_jobs"]
mp_context = job_kwargs["mp_context"]

if sparsity_mask is None:
num_chans = int(recording.get_num_channels())
Expand All @@ -916,10 +937,10 @@ def estimate_templates_with_accumulator(
shape = (num_worker, num_units, nbefore + nafter, num_chans)

dtype = np.dtype("float32")
waveform_accumulator_per_worker, shm = make_shared_array(shape, dtype)
waveform_accumulator_per_worker, shm = make_shared_array(shape, dtype, mp_context=mp_context)
shm_name = shm.name
if return_std:
waveform_squared_accumulator_per_worker, shm_squared = make_shared_array(shape, dtype)
waveform_squared_accumulator_per_worker, shm_squared = make_shared_array(shape, dtype, mp_context=mp_context)
shm_squared_name = shm_squared.name
else:
waveform_squared_accumulator_per_worker = None
Expand Down
15 changes: 13 additions & 2 deletions src/spikeinterface/sortingcomponents/matching/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class CircusOMPPeeler(BaseTemplateMatching):
shared_memory : bool, default True
If True, the overlaps are stored in shared memory, which is more efficient when
using numerous cores
mp_context : multiprocessing context, default None
The multiprocessing context to use for shared memory arrays
"""

_more_output_keys = [
Expand Down Expand Up @@ -175,6 +177,7 @@ def __init__(
engine="numpy",
shared_memory=True,
torch_device="cpu",
mp_context=None,
):

BaseTemplateMatching.__init__(self, recording, templates, return_output=return_output)
Expand Down Expand Up @@ -217,11 +220,19 @@ def __init__(
setattr(self, key, precomputed[key])

if self.shared_memory:
from spikeinterface.core.globals import get_global_job_kwargs
from spikeinterface.core.core_tools import make_shared_array

# if None, use global context
if mp_context is None:
mp_context = get_global_job_kwargs()["mp_context"]

self.max_overlaps = max([len(o) for o in self.overlaps])
num_samples = len(self.overlaps[0][0])
from spikeinterface.core.core_tools import make_shared_array

arr, shm = make_shared_array((self.num_templates, self.max_overlaps, num_samples), dtype=np.float32)
arr, shm = make_shared_array(
(self.num_templates, self.max_overlaps, num_samples), dtype=np.float32, mp_context=mp_context
)
for i in range(self.num_templates):
n_overlaps = len(self.unit_overlaps_indices[i])
arr[i, :n_overlaps] = self.overlaps[i]
Expand Down
15 changes: 13 additions & 2 deletions src/spikeinterface/sortingcomponents/matching/wobble.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ class WobbleMatch(BaseTemplateMatching):
shared_memory : bool, default True
If True, the overlaps are stored in shared memory, which is more efficient when
using numerous cores
mp_context : multiprocessing context, default None
The multiprocessing context to use for shared memory arrays.
"""

def __init__(
Expand All @@ -390,6 +392,7 @@ def __init__(
engine="numpy",
torch_device="cpu",
shared_memory=True,
mp_context=None,
):

BaseTemplateMatching.__init__(self, recording, templates, return_output=return_output)
Expand Down Expand Up @@ -431,13 +434,21 @@ def __init__(
self.shared_memory = shared_memory

if self.shared_memory:
from spikeinterface.core.globals import get_global_job_kwargs
from spikeinterface.core.core_tools import make_shared_array

# if None, use global context
if mp_context is None:
mp_context = get_global_job_kwargs()["mp_context"]

self.max_overlaps = max([len(o) for o in pairwise_convolution])
num_samples = len(pairwise_convolution[0][0])
num_templates = len(templates_array)
num_jittered = num_templates * params.jitter_factor
from spikeinterface.core.core_tools import make_shared_array

arr, shm = make_shared_array((num_jittered, self.max_overlaps, num_samples), dtype=np.float32)
arr, shm = make_shared_array(
(num_jittered, self.max_overlaps, num_samples), dtype=np.float32, mp_context=mp_context
)
for jittered_index in range(num_jittered):
units_are_overlapping = sparsity.unit_overlap[jittered_index, :]
overlapping_units = np.where(units_are_overlapping)[0]
Expand Down
Loading