Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
ad8f2bc
first commit to create branch
Dec 9, 2025
251d1a2
Merge branch 'SpikeInterface:main' into random_spike_selection_new_me…
tayheau Dec 9, 2025
e6c9c2b
temporal bin, rate cap and percentage sampling
tayheau Dec 9, 2025
f3f0b2c
Merge branch 'SpikeInterface:main' into random_spike_selection_new_me…
tayheau Dec 18, 2025
bef42f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 18, 2025
b037441
lazy loading get_segment_duration
tayheau Dec 18, 2025
625bbef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 18, 2025
acfcb66
removed unused var
Dec 18, 2025
7c603ad
Merge branch 'main' into random_spike_selection_new_methods
tayheau Dec 23, 2025
2f00b1b
small changes
samuelgarcia Dec 23, 2025
5aebf66
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 23, 2025
88f32f2
Merge branch 'main' into random_spike_selection_new_methods
samuelgarcia Dec 23, 2025
cc440d0
Merge branch 'main' into random_spike_selection_new_methods
tayheau Dec 28, 2025
5949696
Merge branch 'main' into random_spike_selection_new_methods
alejoe91 Feb 20, 2026
fa3bf6f
removed rnd selection method
tayheau Mar 6, 2026
d49f9d4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2026
d373856
propagated args to `ComputeRandomSpikes`
tayheau Mar 6, 2026
f1315d6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2026
f000f2d
better description
tayheau Mar 6, 2026
dd44155
Merge branch 'main' into random_spike_selection_new_methods
alejoe91 Mar 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,18 @@ class ComputeRandomSpikes(AnalyzerExtension):

Parameters
----------
method : "uniform" | "all", default: "uniform"
The method to select the spikes
method: "uniform" | "percentage" | "maximum_rate" | "all" , default: "uniform"
Method to select spikes: "uniform" randomly up to max_spikes_per_unit, "percentage" selects a fraction of spikes, and "maximum_rate" limits selection by spike rate over time.
max_spikes_per_unit : int, default: 500
The maximum number of spikes per unit, ignored if method="all"
margin_size : int, default: None
A margin on each border of segments to avoid border spikes, ignored if method="all"
seed : int or None, default: None
A seed for the random generator, ignored if method="all"
percentage: float | None, default: None
In case of `percentage` method. The proportion of spikes per units.
maximum_rate: float | None, default: None
In case of `maximum_rate` method. The cap rate per units.

Returns
-------
Expand All @@ -64,7 +68,9 @@ def _run(self, verbose=False):
**self.params,
)

def _set_params(self, method="uniform", max_spikes_per_unit=500, margin_size=None, seed=None):
def _set_params(
self, method="uniform", max_spikes_per_unit=500, margin_size=None, seed=None, percentage=None, maximum_rate=None
):
params = dict(method=method, max_spikes_per_unit=max_spikes_per_unit, margin_size=margin_size, seed=seed)
return params

Expand Down
68 changes: 51 additions & 17 deletions src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings
import importlib.util

from typing import Literal

import numpy as np

from spikeinterface.core.base import BaseExtractor, unit_period_dtype
Expand Down Expand Up @@ -146,14 +148,16 @@ def vector_to_list_of_spiketrain_numba(sample_indices, unit_indices, num_units):
return vector_to_list_of_spiketrain_numba


# TODO later : implement other method like "maximum_rate", "by_percent", ...
# stratified sampling (isi / amplitude / pca distance ? )
def random_spikes_selection(
sorting: BaseSorting,
num_samples: int | None = None,
method: str = "uniform",
num_samples: list[int] | None = None,
method: Literal["uniform", "all", "percentage", "maximum_rate"] = "uniform",
max_spikes_per_unit: int = 500,
margin_size: int | None = None,
seed: int | None = None,
percentage: float | None = None,
maximum_rate: float | None = None,
):
"""
This replaces `select_random_spikes_uniformly()`.
Expand All @@ -165,41 +169,57 @@ def random_spikes_selection(
----------
sorting: BaseSorting
The sorting object
num_samples: list of int
num_samples: list[int] | None, default: None
The number of samples per segment.
Can be retrieved from recording with
num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())]
method: "uniform" | "all", default: "uniform"
The method to use. Only "uniform" is implemented for now
method: "uniform" | "percentage" | "maximum_rate" | "all" , default: "uniform"
Method to select spikes: "uniform" randomly up to max_spikes_per_unit, "percentage" selects a fraction of spikes, and "maximum_rate" limits selection by spike rate over time.
max_spikes_per_unit: int, default: 500
The number of spikes per units
The maximum number of spikes per units
margin_size: None | int, default: None
A margin on each border of segments to avoid border spikes
seed: None | int, default: None
A seed for random generator
percentage: float | None, default: None
In case of `percentage` method. The proportion of spikes per units.
maximum_rate: float | None, default: None
In case of `maximum_rate` method. The cap rate per units.

Returns
-------
random_spikes_indices: np.array
Selected spike indices coresponding to the sorting spike vector.
"""
rng_methods = ("uniform", "percentage", "maximum_rate")

