diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index f87229b958..53fe7be1f2 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -35,14 +35,18 @@ class ComputeRandomSpikes(AnalyzerExtension): Parameters ---------- - method : "uniform" | "all", default: "uniform" - The method to select the spikes + method: "uniform" | "percentage" | "maximum_rate" | "all" , default: "uniform" + Method to select spikes: "uniform" randomly up to max_spikes_per_unit, "percentage" selects a fraction of spikes, and "maximum_rate" limits selection by spike rate over time. max_spikes_per_unit : int, default: 500 The maximum number of spikes per unit, ignored if method="all" margin_size : int, default: None A margin on each border of segments to avoid border spikes, ignored if method="all" seed : int or None, default: None A seed for the random generator, ignored if method="all" + percentage: float | None, default: None + In case of `percentage` method. The proportion of spikes per units. + maximum_rate: float | None, default: None + In case of `maximum_rate` method. The cap rate per units. Returns ------- @@ -64,7 +68,9 @@ def _run(self, verbose=False): **self.params, ) - def _set_params(self, method="uniform", max_spikes_per_unit=500, margin_size=None, seed=None): + def _set_params( + self, method="uniform", max_spikes_per_unit=500, margin_size=None, seed=None, percentage=None, maximum_rate=None + ): params = dict(method=method, max_spikes_per_unit=max_spikes_per_unit, margin_size=margin_size, seed=seed) return params diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index fc3adba3e4..cbad29c806 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -1,6 +1,8 @@ import warnings import importlib.util +from typing import Literal + import numpy as np from spikeinterface.core.base import BaseExtractor, unit_period_dtype @@ -146,14 +148,16 @@ def vector_to_list_of_spiketrain_numba(sample_indices, unit_indices, num_units): return vector_to_list_of_spiketrain_numba -# TODO later : implement other method like "maximum_rate", "by_percent", ... +# stratified sampling (isi / amplitude / pca distance ? ) def random_spikes_selection( sorting: BaseSorting, - num_samples: int | None = None, - method: str = "uniform", + num_samples: list[int] | None = None, + method: Literal["uniform", "all", "percentage", "maximum_rate"] = "uniform", max_spikes_per_unit: int = 500, margin_size: int | None = None, seed: int | None = None, + percentage: float | None = None, + maximum_rate: float | None = None, ): """ This replaces `select_random_spikes_uniformly()`. @@ -165,41 +169,57 @@ def random_spikes_selection( ---------- sorting: BaseSorting The sorting object - num_samples: list of int + num_samples: list[int] | None, default: None The number of samples per segment. Can be retrieved from recording with num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] - method: "uniform" | "all", default: "uniform" - The method to use. Only "uniform" is implemented for now + method: "uniform" | "percentage" | "maximum_rate" | "all" , default: "uniform" + Method to select spikes: "uniform" randomly up to max_spikes_per_unit, "percentage" selects a fraction of spikes, and "maximum_rate" limits selection by spike rate over time. max_spikes_per_unit: int, default: 500 - The number of spikes per units + The maximum number of spikes per units margin_size: None | int, default: None A margin on each border of segments to avoid border spikes seed: None | int, default: None A seed for random generator + percentage: float | None, default: None + In case of `percentage` method. The proportion of spikes per units. + maximum_rate: float | None, default: None + In case of `maximum_rate` method. The cap rate per units. Returns ------- random_spikes_indices: np.array Selected spike indices coresponding to the sorting spike vector. """ + rng_methods = ("uniform", "percentage", "maximum_rate") + + if method == "all": + spikes = sorting.to_spike_vector() + random_spikes_indices = np.arange(spikes.size) + + elif method in rng_methods: + from spikeinterface.widgets.utils import get_segment_durations - if method == "uniform": rng = np.random.default_rng(seed=seed) + # since un concatenated + # spikes = [ [ (sample_index, unit_index, segment_index), (), ... ], [ (), ... ]] spikes = sorting.to_spike_vector(concatenated=False) cum_sizes = np.cumsum([0] + [s.size for s in spikes]) - # this fast when numba + # this is fast when numba is installed spike_indices = spike_vector_to_indices(spikes, sorting.unit_ids, absolute_index=False) random_spikes_indices = [] for unit_index, unit_id in enumerate(sorting.unit_ids): all_unit_indices = [] for segment_index in range(sorting.get_num_segments()): - # this is local index + # this is local segment index inds_in_seg = spike_indices[segment_index][unit_id] if margin_size is not None: + if num_samples is None: + raise ValueError("num_samples must be provided when margin_size is used") + local_spikes = spikes[segment_index][inds_in_seg] mask = (local_spikes["sample_index"] >= margin_size) & ( local_spikes["sample_index"] < (num_samples[segment_index] - margin_size) @@ -209,19 +229,33 @@ def random_spikes_selection( inds_in_seg_abs = inds_in_seg + cum_sizes[segment_index] all_unit_indices.append(inds_in_seg_abs) all_unit_indices = np.concatenate(all_unit_indices) - selected_unit_indices = rng.choice( - all_unit_indices, size=min(max_spikes_per_unit, all_unit_indices.size), replace=False, shuffle=False - ) + + if method == "uniform": + rng_size = min(max_spikes_per_unit, all_unit_indices.size) + selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False) + + elif method == "percentage": + if percentage is None or not (0 < percentage <= 1): + raise ValueError(f"percentage must be in the interval (0, 1]") + + rng_size = min(max_spikes_per_unit, int(all_unit_indices.size * percentage)) + selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False) + + elif method == "maximum_rate": + if maximum_rate is None: + raise ValueError(f"maximum_rate must be defined") + + t_duration = np.sum(get_segment_durations(sorting)) + rng_size = min(int(t_duration * maximum_rate), max_spikes_per_unit, all_unit_indices.size) + selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False) + random_spikes_indices.append(selected_unit_indices) random_spikes_indices = np.concatenate(random_spikes_indices) random_spikes_indices = np.sort(random_spikes_indices) - elif method == "all": - spikes = sorting.to_spike_vector() - random_spikes_indices = np.arange(spikes.size) else: - raise ValueError(f"random_spikes_selection(): method must be 'all' or 'uniform'") + raise ValueError(f"random_spikes_selection(): method must be 'all' or any in {', '.join(rng_methods)}") return random_spikes_indices diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index 1040d17428..0170038c96 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -94,11 +94,11 @@ def fit( model_folder_path: str, detect_peaks_params: dict, peak_selection_params: dict, - job_kwargs: dict = None, + job_kwargs: dict | None = None, ms_before: float = 1.0, ms_after: float = 1.0, whiten: bool = True, - radius_um: float = None, + radius_um: float | None = None, ) -> "IncrementalPCA": """ Train a pca model using the data in the recording object and the parameters provided. diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 11896686c4..000193bb43 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -399,7 +399,7 @@ def validate_segment_indices(segment_indices: list[int] | None, sorting: BaseSor return segment_indices -def get_segment_durations(sorting: BaseSorting, segment_indices: list[int]) -> list[float]: +def get_segment_durations(sorting: BaseSorting, segment_indices: list[int] = None) -> list[float]: """ Calculate the duration of each segment in a sorting object. @@ -408,11 +408,17 @@ def get_segment_durations(sorting: BaseSorting, segment_indices: list[int]) -> l sorting : BaseSorting The sorting object containing spike data + segment_indices : list[int] | None + List of the segment indices to process. Default to None. + Returns ------- list[float] List of segment durations in seconds """ + if segment_indices is None: + segment_indices = range(sorting.get_num_segments()) + spikes = sorting.to_spike_vector() segment_boundaries = [