Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 48 additions & 19 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,24 @@ def __init__(

if pool_engine == "process":
if mp_context is None:
mp_context = recording.get_preferred_mp_context()
if mp_context is not None and platform.system() == "Windows":
assert mp_context != "fork", "'fork' mp_context not supported on Windows!"
elif mp_context == "fork" and platform.system() == "Darwin":
warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS')
# auto choice
if platform.system() == "Windows":
mp_context = "spawn"
elif platform.system() == "Linux":
mp_context = "fork"
elif platform.system() == "Darwin":
# We used to force spawn for macos, this is sad but in some cases fork in macos
# is very unstable and lead to crashes.
mp_context = "spawn"
else:
mp_context = "spawn"

preferred_mp_context = recording.get_preferred_mp_context()
if preferred_mp_context is not None and preferred_mp_context != mp_context:
warnings.warn(
f"You processing chain using pool_engine='process' and mp_context='{mp_context}' is not possible."
f"So use mp_context='{preferred_mp_context}' instead")
mp_context = preferred_mp_context

self.mp_context = mp_context

Expand Down Expand Up @@ -486,9 +499,14 @@ def run(self, recording_slices=None):
recording_slices, desc=f"{self.job_name} (no parallelization)", total=len(recording_slices)
)

worker_dict = self.init_func(*self.init_args)
init_args = self.init_args
if self.need_worker_index:
worker_dict["worker_index"] = 0
worker_index = 0
init_args = init_args + (worker_index, )

worker_dict = self.init_func(*init_args)
if self.need_worker_index:
worker_dict["worker_index"] = worker_index

for segment_index, frame_start, frame_stop in recording_slices:
res = self.func(segment_index, frame_start, frame_stop, worker_dict)
Expand All @@ -503,6 +521,8 @@ def run(self, recording_slices=None):
if self.pool_engine == "process":

if self.need_worker_index:

multiprocessing.set_start_method(self.mp_context, force=True)
lock = multiprocessing.Lock()
array_pid = multiprocessing.Array("i", n_jobs)
for i in range(n_jobs):
Expand Down Expand Up @@ -530,7 +550,7 @@ def run(self, recording_slices=None):

if self.progress_bar:
results = tqdm(
results, desc=f"{self.job_name} (workers: {n_jobs} processes)", total=len(recording_slices)
results, desc=f"{self.job_name} (workers: {n_jobs} processes {self.mp_context})", total=len(recording_slices)
)

for res in results:
Expand Down Expand Up @@ -619,11 +639,6 @@ def __call__(self, args):

def process_worker_initializer(func, init_func, init_args, max_threads_per_worker, need_worker_index, lock, array_pid):
global _process_func_wrapper
if max_threads_per_worker is None:
worker_dict = init_func(*init_args)
else:
with threadpool_limits(limits=max_threads_per_worker):
worker_dict = init_func(*init_args)

if need_worker_index:
child_process = multiprocessing.current_process()
Expand All @@ -634,9 +649,19 @@ def process_worker_initializer(func, init_func, init_args, max_threads_per_worke
worker_index = i
array_pid[i] = child_process.ident
break
worker_dict["worker_index"] = worker_index
lock.release()

init_args = init_args + (worker_index, )

if max_threads_per_worker is None:
worker_dict = init_func(*init_args)
else:
with threadpool_limits(limits=max_threads_per_worker):
worker_dict = init_func(*init_args)

if need_worker_index:
worker_dict["worker_index"] = worker_index

_process_func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_worker)


Expand All @@ -652,19 +677,23 @@ def process_function_wrapper(args):
def thread_worker_initializer(
func, init_func, init_args, max_threads_per_worker, thread_local_data, need_worker_index, lock
):

if need_worker_index:
lock.acquire()
global _thread_started
worker_index = _thread_started
_thread_started += 1
lock.release()
init_args = init_args + (worker_index, )

if max_threads_per_worker is None:
worker_dict = init_func(*init_args)
else:
with threadpool_limits(limits=max_threads_per_worker):
worker_dict = init_func(*init_args)

if need_worker_index:
lock.acquire()
global _thread_started
worker_index = _thread_started
_thread_started += 1
worker_dict["worker_index"] = worker_index
lock.release()

thread_local_data.func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_worker)