if method == "all":
spikes = sorting.to_spike_vector()
random_spikes_indices = np.arange(spikes.size)

elif method in rng_methods:
from spikeinterface.widgets.utils import get_segment_durations

if method == "uniform":
rng = np.random.default_rng(seed=seed)

# since un concatenated
# spikes = [ [ (sample_index, unit_index, segment_index), (), ... ], [ (), ... ]]
spikes = sorting.to_spike_vector(concatenated=False)
cum_sizes = np.cumsum([0] + [s.size for s in spikes])

# this fast when numba
# this is fast when numba is installed
spike_indices = spike_vector_to_indices(spikes, sorting.unit_ids, absolute_index=False)

random_spikes_indices = []
for unit_index, unit_id in enumerate(sorting.unit_ids):
all_unit_indices = []
for segment_index in range(sorting.get_num_segments()):
# this is local index
# this is local segment index
inds_in_seg = spike_indices[segment_index][unit_id]
if margin_size is not None:
if num_samples is None:
raise ValueError("num_samples must be provided when margin_size is used")

local_spikes = spikes[segment_index][inds_in_seg]
mask = (local_spikes["sample_index"] >= margin_size) & (
local_spikes["sample_index"] < (num_samples[segment_index] - margin_size)
Expand All @@ -209,19 +229,33 @@ def random_spikes_selection(
inds_in_seg_abs = inds_in_seg + cum_sizes[segment_index]
all_unit_indices.append(inds_in_seg_abs)
all_unit_indices = np.concatenate(all_unit_indices)
selected_unit_indices = rng.choice(
all_unit_indices, size=min(max_spikes_per_unit, all_unit_indices.size), replace=False, shuffle=False
)

if method == "uniform":
rng_size = min(max_spikes_per_unit, all_unit_indices.size)
selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False)

elif method == "percentage":
if percentage is None or not (0 < percentage <= 1):
raise ValueError(f"percentage must be in the interval (0, 1]")

rng_size = min(max_spikes_per_unit, int(all_unit_indices.size * percentage))
selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False)

elif method == "maximum_rate":
if maximum_rate is None:
raise ValueError(f"maximum_rate must be defined")

t_duration = np.sum(get_segment_durations(sorting))
rng_size = min(int(t_duration * maximum_rate), max_spikes_per_unit, all_unit_indices.size)
selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False)

random_spikes_indices.append(selected_unit_indices)

random_spikes_indices = np.concatenate(random_spikes_indices)
random_spikes_indices = np.sort(random_spikes_indices)

elif method == "all":
spikes = sorting.to_spike_vector()
random_spikes_indices = np.arange(spikes.size)
else:
raise ValueError(f"random_spikes_selection(): method must be 'all' or 'uniform'")
raise ValueError(f"random_spikes_selection(): method must be 'all' or any in {', '.join(rng_methods)}")

return random_spikes_indices

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ def fit(
model_folder_path: str,
detect_peaks_params: dict,
peak_selection_params: dict,
job_kwargs: dict = None,
job_kwargs: dict | None = None,
ms_before: float = 1.0,
ms_after: float = 1.0,
whiten: bool = True,
radius_um: float = None,
radius_um: float | None = None,
) -> "IncrementalPCA":
"""
Train a pca model using the data in the recording object and the parameters provided.
Expand Down
8 changes: 7 additions & 1 deletion src/spikeinterface/widgets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def validate_segment_indices(segment_indices: list[int] | None, sorting: BaseSor
return segment_indices


def get_segment_durations(sorting: BaseSorting, segment_indices: list[int]) -> list[float]:
def get_segment_durations(sorting: BaseSorting, segment_indices: list[int] = None) -> list[float]:
"""
Calculate the duration of each segment in a sorting object.

Expand All @@ -408,11 +408,17 @@ def get_segment_durations(sorting: BaseSorting, segment_indices: list[int]) -> l
sorting : BaseSorting
The sorting object containing spike data

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

I am a master reviewer.

segment_indices : list[int] | None
List of the segment indices to process. Default to None.

Returns
-------
list[float]
List of segment durations in seconds
"""
if segment_indices is None:
segment_indices = range(sorting.get_num_segments())

spikes = sorting.to_spike_vector()

segment_boundaries = [
Expand Down
Loading