Expand Down
27 changes: 25 additions & 2 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional, Type

import struct
import copy

from pathlib import Path

Expand Down Expand Up @@ -71,6 +72,11 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *ar

class PeakSource(PipelineNode):

# this is an important hack : this force a node.compute() before the machininery is started
# this trigger eventually some numba jit compilation and avoid compilation racing
# between processes or threads
need_first_call_before_pipeline = False

def get_trace_margin(self):
raise NotImplementedError

Expand All @@ -86,6 +92,12 @@ def get_peak_slice(
# not needed for PeakDetector
raise NotImplementedError

def _first_call_before_pipeline(self):
# see need_first_call_before_pipeline = True
margin = self.get_trace_margin()
traces = self.recording.get_traces(start_frame=0, end_frame=margin * 2 + 1, segment_index=0)
self.compute(traces, 0, margin * 2 + 1, 0, margin)


# this is used in sorting components
class PeakDetector(PeakSource):
Expand Down Expand Up @@ -601,7 +613,16 @@ def run_node_pipeline(
else:
raise ValueError(f"wrong gather_mode : {gather_mode}")

init_args = (recording, nodes, skip_after_n_peaks_per_worker)
node0 = nodes[0]
if isinstance(node0, PeakSource) and node0.need_first_call_before_pipeline:
# See need_first_call_before_pipeline : this trigger numba compilation before the run
node0._first_call_before_pipeline()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be a common function (by default pass) for all nodes? I assume any nodes using a numba kernel would benefit from this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other nodes are more complicated because we need to run the entire chain.
This help now for peak_detection.
Lets see how we can optimize this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And at the moement this is implemented only in peak source nodes.


if job_kwargs["n_jobs"] != 1 and job_kwargs["pool_engine"] == "thread":
need_shallow_copy = True
else:
need_shallow_copy = False
init_args = (recording, nodes, need_shallow_copy, skip_after_n_peaks_per_worker)

processor = ChunkRecordingExecutor(
recording,
Expand All @@ -620,10 +641,12 @@ def run_node_pipeline(
return outs


def _init_peak_pipeline(recording, nodes, skip_after_n_peaks_per_worker):
def _init_peak_pipeline(recording, nodes, need_shallow_copy, skip_after_n_peaks_per_worker):
# create a local dict per worker
worker_ctx = {}
worker_ctx["recording"] = recording
if need_shallow_copy:
nodes = [copy.copy(node) for node in nodes]
worker_ctx["nodes"] = nodes
worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes)
worker_ctx["skip_after_n_peaks_per_worker"] = skip_after_n_peaks_per_worker
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def __init__(
shm = SharedMemory(shm_name, create=False)
self.shms.append(shm)
traces = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf)
# Force read only
traces.flags.writeable = False
traces_list.append(traces)

if channel_ids is None:
Expand Down
3 changes: 3 additions & 0 deletions src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,9 @@ def get_chunk_with_margin(
taper = taper[:, np.newaxis]
traces_chunk2[:margin] *= taper
traces_chunk2[-margin:] *= taper[::-1]
# enforce non writable when original was not
# (this help numba to have the same signature and not compile twice)
traces_chunk2.flags.writeable = traces_chunk.flags.writeable
traces_chunk = traces_chunk2
elif add_reflect_padding:
# in this case, we don't want to taper
Expand Down
8 changes: 5 additions & 3 deletions src/spikeinterface/core/tests/test_job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,9 @@ def func2(segment_index, start_frame, end_frame, worker_dict):
return worker_dict["worker_index"]


def init_func2():
def init_func2(worker_index):
# this leave time for other thread/process to start
# print('in init_func2 with worker_index', worker_index)
time.sleep(0.010)
worker_dict = {}
return worker_dict
Expand All @@ -256,6 +257,7 @@ def test_worker_index():
for i in range(2):
# making this 2 times ensure to test that global variables are correctly reset
for pool_engine in ("process", "thread"):
# print(pool_engine)
processor = ChunkRecordingExecutor(
recording,
func2,
Expand Down Expand Up @@ -323,7 +325,7 @@ def test_get_best_job_kwargs():
# test_ChunkRecordingExecutor()
# test_fix_job_kwargs()
# test_split_job_kwargs()
# test_worker_index()
test_get_best_job_kwargs()
test_worker_index()
# test_get_best_job_kwargs()

# quick_becnhmark()
1 change: 1 addition & 0 deletions src/spikeinterface/core/waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,7 @@ def _init_worker_estimate_templates(
nafter,
return_in_uV,
sparsity_mask,
worker_index,
):
worker_dict = {}
worker_dict["recording"] = recording
Expand Down
24 changes: 12 additions & 12 deletions src/spikeinterface/sortingcomponents/clustering/isosplit_isocut.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
##########################
# isocut zone

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def jisotonic5(x, weights):
N = x.shape[0]

Expand Down Expand Up @@ -100,7 +100,7 @@ def jisotonic5(x, weights):

return y, MSE

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def updown_arange(num_bins, dtype=np.int_):
num_bins_1 = int(np.ceil(num_bins / 2))
num_bins_2 = num_bins - num_bins_1
Expand All @@ -111,7 +111,7 @@ def updown_arange(num_bins, dtype=np.int_):
)
)

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def compute_ks4(counts1, counts2):
c1s = counts1.sum()
c2s = counts2.sum()
Expand All @@ -123,7 +123,7 @@ def compute_ks4(counts1, counts2):
ks *= np.sqrt((c1s + c2s) / 2)
return ks

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def compute_ks5(counts1, counts2):
best_ks = -np.inf
length = counts1.size
Expand All @@ -138,7 +138,7 @@ def compute_ks5(counts1, counts2):

return best_ks, best_length

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def up_down_isotonic_regression(x, weights=None):
# determine switch point
_, mse1 = jisotonic5(x, weights)
Expand All @@ -153,14 +153,14 @@ def up_down_isotonic_regression(x, weights=None):

return np.hstack((y1, y2))

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def down_up_isotonic_regression(x, weights=None):
return -up_down_isotonic_regression(-x, weights=weights)

# num_bins_factor = 1
float_0 = np.array([0.0])

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def isocut(samples): # , sample_weights=None isosplit6 not handle weight anymore
"""
Compute a dip-test to check if 1-d samples are unimodal or not.
Expand Down Expand Up @@ -464,7 +464,7 @@ def ensure_continuous_labels(labels):

if HAVE_NUMBA:

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def compute_centroids_and_covmats(X, centroids, covmats, labels, label_set, to_compute_mask):
## manual loop with numba to be faster

Expand Down Expand Up @@ -498,7 +498,7 @@ def compute_centroids_and_covmats(X, centroids, covmats, labels, label_set, to_c
if to_compute_mask[i] and count[i] > 0:
covmats[i, :, :] /= count[i]

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def get_pairs_to_compare(centroids, comparisons_made, active_labels_mask):
n = centroids.shape[0]

Expand Down Expand Up @@ -526,7 +526,7 @@ def get_pairs_to_compare(centroids, comparisons_made, active_labels_mask):

return pairs

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def compute_distances(centroids, comparisons_made, active_labels_mask):
n = centroids.shape[0]
dists = np.zeros((n, n), dtype=centroids.dtype)
Expand All @@ -548,7 +548,7 @@ def compute_distances(centroids, comparisons_made, active_labels_mask):

return dists

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def merge_test(X1, X2, centroid1, centroid2, covmat1, covmat2, isocut_threshold):

if X1.size == 0 or X2.size == 0:
Expand Down Expand Up @@ -584,7 +584,7 @@ def merge_test(X1, X2, centroid1, centroid2, covmat1, covmat2, isocut_threshold)

return do_merge, L12

@numba.jit(nopython=True)
@numba.jit(nopython=True, nogil=True)
def compare_pairs(X, labels, pairs, centroids, covmats, min_cluster_size, isocut_threshold):

clusters_changed_mask = np.zeros(centroids.shape[0], dtype="bool")
Expand Down
3 changes: 3 additions & 0 deletions src/spikeinterface/sortingcomponents/matching/nearest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class NearestTemplatesPeeler(BaseTemplateMatching):

name = "nearest"
need_noise_levels = True
# this is because numba
need_first_call_before_pipeline = True

params_doc = """
peak_sign : 'neg' | 'pos' | 'both'
The peak sign to use for detection
Expand Down
Loading