From 3dc57290dbde0aeaa5048f2301ee75015a93fe26 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 29 Dec 2025 15:43:44 +0100 Subject: [PATCH 01/19] Test IBL extractors tests failing for PI update --- src/spikeinterface/extractors/tests/test_iblextractors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index 972a8e7bb0..56d01e38cf 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -76,8 +76,8 @@ def test_offsets(self): def test_probe_representation(self): probe = self.recording.get_probe() - expected_probe_representation = "Probe - 384ch - 1shanks" - assert repr(probe) == expected_probe_representation + expected_probe_representation = "Probe - 384ch" + assert expected_probe_representation in repr(probe) def test_property_keys(self): expected_property_keys = [ From 845ea33e05a8f267b7e9a7a6ca6af1edbf3db52d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 24 Mar 2026 17:11:55 +0100 Subject: [PATCH 02/19] Use ProbeGroup object instead of contact_vector --- src/spikeinterface/core/base.py | 47 ++++ src/spikeinterface/core/baserecording.py | 18 +- .../core/baserecordingsnippets.py | 243 +++++++++++------- src/spikeinterface/core/basesnippets.py | 6 +- src/spikeinterface/core/binaryfolder.py | 2 + .../core/channelsaggregationrecording.py | 44 ++-- src/spikeinterface/core/channelslice.py | 16 +- src/spikeinterface/core/core_tools.py | 4 + src/spikeinterface/core/recording_tools.py | 2 - src/spikeinterface/core/sortinganalyzer.py | 1 - .../core/tests/test_baserecording.py | 18 +- src/spikeinterface/core/zarrextractors.py | 2 +- src/spikeinterface/sorters/external/hdsort.py | 4 +- 13 files changed, 264 insertions(+), 143 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 9dc270d38d..d8d37b2875 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -213,6 +213,14 @@ def id_to_index(self, id) -> int: return ind def annotate(self, **new_annotations) -> None: + """Adds annotations. + + Parameters + ---------- + **new_annotations : dict + Key-value pairs of annotations to add. If an annotation key already exists, + it will be overwritten. + """ self._annotations.update(new_annotations) def set_annotation(self, annotation_key: str, value: Any, overwrite=False) -> None: @@ -236,6 +244,24 @@ def set_annotation(self, annotation_key: str, value: Any, overwrite=False) -> No else: raise ValueError(f"{annotation_key} is already an annotation key. Use 'overwrite=True' to overwrite it") + def delete_annotation(self, annotation_key: str) -> None: + """Deletes existing annotation. + + Parameters + ---------- + annotation_key : str + The annotation key to delete + + Raises + ------ + ValueError + If the annotation key does not exist + """ + if annotation_key in self._annotations.keys(): + del self._annotations[annotation_key] + else: + raise ValueError(f"{annotation_key} is not an annotation key") + def get_preferred_mp_context(self): """ Get the preferred context for multiprocessing. @@ -434,6 +460,15 @@ def copy_metadata( if self._preferred_mp_context is not None: other._preferred_mp_context = self._preferred_mp_context + self._extra_metadata_copy(other) + + def _extra_metadata_copy(self, other: BaseExtractor): + """ + This is a hook to copy extra metadata that is not in the annotations/properties dict. + It is used for instance to copy the probe in the FrameSliceRecording. + """ + pass + def to_dict( self, include_annotations: bool = False, @@ -567,6 +602,8 @@ def to_dict( folder_metadata = Path(folder_metadata).resolve().absolute().relative_to(relative_to) dump_dict["folder_metadata"] = str(folder_metadata) + self._extra_metadata_to_dict(dump_dict) + return dump_dict @staticmethod @@ -855,6 +892,14 @@ def _extra_metadata_to_folder(self, folder): # This implemented in BaseRecording for probe pass + def _extra_metadata_from_dict(self, dump_dict): + # This implemented in BaseRecording for probe + pass + + def _extra_metadata_to_dict(self, dump_dict): + # This implemented in BaseRecording for probe + pass + def save(self, **kwargs) -> BaseExtractor: """ Save a SpikeInterface object. @@ -1154,6 +1199,8 @@ def _load_extractor_from_dict(dic) -> BaseExtractor: for k, v in dic["properties"].items(): extractor.set_property(k, v) + extractor._extra_metadata_from_dict(dic) + return extractor diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 75bd47597b..8dc27f56c0 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -21,7 +21,6 @@ class BaseRecording(BaseRecordingSnippets): _main_annotations = BaseRecordingSnippets._main_annotations + ["is_filtered"] _main_properties = [ "group", - "location", "gain_to_uV", "offset_to_uV", "gain_to_physical_unit", @@ -591,6 +590,9 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): kwargs, job_kwargs = split_job_kwargs(save_kwargs) if format == "binary": + from .binaryfolder import BinaryFolderRecording + from .binaryrecordingextractor import BinaryRecordingExtractor + folder = kwargs["folder"] file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())] dtype = kwargs.get("dtype", None) or self.get_dtype() @@ -598,8 +600,6 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): write_binary_recording(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs) - from .binaryrecordingextractor import BinaryRecordingExtractor - # This is created so it can be saved as json because the `BinaryFolderRecording` requires it loading # See the __init__ of `BinaryFolderRecording` binary_rec = BinaryRecordingExtractor( @@ -616,8 +616,9 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): offset_to_uV=self.get_channel_offsets(), ) binary_rec.dump(folder / "binary.json", relative_to=folder) - - from .binaryfolder import BinaryFolderRecording + if self.has_probe(): + probegroup = self.get_probegroup() + write_probeinterface(folder / "probe.json", probegroup) cached = BinaryFolderRecording(folder_path=folder) @@ -648,10 +649,7 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): else: raise ValueError(f"format {format} not supported") - if self.get_property("contact_vector") is not None: - probegroup = self.get_probegroup() - cached.set_probegroup(probegroup) - + # TODO: write binary should save timestamps too for segment_index in range(self.get_num_segments()): if self.has_time_vector(segment_index): # the use of get_times is preferred since timestamps are converted to array @@ -676,7 +674,7 @@ def _extra_metadata_from_folder(self, folder): def _extra_metadata_to_folder(self, folder): # save probe - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() write_probeinterface(folder / "probe.json", probegroup) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 58e91ec35c..ce3faa0c32 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -1,5 +1,5 @@ from pathlib import Path - +import warnings import numpy as np from probeinterface import Probe, ProbeGroup, write_probeinterface, read_probeinterface, select_axes @@ -19,6 +19,7 @@ def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype BaseExtractor.__init__(self, channel_ids) self._sampling_frequency = float(sampling_frequency) self._dtype = np.dtype(dtype) + self._probegroup = None @property def channel_ids(self): @@ -51,15 +52,31 @@ def has_scaleable_traces(self) -> bool: return True def has_probe(self) -> bool: - return "contact_vector" in self.get_property_keys() + if self._probegroup is None and self.get_property("contact_vector") is not None: + # if contact_vector is present we can reconstruct the probe + self._probegroup = self._build_probegroup_from_properties() + return self._probegroup is not None + + def has_3d_probe(self) -> bool: + if self.has_probe(): + probe = self.get_probegroup().probes[0] + return probe.ndim == 3 + else: + return False def has_channel_location(self) -> bool: - return self.has_probe() or "location" in self.get_property_keys() + return self.has_probe() def is_filtered(self): # the is_filtered is handle with annotation return self._annotations.get("is_filtered", False) + def reset_probe(self): + """ + Removes probe information + """ + self._probegroup = None + def set_probe(self, probe, group_mode="auto", in_place=False): """ Attach a list of Probe object to a recording. @@ -85,12 +102,43 @@ def set_probe(self, probe, group_mode="auto", in_place=False): probegroup.add_probe(probe) return self._set_probes(probegroup, group_mode=group_mode, in_place=in_place) - def set_probegroup(self, probegroup, group_mode="auto", in_place=False): - return self._set_probes(probegroup, group_mode=group_mode, in_place=in_place) + def set_probegroup(self, probegroup, group_mode="auto", in_place=False, raise_if_overlapping_probes=True): + """ + Attach a ProbeGroup to a recording. + For this ProbeGroup.get_global_device_channel_indices() is used to link contacts to recording channels. + If some contacts of the probe group are not connected (device_channel_indices=-1) + then the recording is "sliced" and only connected channel are kept. + + The probe group order is not kept. Channel ids are re-ordered to match the channel_ids of the recording. - def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): + Parameters + ---------- + probe_or_probegroup: Probe, list of Probe, or ProbeGroup + The probe(s) to be attached to the recording + group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" + How to add the "group" property. + "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks and two sides are present. + in_place: bool + False by default. + Useful internally when extractor do self.set_probegroup(probe) + raise_if_overlapping_probes: bool + If True, raises an error if the probes overlap. If False, it will just warn + + Returns + ------- + sub_recording: BaseRecording + A view of the recording (ChannelSlice or clone or itself) """ - Attach a list of Probe objects to a recording. + return self._set_probes( + probegroup, + group_mode=group_mode, + in_place=in_place, + raise_if_overlapping_probes=raise_if_overlapping_probes, + ) + + def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False, raise_if_overlapping_probes=True): + """ + Attach a list of Probe objects or a ProbeGroup to a recording. For this Probe.device_channel_indices is used to link contacts to recording channels. If some contacts of the Probe are not connected (device_channel_indices=-1) then the recording is "sliced" and only connected channel are kept. @@ -100,7 +148,7 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): Parameters ---------- - probe_or_probegroup: Probe, list of Probe, or ProbeGroup + probe_or_probegroup: Probe, list of Probes, ProbeGroup, or dict The probe(s) to be attached to the recording group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" How to add the "group" property. @@ -108,6 +156,8 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): in_place: bool False by default. Useful internally when extractor do self.set_probegroup(probe) + raise_if_overlapping_probes: bool + If True, raises an error if the probes overlap. If False, it will just warn Returns ------- @@ -132,12 +182,14 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): probegroup = ProbeGroup() for probe in probe_or_probegroup: probegroup.add_probe(probe) + elif isinstance(probe_or_probegroup, dict): + probegroup = ProbeGroup.from_dict(probe_or_probegroup) else: raise ValueError("must give Probe or ProbeGroup or list of Probe") # check that the probe do not overlap num_probes = len(probegroup.probes) - if num_probes > 1: + if num_probes > 1 and raise_if_overlapping_probes: check_probe_do_not_overlap(probegroup.probes) # handle not connected channels @@ -145,36 +197,37 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): probe.device_channel_indices is not None for probe in probegroup.probes ), "Probe must have device_channel_indices" - # this is a vector with complex fileds (dataframe like) that handle all contact attr + # TODO: add get_slice for probegroup to handle not connected channels probe_as_numpy_array = probegroup.to_numpy(complete=True) - - # keep only connected contact ( != -1) - keep = probe_as_numpy_array["device_channel_indices"] >= 0 + device_channel_indices = probegroup.get_global_device_channel_indices()["device_channel_indices"] + keep = device_channel_indices >= 0 if np.any(~keep): warn("The given probes have unconnected contacts: they are removed") - + device_channel_indices = device_channel_indices[keep] probe_as_numpy_array = probe_as_numpy_array[keep] + if len(device_channel_indices) > 0: + probegroup = probegroup.get_slice(device_channel_indices) + order = np.argsort(device_channel_indices) + device_channel_indices = device_channel_indices[order] + probegroup.set_global_device_channel_indices(np.arange(len(device_channel_indices))) + + # check TODO: Where did this came from? + number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) + if number_of_device_channel_indices >= self.get_num_channels(): + error_msg = ( + f"The given Probe either has 'device_channel_indices' that does not match channel count \n" + f"{len(device_channel_indices)} vs {self.get_num_channels()} \n" + f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n" + f"If using all channels remember that python is 0-indexed so max device_channel_index should be {self.get_num_channels() - 1} \n" + f"device_channel_indices are the following: {device_channel_indices} \n" + f"recording channels are the following: {self.get_channel_ids()} \n" + ) + raise ValueError(error_msg) + else: + warn("No connected channel in the probe! The probe will be attached but no channel will be selected.") + probegroup = ProbeGroup() # empty probegroup - device_channel_indices = probe_as_numpy_array["device_channel_indices"] - order = np.argsort(device_channel_indices) - device_channel_indices = device_channel_indices[order] - - # check TODO: Where did this came from? - number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) - if number_of_device_channel_indices >= self.get_num_channels(): - error_msg = ( - f"The given Probe either has 'device_channel_indices' that does not match channel count \n" - f"{len(device_channel_indices)} vs {self.get_num_channels()} \n" - f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n" - f"If using all channels remember that python is 0-indexed so max device_channel_index should be {self.get_num_channels() - 1} \n" - f"device_channel_indices are the following: {device_channel_indices} \n" - f"recording channels are the following: {self.get_channel_ids()} \n" - ) - raise ValueError(error_msg) - - new_channel_ids = self.get_channel_ids()[device_channel_indices] - probe_as_numpy_array = probe_as_numpy_array[order] - probe_as_numpy_array["device_channel_indices"] = np.arange(probe_as_numpy_array.size, dtype="int64") + new_channel_ids = self.channel_ids[device_channel_indices] # create recording : channel slice or clone or self if in_place: @@ -187,21 +240,9 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): else: sub_recording = self.select_channels(new_channel_ids) - # create a vector that handle all contacts in property - sub_recording.set_property("contact_vector", probe_as_numpy_array, ids=None) - - # planar_contour is saved in annotations - for probe_index, probe in enumerate(probegroup.probes): - contour = probe.probe_planar_contour - if contour is not None: - sub_recording.set_annotation(f"probe_{probe_index}_planar_contour", contour, overwrite=True) - - # duplicate positions to "locations" property - ndim = probegroup.ndim - locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") - for i, dim in enumerate(["x", "y", "z"][:ndim]): - locations[:, i] = probe_as_numpy_array[dim] - sub_recording.set_property("location", locations, ids=None) + # # create a vector that handle all contacts in property + # sub_recording.set_property("contact_vector", probe_as_numpy_array, ids=None) + sub_recording._probegroup = probegroup # handle groups has_shank_id = "shank_ids" in probe_as_numpy_array.dtype.fields @@ -232,17 +273,11 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): groups[mask] = group sub_recording.set_property("group", groups, ids=None) - # add probe annotations to recording - probes_info = [] - for probe in probegroup.probes: - probes_info.append(probe.annotations) - sub_recording.annotate(probes_info=probes_info) - return sub_recording def get_probe(self): probes = self.get_probes() - assert len(probes) == 1, "there are several probe use .get_probes() or get_probegroup()" + assert len(probes) == 1, "There are several probe use .get_probes() or get_probegroup()" return probes[0] def get_probes(self): @@ -250,11 +285,22 @@ def get_probes(self): return probegroup.probes def get_probegroup(self): + if self._probegroup is not None: + return self._probegroup + else: # Backward compatibility: if contact_vector is present we reconstruct the probe, otherwise we look for + probegroup = self._build_probegroup_from_properties() + if probegroup is None: + raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") + self._probegroup = probegroup + return probegroup + + def _build_probegroup_from_properties(self): + # location and create a dummy probe arr = self.get_property("contact_vector") if arr is None: positions = self.get_property("location") if positions is None: - raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") + return None else: warn("There is no Probe attached to this recording. Creating a dummy one with contact positions") probe = self.create_dummy_probe_from_locations(positions) @@ -273,8 +319,15 @@ def get_probegroup(self): contour = self.get_annotation(f"probe_{probe_index}_planar_contour") if contour is not None: probe.set_planar_contour(contour) + self.delete_annotation(f"probe_{probe_index}_planar_contour") + # delete contact_vector as it is not needed anymore + self.delete_property("contact_vector") return probegroup + def _extra_metadata_copy(self, other): + if self._probegroup is not None: + other._probegroup = self._probegroup.copy() + def _extra_metadata_from_folder(self, folder): # load probe folder = Path(folder) @@ -284,10 +337,22 @@ def _extra_metadata_from_folder(self, folder): def _extra_metadata_to_folder(self, folder): # save probe - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() write_probeinterface(folder / "probe.json", probegroup) + def _extra_metadata_from_dict(self, dump_dict): + # load probe + if "probegroup" in dump_dict: + probegroup = dump_dict["probegroup"] + self.set_probegroup(probegroup, in_place=True) + + def _extra_metadata_to_dict(self, dump_dict): + # save probe + if self.has_probe(): + probegroup = self.get_probegroup() + dump_dict["probegroup"] = probegroup + def create_dummy_probe_from_locations(self, locations, shape="circle", shape_params={"radius": 1}, axes="xy"): """ Creates a "dummy" probe based on locations. @@ -330,51 +395,55 @@ def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params ---------- locations : np.array Array with channel locations (num_channels, ndim) [ndim can be 2 or 3] - shape : str, default: default: "circle" + shape : str, default: "circle" Electrode shapes shape_params : dict, default: {"radius": 1} Shape parameters axes : "xy" | "yz" | "xz", default: "xy" If ndim is 3, indicates the axes that define the plane of the electrodes """ - probe = self.create_dummy_probe_from_locations(locations, shape=shape, shape_params=shape_params, axes=axes) + probe = self.create_dummy_probe_from_locations( + np.array(locations), shape=shape, shape_params=shape_params, axes=axes + ) self.set_probe(probe, in_place=True) def set_channel_locations(self, locations, channel_ids=None): - if self.get_property("contact_vector") is not None: - raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)") - self.set_property("location", locations, ids=channel_ids) + warnings.warn( + ( + "set_channel_locations() is deprecated and will be removed in version 0.106.0. " + "If you want to set probe information, use `set_dummy_probe_from_locations()`." + ), + DeprecationWarning, + stacklevel=2, + ) + self.set_dummy_probe_from_locations(locations, axes="xy") def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarray: if channel_ids is None: channel_ids = self.get_channel_ids() channel_indices = self.ids_to_indices(channel_ids) - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - # here we bypass the probe reconstruction so this works both for probe and probegroup - ndim = len(axes) - all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") - for i, dim in enumerate(axes): - all_positions[:, i] = contact_vector[dim] - positions = all_positions[channel_indices] - return positions - else: - locations = self.get_property("location") - if locations is None: - raise Exception("There are no channel locations") - locations = np.asarray(locations)[channel_indices] - return select_axes(locations, axes) + if not self.has_probe(): + raise ValueError("get_channel_locations(..) needs a probe to be attached to the recording") + probegroup = self.get_probegroup() + contact_positions = probegroup.get_global_contact_positions() + return select_axes(contact_positions, axes)[channel_indices] - def has_3d_locations(self) -> bool: - return self.get_property("location").shape[1] == 3 + def is_probe_3d(self) -> bool: + if not self.has_probe(): + raise ValueError("is_probe_3d() needs a probe to be attached to the recording") + probe = self.get_probegroup().probes[0] + return probe.ndim == 3 def clear_channel_locations(self, channel_ids=None): - if channel_ids is None: - n = self.get_num_channel() - else: - n = len(channel_ids) - locations = np.zeros((n, 2)) * np.nan - self.set_property("location", locations, ids=channel_ids) + warnings.warn( + ( + "clear_channel_locations() is deprecated and will be removed in version 0.106.0. " + "If you want to remove probe information, use `reset_probe()`." + ), + DeprecationWarning, + stacklevel=2, + ) + self.reset_probe() def set_channel_groups(self, groups, channel_ids=None): if "probes" in self._annotations: @@ -429,7 +498,7 @@ def planarize(self, axes: str = "xy"): BaseRecording The recording with 2D positions """ - assert self.has_3d_locations, "The 'planarize' function needs a recording with 3d locations" + assert self.has_3d_probe(), "The 'planarize' function needs a recording with 3d locations" assert len(axes) == 2, "You need to specify 2 dimensions (e.g. 'xy', 'zy')" probe2d = self.get_probe().to_2d(axes=axes) diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index b56a093ccc..39e8c51225 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -11,7 +11,7 @@ class BaseSnippets(BaseRecordingSnippets): Abstract class representing several multichannel snippets. """ - _main_properties = ["group", "location", "gain_to_uV", "offset_to_uV"] + _main_properties = ["group", "gain_to_uV", "offset_to_uV"] _main_features = [] def __init__(self, sampling_frequency: float, nbefore: int | None, snippet_len: int, channel_ids: list, dtype): @@ -259,9 +259,9 @@ def _save(self, format="npy", **save_kwargs): else: raise ValueError(f"format {format} not supported") - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() - cached.set_probegroup(probegroup) + cached.set_probegroup(probegroup, in_place=True) return cached diff --git a/src/spikeinterface/core/binaryfolder.py b/src/spikeinterface/core/binaryfolder.py index 4b9d7b7d09..8d8ed3c206 100644 --- a/src/spikeinterface/core/binaryfolder.py +++ b/src/spikeinterface/core/binaryfolder.py @@ -3,6 +3,8 @@ import numpy as np +from probeinterface import read_probeinterface + from .binaryrecordingextractor import BinaryRecordingExtractor from .core_tools import define_function_from_class, make_paths_absolute diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 0da4797440..4e933c9e9d 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -2,6 +2,7 @@ import numpy as np +from probeinterface import ProbeGroup from .baserecording import BaseRecording, BaseRecordingSegment @@ -90,31 +91,28 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record break for prop_name, prop_values in property_dict.items(): - if prop_name == "contact_vector": - # remap device channel indices correctly - prop_values["device_channel_indices"] = np.arange(self.get_num_channels()) self.set_property(key=prop_name, values=prop_values) - # if locations are present, check that they are all different! - if "location" in self.get_property_keys(): - location_tuple = [tuple(loc) for loc in self.get_property("location")] - assert len(set(location_tuple)) == self.get_num_channels(), ( - "Locations are not unique! " "Cannot aggregate recordings!" - ) - - planar_contour_keys = [ - key for recording in recording_list for key in recording.get_annotation_keys() if "planar_contour" in key - ] - if len(planar_contour_keys) > 0: - if all( - k == planar_contour_keys[0] for k in planar_contour_keys - ): # we add the 'planar_contour' annotations only if there is a unique one in the recording_list - planar_contour_key = planar_contour_keys[0] - collect_planar_contours = [] - for rec in recording_list: - collect_planar_contours.append(rec.get_annotation(planar_contour_key)) - if all(np.array_equal(arr, collect_planar_contours[0]) for arr in collect_planar_contours): - self.set_annotation(planar_contour_key, collect_planar_contours[0]) + # Aggregate probe information + all_probegroups = [rec.get_probegroup() for rec in recording_list if rec.has_probe()] + if len(all_probegroups) == len(recording_list): + # check that contact positions are unique across all recordings + all_positions = [] + for probegroup in all_probegroups: + for probe in probegroup.probes: + all_positions.extend(probe.contact_positions) + assert len(np.unique(np.array(all_positions), axis=0)) == len( + all_positions + ), "Contact positions are not unique! Cannot aggregate recordings." + + # Now make a new probegroup with all probes and set global device channel indices + all_probes = [] + for probegroup in all_probegroups: + all_probes.extend([p.copy() for p in probegroup.probes]) + probegroup_agg = ProbeGroup() + probegroup_agg.probes = all_probes + probegroup_agg.set_global_device_channel_indices(np.arange(num_all_channels)) + self.set_probegroup(probegroup_agg, in_place=True, raise_if_overlapping_probes=False) # finally add segments, we need a channel mapping ch_id = 0 diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index 67d25b2925..5e67acf304 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -62,10 +62,11 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) self._parent = parent_recording # change the wiring of the probe - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + if self._parent.has_probe(): + parent_probegroup = self._parent.get_probegroup() + sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) + sliced_probegroup.set_global_device_channel_indices(np.arange(len(self._channel_ids))) + self.set_probegroup(sliced_probegroup, in_place=True) # update dump dict self._kwargs = { @@ -152,10 +153,9 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): parent_snippets.copy_metadata(self, only_main=False, ids=self._channel_ids) # change the wiring of the probe - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + if self._parent_snippets.has_probe(): + parent_probegroup = self._parent_snippets.get_probegroup() + self.set_probe(parent_probegroup.get_slice(self._parent_channel_indices)) # update dump dict self._kwargs = { diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index ed98613553..ba11642a5e 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -10,6 +10,7 @@ from collections import namedtuple import inspect +from probeinterface import ProbeGroup import numpy as np @@ -148,6 +149,9 @@ def default(self, obj): if isinstance(obj, Motion): return obj.to_dict() + if isinstance(obj, ProbeGroup): + return obj.to_dict() + # The base-class handles the assertion return super().default(obj) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 48eb2d7fd4..19e54d1fc9 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -1023,8 +1023,6 @@ def get_rec_attributes(recording): The rec_attributes dictionary """ properties_to_attrs = deepcopy(recording._properties) - if "contact_vector" in properties_to_attrs: - del properties_to_attrs["contact_vector"] rec_attributes = dict( channel_ids=recording.channel_ids, sampling_frequency=recording.get_sampling_frequency(), diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 8e16757bcc..e712b881d5 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -365,7 +365,6 @@ def create( ) # check that multiple probes are non-overlapping all_probes = recording.get_probegroup().probes - check_probe_do_not_overlap(all_probes) if has_exceeding_spikes(sorting=sorting, recording=recording): warnings.warn( diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 1ebeb677c6..c724cbba98 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -196,7 +196,13 @@ def test_BaseRecording(create_cache_folder): probe.create_auto_shape() rec_p = rec.set_probe(probe, group_mode="auto") + positions2 = rec_p.get_channel_locations() + assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) + rec_p = rec.set_probe(probe, group_mode="by_shank") + positions2 = rec_p.get_channel_locations() + assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) + rec_p = rec.set_probe(probe, group_mode="by_probe") positions2 = rec_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) @@ -204,7 +210,6 @@ def test_BaseRecording(create_cache_folder): probe2 = rec_p.get_probe() positions3 = probe2.contact_positions assert np.array_equal(positions2, positions3) - assert np.array_equal(probe2.device_channel_indices, [0, 1]) # test save with probe @@ -284,8 +289,9 @@ def test_BaseRecording(create_cache_folder): rec_int16.set_property("offset_to_uV", [0.0] * 5) # Test deprecated return_scaled parameter - traces_float32_old = rec_int16.get_traces(return_scaled=True) # Keep this for testing the deprecation warning - assert traces_float32_old.dtype == "float32" + with pytest.warns(DeprecationWarning, match="`return_scaled` is deprecated"): + traces_float32_old = rec_int16.get_traces(return_scaled=True) # Keep this for testing the deprecation warning + assert traces_float32_old.dtype == "float32" # Test new return_in_uV parameter traces_float32_new = rec_int16.get_traces(return_in_uV=True) @@ -342,7 +348,7 @@ def test_BaseRecording(create_cache_folder): # test 3d probe rec_3d = generate_recording(ndim=3, num_channels=30) - locations_3d = rec_3d.get_property("location") + locations_3d = rec_3d.get_probe().contact_positions locations_xy = rec_3d.get_channel_locations(axes="xy") assert np.allclose(locations_xy, locations_3d[:, [0, 1]]) @@ -411,8 +417,8 @@ def test_json_pickle_equivalence(create_cache_folder): for key, value in data_json.items(): # skip probe info, since pickle keeps some additional information - if key not in ["properties"]: - if isinstance(value, dict): + if key not in ["properties", "probegroup"]: + if isinstance(value, dict) and isinstance(data_pickle[key], dict): for sub_key, sub_value in value.items(): assert np.all(sub_value == data_pickle[key][sub_key]) else: diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index e58ef4ee68..a0fa2a24a1 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -494,7 +494,7 @@ def add_recording_to_zarr_group( ) # save probe - if recording.get_property("contact_vector") is not None: + if recording.has_probe(): probegroup = recording.get_probegroup() zarr_group.attrs["probe"] = check_json(probegroup.to_dict(array_as_list=True)) diff --git a/src/spikeinterface/sorters/external/hdsort.py b/src/spikeinterface/sorters/external/hdsort.py index 3daaf85b7a..07d59332d2 100644 --- a/src/spikeinterface/sorters/external/hdsort.py +++ b/src/spikeinterface/sorters/external/hdsort.py @@ -276,8 +276,8 @@ def write_hdsort_input_format(cls, recording, save_path, chunk_memory="500M"): [("electrode", np.int32), ("x", np.float64), ("y", np.float64), ("channel", np.int32)] ) - locations = recording.get_property("location") - assert locations is not None, "'location' property is needed to run HDSort" + assert recording.has_probe(), "The recording must have a probe to run HDSort" + locations = recording.get_channel_locations() with h5py.File(save_path, "w") as f: f.create_group("ephys") From 1426bf8e1f80d00a241bf1a22e8412d23c126b0a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 24 Mar 2026 17:14:04 +0100 Subject: [PATCH 03/19] Apply suggestion from @alejoe91 --- src/spikeinterface/core/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index d8d37b2875..67ad9a3fef 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -465,7 +465,6 @@ def copy_metadata( def _extra_metadata_copy(self, other: BaseExtractor): """ This is a hook to copy extra metadata that is not in the annotations/properties dict. - It is used for instance to copy the probe in the FrameSliceRecording. """ pass From c2dbeaf64d5d9764920d6ffaa7cd959f3e9e8c72 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 24 Mar 2026 17:22:44 +0100 Subject: [PATCH 04/19] Apply suggestions from code review Co-authored-by: Alessio Buccino --- src/spikeinterface/core/baserecordingsnippets.py | 8 +++----- src/spikeinterface/core/channelslice.py | 4 +++- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index ce3faa0c32..db033d2393 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -197,7 +197,6 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False, ra probe.device_channel_indices is not None for probe in probegroup.probes ), "Probe must have device_channel_indices" - # TODO: add get_slice for probegroup to handle not connected channels probe_as_numpy_array = probegroup.to_numpy(complete=True) device_channel_indices = probegroup.get_global_device_channel_indices()["device_channel_indices"] keep = device_channel_indices >= 0 @@ -240,8 +239,7 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False, ra else: sub_recording = self.select_channels(new_channel_ids) - # # create a vector that handle all contacts in property - # sub_recording.set_property("contact_vector", probe_as_numpy_array, ids=None) + # Set probegroup sub_recording._probegroup = probegroup # handle groups @@ -431,8 +429,8 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarra def is_probe_3d(self) -> bool: if not self.has_probe(): raise ValueError("is_probe_3d() needs a probe to be attached to the recording") - probe = self.get_probegroup().probes[0] - return probe.ndim == 3 + probegroup = self.get_probegroup() + return probegroup.ndim == 3 def clear_channel_locations(self, channel_ids=None): warnings.warn( diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index 5e67acf304..748d052530 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -155,7 +155,9 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): # change the wiring of the probe if self._parent_snippets.has_probe(): parent_probegroup = self._parent_snippets.get_probegroup() - self.set_probe(parent_probegroup.get_slice(self._parent_channel_indices)) + sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) + sliced_probegroup.set_global_device_channel_indices(np.arange(len(self._channel_ids))) + self.set_probegroup(sliced_probegroup, in_place=True) # update dump dict self._kwargs = { From 15331e5fc73c664555e8eec1c01fe3a75bd8f177 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 25 Mar 2026 12:04:35 +0100 Subject: [PATCH 05/19] Remove contact vector from extractors/sortingcomponents --- src/spikeinterface/extractors/neoextractors/biocam.py | 6 ++++-- .../extractors/neoextractors/maxwell.py | 3 ++- .../extractors/tests/test_iblextractors.py | 4 +++- src/spikeinterface/preprocessing/zero_channel_pad.py | 2 +- .../sortingcomponents/motion/motion_interpolation.py | 11 ++++++----- 5 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index 8d1fac0c72..c85e82b574 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -70,9 +70,11 @@ def __init__( if electrode_width is not None: probe_kwargs["electrode_width"] = electrode_width probe = probeinterface.read_3brain(file_path, **probe_kwargs) + rows = probe.contact_annotations["row"] + cols = probe.contact_annotations["col"] self.set_probe(probe, in_place=True) - self.set_property("row", self.get_property("contact_vector")["row"]) - self.set_property("col", self.get_property("contact_vector")["col"]) + self.set_property("row", rows) + self.set_property("col", cols) self._kwargs.update( { diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 932ecee106..5eaa49e6b8 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -74,8 +74,9 @@ def __init__( # rec_name auto set by neo rec_name = self.neo_reader.rec_name probe = probeinterface.read_maxwell(file_path, well_name=well_name, rec_name=rec_name) + electrodes = probe.contact_annotations["electrode"] self.set_probe(probe, in_place=True) - self.set_property("electrode", self.get_property("contact_vector")["electrode"]) + self.set_property("electrode", electrodes) self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), rec_name=rec_name)) @classmethod diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index 5306de2441..912d22627c 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -84,7 +84,6 @@ def test_property_keys(self): expected_property_keys = [ "gain_to_uV", "offset_to_uV", - "contact_vector", "location", "group", "shank", @@ -97,6 +96,9 @@ def test_property_keys(self): ] self.assertCountEqual(first=self.recording.get_property_keys(), second=expected_property_keys) + def test_has_probe(self): + assert self.recording.has_probe() is True + def test_trace_shape(self): expected_shape = (21, 384) self.assertTupleEqual(tuple1=self.small_scaled_trace.shape, tuple2=expected_shape) diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index 35b984449d..aaede9de5e 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -157,7 +157,7 @@ def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping: "The new mapping cannot exceed total number of channels " "in the zero-chanenl-padded recording." ) else: - if "locations" in recording.get_property_keys() or "contact_vector" in recording.get_property_keys(): + if recording.has_probe(): self.channel_mapping = np.argsort(recording.get_channel_locations()[:, 1]) else: self.channel_mapping = np.arange(recording.get_num_channels()) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index a50b9609b9..aa6f43936b 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -405,11 +405,12 @@ def __init__( if border_mode == "remove_channels": # change the wiring of the probe - # TODO this is also done in ChannelSliceRecording, this should be done in a common place - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + if recording.has_probe(): + probegroup = recording.get_probegroup() + channel_indices = recording.ids_to_indices(channel_ids) + probegroup_sliced = probegroup.get_slice(channel_indices) + probegroup_sliced.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) + self.set_probegroup(probegroup_sliced, in_place=True) # handle manual interpolation_time_bin_centers_s # the case where interpolation_time_bin_size_s is set is handled per-segment below From 4ccb318742154d0312df875757fcd74bc34fccce Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 25 Mar 2026 12:15:24 +0100 Subject: [PATCH 06/19] fix: update test_interpolate_bad_channels probe manipulation --- .../tests/test_interpolate_bad_channels.py | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index a571894374..1294b57a91 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -2,6 +2,8 @@ import numpy as np import os +import probeinterface as pi + import spikeinterface as si import spikeinterface.preprocessing as spre import spikeinterface.extractors as se @@ -125,9 +127,12 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan # distribute default probe locations across 4 shanks if set rng = np.random.default_rng(seed=None) - x = rng.choice(shanks, num_channels) - for idx, __ in enumerate(recording._properties["contact_vector"]): - recording._properties["contact_vector"][idx][1] = x[idx] + x_new = rng.choice(shanks, num_channels) + probe = recording.get_probe() + new_positions = probe.contact_positions.copy() + new_positions[:, 0] = x_new # column 0 is x + recording._probegroup.probes[0]._contact_positions = new_positions + recording.set_probe(probe, in_place=True) # generate random bad channel locations bad_channel_indexes = rng.choice(num_channels, rng.integers(1, int(num_channels / 5)), replace=False) @@ -161,18 +166,21 @@ def test_output_values(): the non-interpolated channels is also an implicit test these were not accidently changed. """ - recording = generate_recording(num_channels=5, durations=[1]) + recording = generate_recording(num_channels=5, durations=[1], set_probe=False) bad_channel_indexes = np.array([0]) bad_channel_ids = recording.channel_ids[bad_channel_indexes] - new_probe_locs = [ - [5, 7, 3, 5, 5], # 5 channels, a in the center ('bad channel', zero index) - [5, 5, 5, 7, 3], - ] # all others equal distance away. - # Overwrite the probe information with the new locations - for idx, (x, y) in enumerate(zip(*new_probe_locs)): - recording._properties["contact_vector"][idx][1] = x - recording._properties["contact_vector"][idx][2] = y + probe_locs = np.array( + [ + [5, 7, 3, 5, 5], # 5 channels, a in the center ('bad channel', zero index) + [5, 5, 5, 7, 3], + ] # all others equal distance away. + ).T + # Set the probe information with the new locations + probe = pi.Probe(ndim=2) + probe.set_contacts(positions=probe_locs) + probe.set_device_channel_indices(np.arange(len(probe_locs))) + recording.set_probe(probe, in_place=True) # Run interpolation in SI and check the interpolated channel # 0 is a linear combination of other channels @@ -186,8 +194,7 @@ def test_output_values(): # Shift the last channel position so that it is 4 units, rather than 2 # away. Setting sigma_um = p = 1 allows easy calculation of the expected # weights. - recording._properties["contact_vector"][-1][1] = 5 - recording._properties["contact_vector"][-1][2] = 9 + recording._probegroup.probes[0]._contact_positions[-1] = [5, 9] expected_weights = np.r_[np.tile(np.exp(-2), 3), np.exp(-4)] expected_weights /= np.sum(expected_weights) From 485a3548a5cb54d4f6c2ba233037e76c598b95a4 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 25 Mar 2026 12:16:12 +0100 Subject: [PATCH 07/19] test: remove 'location' from IBL properties check --- src/spikeinterface/extractors/tests/test_iblextractors.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index 912d22627c..87eb4df47a 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -84,7 +84,6 @@ def test_property_keys(self): expected_property_keys = [ "gain_to_uV", "offset_to_uV", - "location", "group", "shank", "shank_row", From 4d2c56f5adf48fbc171d13783058094db8e73609 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 25 Mar 2026 14:47:32 +0100 Subject: [PATCH 08/19] fix: extra_metadata not used in copy_metadata if only_main=True --- src/spikeinterface/core/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 67ad9a3fef..4e0bdc1b6a 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -460,7 +460,8 @@ def copy_metadata( if self._preferred_mp_context is not None: other._preferred_mp_context = self._preferred_mp_context - self._extra_metadata_copy(other) + if not only_main: + self._extra_metadata_copy(other) def _extra_metadata_copy(self, other: BaseExtractor): """ From dd265481b6ad6a9f3d6591c0a8b2559975360e0f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 25 Mar 2026 18:00:40 +0100 Subject: [PATCH 09/19] Fix dtype issue in average_across_directions --- src/spikeinterface/preprocessing/average_across_direction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/average_across_direction.py b/src/spikeinterface/preprocessing/average_across_direction.py index 113d1e22f1..dfef781ec1 100644 --- a/src/spikeinterface/preprocessing/average_across_direction.py +++ b/src/spikeinterface/preprocessing/average_across_direction.py @@ -63,7 +63,7 @@ def __init__( # my geometry channel_locations = np.zeros( (n_pos_unique, parent_channel_locations.shape[1]), - dtype=parent_channel_locations.dtype, + dtype=np.float32, ) # average other dimensions in the geometry other_dim = np.arange(parent_channel_locations.shape[1]) != dim From db357a0e59a7bbbe25f99569a8f9ef0db5f5a037 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 16 Apr 2026 15:08:05 +0200 Subject: [PATCH 10/19] fix annotations --- src/spikeinterface/core/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 57e9e72765..aa41d343be 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -470,7 +470,7 @@ def copy_metadata( if not only_main: self._extra_metadata_copy(other) - def _extra_metadata_copy(self, other: BaseExtractor): + def _extra_metadata_copy(self, other: "BaseExtractor") -> None: """ This is a hook to copy extra metadata that is not in the annotations/properties dict. """ From bf5a1a4551621b3146582cc4beb918a9d5f293fe Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 16 Apr 2026 18:16:24 +0200 Subject: [PATCH 11/19] fix ibl tests --- .../extractors/tests/test_iblextractors.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index 87eb4df47a..fa9d59ec47 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -31,11 +31,13 @@ def setUpClass(cls): cache_dir=None, ) except: + print("Skipping test due to server being down.") pytest.skip("Skipping test due to server being down.") try: cls.recording = read_ibl_recording(eid=cls.eid, stream_name="probe00.ap", one=cls.one) except requests.exceptions.HTTPError as e: if e.response.status_code == 503: + print("Skipping test due to server being down (HTTP 503).") pytest.skip("Skipping test due to server being down (HTTP 503).") else: raise @@ -68,11 +70,11 @@ def test_channel_ids(self): def test_gains(self): expected_gains = 2.34375 * np.ones(shape=384) - assert_array_equal(x=self.recording.get_channel_gains(), y=expected_gains) + assert_array_equal(self.recording.get_channel_gains(), expected_gains) def test_offsets(self): expected_offsets = np.zeros(shape=384) - assert_array_equal(x=self.recording.get_channel_offsets(), y=expected_offsets) + assert_array_equal(self.recording.get_channel_offsets(), expected_offsets) def test_probe_representation(self): probe = self.recording.get_probe() @@ -143,11 +145,11 @@ def test_channel_ids(self): def test_gains(self): expected_gains = np.concatenate([2.34375 * np.ones(shape=384), [1171.875]]) - assert_array_equal(x=self.recording.get_channel_gains(), y=expected_gains) + assert_array_equal(self.recording.get_channel_gains(), expected_gains) def test_offsets(self): expected_offsets = np.zeros(shape=385) - assert_array_equal(x=self.recording.get_channel_offsets(), y=expected_offsets) + assert_array_equal(self.recording.get_channel_offsets(), expected_offsets) def test_probe_representation(self): expected_exception = ValueError From f712b343d440658e176ffd2ed9b6c727c270c01e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 29 Jun 2026 16:46:36 +0200 Subject: [PATCH 12/19] refac: modify set_probe and add select_channels_with_probe --- .../scripts/create_probe_compat_fixtures.py | 96 +++++ .github/workflows/probe_backward_compat.yml | 51 +++ .../core/baserecordingsnippets.py | 347 +++++++++--------- src/spikeinterface/core/basesnippets.py | 2 +- .../core/channelsaggregationrecording.py | 2 +- src/spikeinterface/core/channelslice.py | 4 +- src/spikeinterface/core/generate.py | 6 +- src/spikeinterface/core/recording_tools.py | 89 +++-- src/spikeinterface/core/sortinganalyzer.py | 2 +- .../core/tests/test_baserecording.py | 80 ++-- .../core/tests/test_basesnippets.py | 3 +- .../test_channelsaggregationrecording.py | 59 +++ .../core/tests/test_channelslicerecording.py | 34 +- .../core/tests/test_probe_backward_compat.py | 111 ++++++ .../core/tests/test_sortinganalyzer.py | 2 +- src/spikeinterface/core/zarrextractors.py | 5 +- src/spikeinterface/extractors/bids.py | 4 +- src/spikeinterface/extractors/cbin_ibl.py | 4 +- .../extractors/iblextractors.py | 4 +- .../extractors/neoextractors/biocam.py | 2 +- .../extractors/neoextractors/maxwell.py | 2 +- .../extractors/neoextractors/mearec.py | 2 +- .../extractors/neoextractors/openephys.py | 4 +- .../extractors/neoextractors/spikegadgets.py | 2 +- .../extractors/neoextractors/spikeglx.py | 4 +- .../extractors/shybridextractors.py | 2 +- .../extractors/sinapsrecordingextractors.py | 4 +- src/spikeinterface/generation/drift_tools.py | 2 +- .../postprocessing/localization_tools.py | 2 +- .../tests/test_deepinterpolation.py | 2 +- .../tests/test_detect_bad_channels.py | 4 +- .../tests/test_highpass_spatial_filter.py | 2 +- .../tests/test_interpolate_bad_channels.py | 4 +- .../motion/motion_interpolation.py | 2 +- .../peak_localization/monopolar.py | 4 +- 35 files changed, 668 insertions(+), 281 deletions(-) create mode 100644 .github/scripts/create_probe_compat_fixtures.py create mode 100644 .github/workflows/probe_backward_compat.yml create mode 100644 src/spikeinterface/core/tests/test_probe_backward_compat.py diff --git a/.github/scripts/create_probe_compat_fixtures.py b/.github/scripts/create_probe_compat_fixtures.py new file mode 100644 index 0000000000..e41f7c7804 --- /dev/null +++ b/.github/scripts/create_probe_compat_fixtures.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python +""" +Creates probe compatibility fixtures using the *currently installed* spikeinterface. + +Run this script with spikeinterface==0.104.* installed to produce the fixture +files consumed by test_probe_backward_compat.py: + + python create_probe_compat_fixtures.py [output_dir] + +If output_dir is omitted, fixtures are written to ./probe_compat_fixtures. +""" + +import sys +import shutil +import numpy as np +from pathlib import Path + +import spikeinterface + +print(f"Creating fixtures with spikeinterface {spikeinterface.__version__}") + +OUTPUT_DIR = Path(sys.argv[1]) if len(sys.argv) > 1 else Path("probe_compat_fixtures") +if OUTPUT_DIR.exists(): + shutil.rmtree(OUTPUT_DIR) +OUTPUT_DIR.mkdir(parents=True) + +from probeinterface import generate_linear_probe, ProbeGroup +from spikeinterface.core import NumpyRecording + +# ----------------------------------------------------------------------- +# Fixture 1: single probe, sequential device_channel_indices +# ----------------------------------------------------------------------- +n = 8 +probe = generate_linear_probe(num_elec=n, ypitch=20.0) +probe.annotate(name="test_probe", manufacturer="test_vendor") +probe.set_contact_ids([f"e{i}" for i in range(n)]) +probe.set_device_channel_indices(np.arange(n)) +probe.create_auto_shape() + +traces = np.arange(1000 * n, dtype="int16").reshape(1000, n) +rec_single = NumpyRecording([traces], sampling_frequency=30000.0) +rec_single = rec_single.set_probe(probe) # old API: in_place=False, returns new recording + +rec_single.save(folder=str(OUTPUT_DIR / "single_probe_binary")) +rec_single.dump_to_json(str(OUTPUT_DIR / "single_probe.json"), relative_to=None) + +# ----------------------------------------------------------------------- +# Fixture 2: two probes with per-probe name/manufacturer +# ----------------------------------------------------------------------- +n_A, n_B = 8, 8 +probe_A = generate_linear_probe(num_elec=n_A, ypitch=20.0) +probe_A.move([0.0, 0.0]) +probe_A.annotate(name="probe_A", manufacturer="vendor_X") +probe_A.set_contact_ids([f"a{i}" for i in range(n_A)]) +probe_A.set_device_channel_indices(np.arange(n_A)) +probe_A.create_auto_shape() + +probe_B = generate_linear_probe(num_elec=n_B, ypitch=20.0) +probe_B.move([500.0, 0.0]) +probe_B.annotate(name="probe_B", manufacturer="vendor_Y") +probe_B.set_contact_ids([f"b{i}" for i in range(n_B)]) +probe_B.set_device_channel_indices(np.arange(n_A, n_A + n_B)) +probe_B.create_auto_shape() + +pg = ProbeGroup() +pg.add_probe(probe_A) +pg.add_probe(probe_B) + +n_total = n_A + n_B +traces2 = np.arange(1000 * n_total, dtype="int16").reshape(1000, n_total) +rec_two = NumpyRecording([traces2], sampling_frequency=30000.0) +rec_two = rec_two.set_probegroup(pg) # old API: in_place=False, returns new recording + +rec_two.save(folder=str(OUTPUT_DIR / "two_probe_binary")) +rec_two.dump_to_json(str(OUTPUT_DIR / "two_probe.json"), relative_to=None) + +# ----------------------------------------------------------------------- +# Fixture 3: probe with shuffled device_channel_indices +# Verifies that the channel-reordering logic is preserved across versions. +# ----------------------------------------------------------------------- +n = 8 +probe_sh = generate_linear_probe(num_elec=n, ypitch=20.0) +probe_sh.annotate(name="shuffled_probe", manufacturer="shuffle_vendor") +shuffled_dci = np.array([3, 0, 7, 1, 5, 2, 6, 4]) # permutation of 0..7 +probe_sh.set_device_channel_indices(shuffled_dci) + +# traces[:, j] corresponds to recording channel j, which after set_probe +# is mapped to the contact whose dci equals j. +traces3 = np.arange(1000 * n, dtype="int16").reshape(1000, n) +rec_sh = NumpyRecording([traces3], sampling_frequency=30000.0) +rec_sh = rec_sh.set_probe(probe_sh) # old API + +rec_sh.save(folder=str(OUTPUT_DIR / "shuffled_probe_binary")) +rec_sh.dump_to_json(str(OUTPUT_DIR / "shuffled_probe.json"), relative_to=None) + +print(f"Fixtures written to: {OUTPUT_DIR.resolve()}") diff --git a/.github/workflows/probe_backward_compat.yml b/.github/workflows/probe_backward_compat.yml new file mode 100644 index 0000000000..7eeab4089b --- /dev/null +++ b/.github/workflows/probe_backward_compat.yml @@ -0,0 +1,51 @@ +name: Probe backward compatibility + +on: + pull_request: + types: [synchronize, opened, reopened] + branches: + - main + paths: + - 'src/spikeinterface/core/**' + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + probe-backward-compat: + name: Probe compat (SI 0.104.* → current) + runs-on: ubuntu-latest + env: + SI_PROBE_COMPAT_FIXTURES_DIR: ${{ github.workspace }}/probe_compat_fixtures + + steps: + - name: Check out code + uses: actions/checkout@v6 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.11' + + - name: Set up uv + uses: astral-sh/setup-uv@v7 + with: + python-version: '3.11' + enable-cache: false + + # Step 1: install the OLD release and create fixtures. + # The fixture script uses the old in_place=False default (returns a new recording), + # saves to binary folder + JSON, and writes a known probe name/manufacturer/contact_ids. + - name: Install spikeinterface 0.104.* to create fixtures + run: uv pip install --system "spikeinterface[core]==0.104.*" + + - name: Create compatibility fixtures with old version + run: python .github/scripts/create_probe_compat_fixtures.py "$SI_PROBE_COMPAT_FIXTURES_DIR" + + # Step 2: install the NEW version from this PR source and run the load tests. + - name: Install new spikeinterface from source + run: uv pip install --system -e . --group test-core + + - name: Run backward compatibility tests + run: pytest src/spikeinterface/core/tests/test_probe_backward_compat.py -v diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 995354497e..0bb969ae69 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -1,11 +1,12 @@ from pathlib import Path +from typing import Literal import warnings import numpy as np from probeinterface import Probe, ProbeGroup, write_probeinterface, read_probeinterface, select_axes from .base import BaseExtractor -from .recording_tools import check_probe_do_not_overlap +from .recording_tools import _set_group_property_based_on_probegroup from warnings import warn @@ -75,140 +76,199 @@ def reset_probe(self): """ self._probegroup = None - def set_probe(self, probe, group_mode="auto", in_place=False): + def set_probe( + self, + probe: Probe, + group_mode: Literal["auto", "by_probe", "by_shank", "by_side"] = "auto", + in_place: bool | None = None, + ) -> None: """ - Attach a list of Probe object to a recording. + Attach a Probe object to a recording. Parameters ---------- - probe_or_probegroup: Probe, list of Probe, or ProbeGroup - The probe(s) to be attached to the recording + probe: Probe + The probe to be attached to the recording group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" How to add the "group" property. - "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks and two sides are present. - in_place: bool - False by default. - Useful internally when extractor do self.set_probegroup(probe) - - Returns - ------- - sub_recording: BaseRecording - A view of the recording (ChannelSlice or clone or itself) + "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks + and two sides are present. + in_place: (deprecated) bool | None, default: None + Deprecated argument to indicate whether to modify the recording in place + or return a new recording. The function is always in place now. + Use the `recording.select_channels_with_probegroup()` method instead of `in_place=False` + to return a new recording with a channel selection to match the probe/probegroup. + + Notes + ----- + Internally, this will construct a ProbeGroup with the probe and call `set_probegroup()`. """ - assert isinstance(probe, Probe), "must give Probe" + assert isinstance(probe, Probe), "The input must be a Probe object" probegroup = ProbeGroup() probegroup.add_probe(probe) - return self._set_probes(probegroup, group_mode=group_mode, in_place=in_place) - - def set_probegroup(self, probegroup, group_mode="auto", in_place=False, raise_if_overlapping_probes=True): + # TODO: remove return in 0.106.0 after removing in_place argument + return self.set_probegroup(probegroup, group_mode=group_mode, in_place=in_place) + + def set_probegroup( + self, + probegroup: ProbeGroup | dict, + group_mode: Literal["auto", "by_probe", "by_shank", "by_side"] = "auto", + in_place: bool | None = None, + ) -> None: """ - Attach a ProbeGroup to a recording. - For this ProbeGroup.get_global_device_channel_indices() is used to link contacts to recording channels. - If some contacts of the probe group are not connected (device_channel_indices=-1) - then the recording is "sliced" and only connected channel are kept. + Attach a ProbeGroup or dict to a recording. + For this Probe.device_channel_indices is used to link contacts to recording channels. + After removing unconnected contacts, the number of connected contacts must match the + number of channels in the recording. If this is not the case, use the `recording.select_with_probegroup()` + method instead to return a new recording with a channel selection to match the probe/probegroup. - The probe group order is not kept. Channel ids are re-ordered to match the channel_ids of the recording. + Note: The probe order of the probegroup is not kept. Channel ids are re-ordered to match the channel_ids of the recording. Parameters ---------- - probe_or_probegroup: Probe, list of Probe, or ProbeGroup + probe_or_probegroup: ProbeGroup, or dict The probe(s) to be attached to the recording group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" How to add the "group" property. "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks and two sides are present. - in_place: bool - False by default. - Useful internally when extractor do self.set_probegroup(probe) - raise_if_overlapping_probes: bool - If True, raises an error if the probes overlap. If False, it will just warn + in_place: (deprecated) bool | None, default: None + Deprecated argument to indicate whether to modify the recording in place + or return a new recording. The function is always in place now. + Use the `recording.select_channels_with_probegroup()` method instead of `in_place=False` + to return a new recording with a channel selection to match the probe/probegroup. + """ + if in_place is not None: + warnings.warn( + "The 'in_place' argument is deprecated and will be removed in version 0.106.0. " + "The `set_probe/probegroup()` are always in place and assume that the probe/probegroup has the " + "same number of connected contacts as the number of channels in the recording. " + "Use the `recording.select_channels_with_probegroup()` method instead to return a new recording with " + "a channel selection to match the probe/probegroup.", + DeprecationWarning, + stacklevel=2, + ) + if not in_place: + return self.select_channels_with_probegroup(probegroup, group_mode=group_mode) + + # Handle several input possibilities: Probe or dict + if isinstance(probegroup, dict): + probegroup = ProbeGroup.from_dict(probegroup) + + probegroup_sorted = self._get_probegroup_based_on_device_channel_indices(probegroup) + + if probegroup_sorted.get_contact_count() != self.get_num_channels(): + raise ValueError( + "The probe/probegroup must have the same number of connected contacts " + f"as the number of channels as the recording, but the probe has {probegroup.get_contact_count()} " + f"connected channels and the recording has {self.get_num_channels()} channels. " + "Use the `recording.select_channels_with_probegroup()` method instead to return a new recording with " + "a channel selection to match the probe/probegroup." + ) + probegroup_sorted.set_global_device_channel_indices(np.arange(probegroup_sorted.get_contact_count())) + self._probegroup = probegroup_sorted + + # Handle and set channel groups + _set_group_property_based_on_probegroup(self, probegroup_sorted, group_mode=group_mode) + + def select_channels_with_probe( + self, probe: Probe, group_mode: Literal["auto", "by_probe", "by_shank", "by_side"] = "auto" + ) -> "BaseRecordingSnippets": + """ + Returns a new recording with channels selected based on the probe. + + Parameters + ---------- + probe: Probe + The probe to be used for channel selection + group_mode: "auto" | "by_probe" | "by_shank" | + "by_side", default: "auto" + How to add the "group" property. + "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks and two sides are present. Returns ------- sub_recording: BaseRecording A view of the recording (ChannelSlice or clone or itself) """ - return self._set_probes( - probegroup, - group_mode=group_mode, - in_place=in_place, - raise_if_overlapping_probes=raise_if_overlapping_probes, - ) + assert isinstance(probe, Probe), "The input must be a Probe object" + probegroup = ProbeGroup() + probegroup.add_probe(probe) + return self.select_channels_with_probegroup(probegroup, group_mode=group_mode) - def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False, raise_if_overlapping_probes=True): + def select_channels_with_probegroup( + self, probegroup: ProbeGroup, group_mode: Literal["auto", "by_probe", "by_shank", "by_side"] = "auto" + ) -> "BaseRecordingSnippets": """ - Attach a list of Probe objects or a ProbeGroup to a recording. - For this Probe.device_channel_indices is used to link contacts to recording channels. - If some contacts of the Probe are not connected (device_channel_indices=-1) - then the recording is "sliced" and only connected channel are kept. - - The probe order is not kept. Channel ids are re-ordered to match the channel_ids of the recording. - + Selects channels based on the given ProbeGroup and returns a new recording with the selected channels. Parameters ---------- - probe_or_probegroup: Probe, list of Probes, ProbeGroup, or dict - The probe(s) to be attached to the recording - group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" + probegroup: ProbeGroup + The probegroup to be used for channel selection + group_mode: "auto" | "by_probe" | "by_shank" | + "by_side", default: "auto" How to add the "group" property. - "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks and two sides are present. - in_place: bool - False by default. - Useful internally when extractor do self.set_probegroup(probe) - raise_if_overlapping_probes: bool - If True, raises an error if the probes overlap. If False, it will just warn + "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks + and two sides are present. Returns ------- sub_recording: BaseRecording A view of the recording (ChannelSlice or clone or itself) """ - assert group_mode in ( - "auto", - "by_probe", - "by_shank", - "by_side", - ), "'group_mode' can be 'auto' 'by_probe' 'by_shank' or 'by_side'" - - # handle several input possibilities - if isinstance(probe_or_probegroup, Probe): - probegroup = ProbeGroup() - probegroup.add_probe(probe_or_probegroup) - elif isinstance(probe_or_probegroup, ProbeGroup): - probegroup = probe_or_probegroup - elif isinstance(probe_or_probegroup, list): - assert all([isinstance(e, Probe) for e in probe_or_probegroup]) - probegroup = ProbeGroup() - for probe in probe_or_probegroup: - probegroup.add_probe(probe) - elif isinstance(probe_or_probegroup, dict): - probegroup = ProbeGroup.from_dict(probe_or_probegroup) + probegroup_sorted = self._get_probegroup_based_on_device_channel_indices(probegroup) + if probegroup_sorted.get_contact_count() > 0: + sorted_dci = probegroup_sorted.get_global_device_channel_indices()["device_channel_indices"] + new_channel_ids = self.channel_ids[sorted_dci] + probegroup_sorted.set_global_device_channel_indices(np.arange(len(new_channel_ids))) + if np.array_equal(new_channel_ids, self.channel_ids): + sub_recording = self.clone() + else: + sub_recording = self.select_channels(new_channel_ids) + sub_recording._probegroup = probegroup_sorted + _set_group_property_based_on_probegroup(sub_recording, probegroup_sorted, group_mode=group_mode) else: - raise ValueError("must give Probe or ProbeGroup or list of Probe") + sub_recording = self.select_channels([]) # empty recording + sub_recording._probegroup = ProbeGroup() # empty probegroup + return sub_recording - # check that the probe do not overlap - num_probes = len(probegroup.probes) - if num_probes > 1 and raise_if_overlapping_probes: - check_probe_do_not_overlap(probegroup.probes) + def _get_probegroup_based_on_device_channel_indices(self, probegroup: ProbeGroup) -> ProbeGroup: + """ + Returns a new probegroup sorted based on their device_channel_indices. + This is useful to ensure that the probes are ordered correctly when attached to a recording. + Also checks that the device_channel_indices are consistent with the recording channel count and + contacts are unique across probes in the probegroup. + + Parameters + ---------- + probegroup : ProbeGroup + The probegroup to be sorted. + + Returns + ------- + ProbeGroup + The sorted probegroup. + """ + if not isinstance(probegroup, ProbeGroup): + raise ValueError("The input must be a ProbeGroup or dict") - # handle not connected channels assert all( probe.device_channel_indices is not None for probe in probegroup.probes ), "Probe must have device_channel_indices" - probe_as_numpy_array = probegroup.to_numpy(complete=True) + # Remove unconnected contacts and slice the probe group accordingly device_channel_indices = probegroup.get_global_device_channel_indices()["device_channel_indices"] - keep = device_channel_indices >= 0 - if np.any(~keep): - warn("The given probes have unconnected contacts: they are removed") - device_channel_indices = device_channel_indices[keep] - probe_as_numpy_array = probe_as_numpy_array[keep] - if len(device_channel_indices) > 0: - probegroup = probegroup.get_slice(device_channel_indices) - order = np.argsort(device_channel_indices) - device_channel_indices = device_channel_indices[order] - probegroup.set_global_device_channel_indices(np.arange(len(device_channel_indices))) + keep_indices = np.flatnonzero(device_channel_indices >= 0) + if len(keep_indices) < len(device_channel_indices): + warn( + f"The given probes have {len(device_channel_indices) - len(keep_indices)} unconnected contacts: " + "they will be removed" + ) + probegroup = probegroup.get_slice(keep_indices) + device_channel_indices = device_channel_indices[keep_indices] - # check TODO: Where did this came from? + if len(device_channel_indices) > 0: + # Check consistency of device_channel_indices with the recording channel count number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) if number_of_device_channel_indices >= self.get_num_channels(): error_msg = ( @@ -220,56 +280,26 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False, ra f"recording channels are the following: {self.get_channel_ids()} \n" ) raise ValueError(error_msg) + # Now slice the probe using the device channel indices to match the recording channel_ids + order = np.argsort(device_channel_indices) + probegroup = probegroup.get_slice(order) else: - warn("No connected channel in the probe! The probe will be attached but no channel will be selected.") + warn( + "No connected channels in the probegroup! " + "The probegroup will be attached but no channel will be selected." + ) probegroup = ProbeGroup() # empty probegroup - new_channel_ids = self.channel_ids[device_channel_indices] - - # create recording : channel slice or clone or self - if in_place: - if not np.array_equal(new_channel_ids, self.get_channel_ids()): - raise Exception("set_probe(inplace=True) must have all channel indices") - sub_recording = self - else: - if np.array_equal(new_channel_ids, self.get_channel_ids()): - sub_recording = self.clone() - else: - sub_recording = self.select_channels(new_channel_ids) - - # Set probegroup - sub_recording._probegroup = probegroup - - # handle groups - has_shank_id = "shank_ids" in probe_as_numpy_array.dtype.fields - has_contact_side = "contact_sides" in probe_as_numpy_array.dtype.fields - if group_mode == "auto": - group_keys = ["probe_index"] - if has_shank_id: - group_keys += ["shank_ids"] - if has_contact_side: - group_keys += ["contact_sides"] - elif group_mode == "by_probe": - group_keys = ["probe_index"] - elif group_mode == "by_shank": - assert has_shank_id, "shank_ids is None in probe, you cannot group by shank" - group_keys = ["probe_index", "shank_ids"] - elif group_mode == "by_side": - assert has_contact_side, "contact_sides is None in probe, you cannot group by side" - if has_shank_id: - group_keys = ["probe_index", "shank_ids", "contact_sides"] - else: - group_keys = ["probe_index", "contact_sides"] - groups = np.zeros(probe_as_numpy_array.size, dtype="int64") - unique_keys = np.unique(probe_as_numpy_array[group_keys]) - for group, a in enumerate(unique_keys): - mask = np.ones(probe_as_numpy_array.size, dtype=bool) - for k in group_keys: - mask &= probe_as_numpy_array[k] == a[k] - groups[mask] = group - sub_recording.set_property("group", groups, ids=None) + # In some older SI versions, before #4300, the probe annotations were + # saved to the recording annotations as `probes_info`. If this is the + # case, we can copy the annotations to the probegroup and delete the + # `probes_info` from the recording annotations. + if "probes_info" in self._annotations: + probes_info = self._annotations.pop("probes_info") + for probe, probe_info in zip(probegroup.probes, probes_info): + probe.annotations.update(probe_info) - return sub_recording + return probegroup def get_probe(self): probes = self.get_probes() @@ -285,51 +315,24 @@ def get_probegroup(self): raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") return self._probegroup - # def _build_probegroup_from_properties(self): - # # location and create a dummy probe - # arr = self.get_property("contact_vector") - # if arr is None: - # positions = self.get_property("location") - # if positions is None: - # return None - # else: - # warn("There is no Probe attached to this recording. Creating a dummy one with contact positions") - # probe = self.create_dummy_probe_from_locations(positions) - # # probe.create_auto_shape() - # probegroup = ProbeGroup() - # probegroup.add_probe(probe) - # else: - # probegroup = ProbeGroup.from_numpy(arr) - - # if "probes_info" in self.get_annotation_keys(): - # probes_info = self.get_annotation("probes_info") - # for probe, probe_info in zip(probegroup.probes, probes_info): - # probe.annotations = probe_info - - # for probe_index, probe in enumerate(probegroup.probes): - # contour = self.get_annotation(f"probe_{probe_index}_planar_contour") - # if contour is not None: - # probe.set_planar_contour(contour) - # self.delete_annotation(f"probe_{probe_index}_planar_contour") - # # delete contact_vector as it is not needed anymore - # self.delete_property("contact_vector") - # return probegroup - def _extra_metadata_copy(self, other): if self._probegroup is not None: other._probegroup = self._probegroup.copy() def _extra_metadata_from_folder(self, folder): - # load probe + # load probe from folder + # Note: we don't need any fix for legacy probegroups, since the + # set_probegroup() method will handle the device_channel_indices + # sorting and global contact order folder = Path(folder) probe_file = folder / "probegroup.json" legacy_probe_file = folder / "probe.json" if probe_file.is_file(): probegroup = read_probeinterface(probe_file) - self.set_probegroup(probegroup, in_place=True) + self.set_probegroup(probegroup) elif legacy_probe_file.is_file(): probegroup = read_probeinterface(legacy_probe_file) - self.set_probegroup(probegroup, in_place=True) + self.set_probegroup(probegroup) # remove "contact_vector" property if present as it is not needed anymore if "contact_vector" in self.get_property_keys(): @@ -345,7 +348,7 @@ def _extra_metadata_from_dict(self, dump_dict): # load probe if "probegroup" in dump_dict: probegroup = dump_dict["probegroup"] - self.set_probegroup(probegroup, in_place=True) + self.set_probegroup(probegroup) def _extra_metadata_to_dict(self, dump_dict): # save probe @@ -405,7 +408,7 @@ def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params probe = self.create_dummy_probe_from_locations( np.array(locations), shape=shape, shape_params=shape_params, axes=axes ) - self.set_probe(probe, in_place=True) + self.set_probe(probe) def set_channel_locations(self, locations, channel_ids=None): warnings.warn( @@ -503,7 +506,7 @@ def planarize(self, axes: str = "xy"): probe2d = self.get_probe().to_2d(axes=axes) recording2d = self.clone() - recording2d.set_probe(probe2d, in_place=True) + recording2d.set_probe(probe2d) return recording2d diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index 39e8c51225..a1b0563186 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -261,7 +261,7 @@ def _save(self, format="npy", **save_kwargs): if self.has_probe(): probegroup = self.get_probegroup() - cached.set_probegroup(probegroup, in_place=True) + cached.set_probegroup(probegroup) return cached diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index b62b808ab4..8371a9a5ee 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -112,7 +112,7 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record probegroup_agg = ProbeGroup() probegroup_agg.probes = all_probes probegroup_agg.set_global_device_channel_indices(np.arange(num_all_channels)) - self.set_probegroup(probegroup_agg, in_place=True, raise_if_overlapping_probes=False) + self.set_probegroup(probegroup_agg) # finally add segments, we need a channel mapping ch_id = 0 diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index 1df401f8db..8669e3c90c 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -66,7 +66,7 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) parent_probegroup = self._parent.get_probegroup() sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) sliced_probegroup.set_global_device_channel_indices(np.arange(len(self._channel_ids))) - self.set_probegroup(sliced_probegroup, in_place=True) + self.set_probegroup(sliced_probegroup) # update dump dict self._kwargs = { @@ -157,7 +157,7 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): parent_probegroup = self._parent_snippets.get_probegroup() sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) sliced_probegroup.set_global_device_channel_indices(np.arange(len(self._channel_ids))) - self.set_probegroup(sliced_probegroup, in_place=True) + self.set_probegroup(sliced_probegroup) # update dump dict self._kwargs = { diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 4fa68ebec0..9ca5cb2df9 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -86,7 +86,7 @@ def generate_recording( if ndim == 3: probe = probe.to_3d() probe.set_device_channel_indices(np.arange(num_channels)) - recording.set_probe(probe, in_place=True) + recording.set_probe(probe) recording.name = "SyntheticRecording" @@ -675,7 +675,7 @@ def generate_snippets( if set_probe: probe = recording.get_probe() - snippets = snippets.set_probe(probe) + snippets.set_probe(probe) return snippets, sorting @@ -2462,7 +2462,7 @@ def generate_ground_truth_recording( upsample_vector=upsample_vector, ) recording.annotate(is_filtered=True) - recording.set_probe(probe, in_place=True) + recording.set_probe(probe) recording.set_channel_gains(1.0) recording.set_channel_offsets(0.0) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 7d1b46a07d..99045501fc 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -8,7 +8,8 @@ import numpy as np -from .core_tools import add_suffix, make_shared_array +from probeinterface import ProbeGroup + from .job_tools import ( ensure_chunk_size, divide_segment_into_chunks, @@ -683,44 +684,60 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y"), return order_f, order_r -def check_probe_do_not_overlap(probes): +def _set_group_property_based_on_probegroup( + recording, probegroup: ProbeGroup, group_mode: Literal["auto", "by_probe", "by_shank", "by_side"] +): """ - When several probes this check that that they do not overlap in space - and so channel positions can be safely concatenated. - - Raises - ------ - Exception : - If probes are overlapping + Set the group property for a recording based on a ProbeGroup. + Use "auto" (default) to automatically determine the grouping based on the available + information in the ProbeGroup (default: probe + shank + side if available). - Returns - ------- - None : None - If the check is successful + Parameters + ---------- + recording : BaseRecording + The recording object + probegroup : ProbeGroup + The ProbeGroup object + group_mode : {"auto", "by_probe", "by_shank", "by_side"} + The mode for grouping channels """ - for i in range(len(probes)): - probe_i = probes[i] - # check that all positions in probe_j are outside probe_i boundaries - x_bounds_i = [ - np.min(probe_i.contact_positions[:, 0]), - np.max(probe_i.contact_positions[:, 0]), - ] - y_bounds_i = [ - np.min(probe_i.contact_positions[:, 1]), - np.max(probe_i.contact_positions[:, 1]), - ] - - for j in range(i + 1, len(probes)): - probe_j = probes[j] - if np.any( - np.array( - [ - x_bounds_i[0] <= cp[0] <= x_bounds_i[1] and y_bounds_i[0] <= cp[1] <= y_bounds_i[1] - for cp in probe_j.contact_positions - ] - ) - ): - raise Exception("Probes are overlapping! Retrieve locations of single probes separately") + if not isinstance(probegroup, ProbeGroup): + raise ValueError("`probegroup` must be a ProbeGroup instance.") + assert group_mode in ( + "auto", + "by_probe", + "by_shank", + "by_side", + ), "'group_mode' can be 'auto' 'by_probe' 'by_shank' or 'by_side'" + + probe_array = probegroup.to_numpy(complete=True) + has_shank_id = "shank_ids" in probe_array.dtype.fields + has_contact_side = "contact_sides" in probe_array.dtype.fields + if group_mode == "auto": + group_keys = ["probe_index"] + if has_shank_id: + group_keys += ["shank_ids"] + if has_contact_side: + group_keys += ["contact_sides"] + elif group_mode == "by_probe": + group_keys = ["probe_index"] + elif group_mode == "by_shank": + assert has_shank_id, "shank_ids is None in probe, you cannot group by shank" + group_keys = ["probe_index", "shank_ids"] + elif group_mode == "by_side": + assert has_contact_side, "contact_sides is None in probe, you cannot group by side" + if has_shank_id: + group_keys = ["probe_index", "shank_ids", "contact_sides"] + else: + group_keys = ["probe_index", "contact_sides"] + groups = np.zeros(probe_array.size, dtype="int64") + unique_keys = np.unique(probe_array[group_keys]) + for group, a in enumerate(unique_keys): + mask = np.ones(probe_array.size, dtype=bool) + for k in group_keys: + mask &= probe_array[k] == a[k] + groups[mask] = group + recording.set_property("group", groups, ids=None) def get_rec_attributes(recording): diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index d9bee0d0a7..5abf58ddfd 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -23,7 +23,7 @@ from spikeinterface.core import BaseRecording, BaseSorting, aggregate_channels, aggregate_units from spikeinterface.core.waveform_tools import has_exceeding_spikes -from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match +from .recording_tools import get_rec_attributes, do_recording_attributes_match from .core_tools import ( check_json, retrieve_importing_provenance, diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 4bbc20fa4f..e36150c43e 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -12,7 +12,13 @@ from probeinterface import Probe, ProbeGroup, generate_linear_probe -from spikeinterface.core import BinaryRecordingExtractor, NumpyRecording, load, get_default_zarr_compressor +from spikeinterface.core import ( + BinaryRecordingExtractor, + NumpyRecording, + load, + get_default_zarr_compressor, + aggregate_channels, +) from spikeinterface.core.base import BaseExtractor from spikeinterface.core.testing import check_recordings_equal @@ -197,15 +203,15 @@ def test_BaseRecording(create_cache_folder): ) probe.create_auto_shape() - rec_p = rec.set_probe(probe, group_mode="auto") + rec_p = rec.select_channels_with_probe(probe, group_mode="auto") positions2 = rec_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) - rec_p = rec.set_probe(probe, group_mode="by_shank") + rec_p = rec.select_channels_with_probe(probe, group_mode="by_shank") positions2 = rec_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) - rec_p = rec.set_probe(probe, group_mode="by_probe") + rec_p = rec.select_channels_with_probe(probe, group_mode="by_probe") positions2 = rec_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) @@ -248,13 +254,13 @@ def test_BaseRecording(create_cache_folder): probe.create_auto_shape() traces = np.zeros((1000, 12), dtype="int16") rec = NumpyRecording([traces], 30000.0) - rec1 = rec.set_probe(probe, group_mode="auto") + rec1 = rec.select_channels_with_probe(probe, group_mode="auto") assert np.unique(rec1.get_property("group")).size == 4 - rec2 = rec.set_probe(probe, group_mode="by_probe") + rec2 = rec.select_channels_with_probe(probe, group_mode="by_probe") assert np.unique(rec2.get_property("group")).size == 1 - rec3 = rec.set_probe(probe, group_mode="by_shank") + rec3 = rec.select_channels_with_probe(probe, group_mode="by_shank") assert np.unique(rec3.get_property("group")).size == 2 - rec4 = rec.set_probe(probe, group_mode="by_side") + rec4 = rec.select_channels_with_probe(probe, group_mode="by_side") assert np.unique(rec4.get_property("group")).size == 4 # set unconnected probe @@ -264,7 +270,7 @@ def test_BaseRecording(create_cache_folder): probe.set_device_channel_indices([-1, -1, -1]) probe.create_auto_shape() - rec_empty_probe = rec.set_probe(probe, group_mode="by_shank") + rec_empty_probe = rec.select_channels_with_probe(probe, group_mode="by_shank") assert rec_empty_probe.channel_ids.size == 0 # test scaling parameters @@ -427,32 +433,46 @@ def test_json_pickle_equivalence(create_cache_folder): assert np.all(value == data_pickle[key]) -def test_interleaved_probegroups(): - recording = generate_recording(durations=[1.0], num_channels=16) +def test_probes_info_annotation_backward_compat(): + """ + Regression test: SI versions before #4300 stored per-probe metadata in a + 'probes_info' annotation on the recording rather than inside the ProbeGroup. + set_probegroup() must migrate those annotations onto the probe objects and + remove the stale 'probes_info' entry from the recording annotations. + """ + from probeinterface import generate_linear_probe, ProbeGroup + + # Simulate probegroup as read from an old probegroup.json: probes exist but + # have no name/manufacturer annotations (old probeinterface did not write them). + probe_A = generate_linear_probe(num_elec=8, ypitch=20.0) + probe_A.move([0.0, 0.0]) + probe_A.set_device_channel_indices(np.arange(8)) - probe1 = generate_linear_probe(num_elec=8, ypitch=20.0) - probe2_overlap = probe1.copy() + probe_B = generate_linear_probe(num_elec=8, ypitch=20.0) + probe_B.move([500.0, 0.0]) + probe_B.set_device_channel_indices(np.arange(8, 16)) - probegroup_overlap = ProbeGroup() - probegroup_overlap.add_probe(probe1) - probegroup_overlap.add_probe(probe2_overlap) - probegroup_overlap.set_global_device_channel_indices(np.arange(16)) + pg = ProbeGroup() + pg.add_probe(probe_A) + pg.add_probe(probe_B) - # setting overlapping probes should raise an error - with pytest.raises(Exception): - recording.set_probegroup(probegroup_overlap) + rec = NumpyRecording([np.zeros((100, 16), dtype="int16")], sampling_frequency=30000.0) + + # Inject the old-style annotation that SI used to write alongside the probegroup. + rec._annotations["probes_info"] = [ + {"name": "probe_A", "manufacturer": "vendor_X"}, + {"name": "probe_B", "manufacturer": "vendor_Y"}, + ] - probe2 = probe1.copy() - probe2.move([100.0, 100.0]) - probegroup = ProbeGroup() - probegroup.add_probe(probe1) - probegroup.add_probe(probe2) - probegroup.set_global_device_channel_indices(np.random.permutation(16)) + rec.set_probegroup(pg) # new default: in_place=None → always in-place - recording.set_probegroup(probegroup) - probegroup_set = recording.get_probegroup() - # check that the probe group is correctly set, by sorting the device channel indices - assert np.array_equal(probegroup_set.get_global_device_channel_indices()["device_channel_indices"], np.arange(16)) + probes = rec.get_probes() + assert len(probes) == 2 + probe_names = {p.annotations.get("name") for p in probes} + assert probe_names == {"probe_A", "probe_B"}, "Probe names must be migrated from probes_info" + manufacturers = {p.annotations.get("manufacturer") for p in probes} + assert manufacturers == {"vendor_X", "vendor_Y"}, "Manufacturers must be migrated from probes_info" + assert "probes_info" not in rec._annotations, "probes_info annotation must be consumed after migration" def test_rename_channels(): diff --git a/src/spikeinterface/core/tests/test_basesnippets.py b/src/spikeinterface/core/tests/test_basesnippets.py index 751a03460c..05710b1607 100644 --- a/src/spikeinterface/core/tests/test_basesnippets.py +++ b/src/spikeinterface/core/tests/test_basesnippets.py @@ -142,8 +142,7 @@ def test_BaseSnippets(create_cache_folder): probe.set_device_channel_indices([2, -1, 0]) probe.create_auto_shape() - snippets_p = snippets.set_probe(probe, group_mode="auto") - snippets_p = snippets.set_probe(probe, group_mode="by_probe") + snippets_p = snippets.select_channels_with_probe(probe, group_mode="auto") positions2 = snippets_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py index d5ba74cfd9..8936e6a650 100644 --- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py +++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py @@ -1,10 +1,23 @@ import numpy as np +from probeinterface import generate_linear_probe from spikeinterface.core import aggregate_channels from spikeinterface.core import generate_recording from spikeinterface.core.testing import check_recordings_equal +def _make_rec_with_named_probe(name, manufacturer, x_shift): + """Helper: single-probe recording with annotated name and manufacturer.""" + probe = generate_linear_probe(num_elec=8, ypitch=20.0) + probe.move([x_shift, 0.0]) + probe.annotate(name=name, manufacturer=manufacturer) + probe.set_device_channel_indices(np.arange(8)) + probe.create_auto_shape() + rec = generate_recording(num_channels=8, durations=[1.0], set_probe=False) + rec.set_probe(probe) + return rec + + def test_channelsaggregationrecording(): num_channels = 3 @@ -262,5 +275,51 @@ def test_channel_aggregation_with_string_dtypes_of_different_size(): assert aggregated_recording.channel_ids.dtype == np.dtype(" +with spikeinterface==0.104.* installed. + +The GH Action workflow probe_backward_compat.yml does this automatically. +Set SI_PROBE_COMPAT_FIXTURES_DIR to point at the fixture directory if running locally. +""" + +import os +import numpy as np +import pytest +from pathlib import Path + +from spikeinterface.core import load + +FIXTURES_DIR = Path(os.environ.get("SI_PROBE_COMPAT_FIXTURES_DIR", "probe_compat_fixtures")) + +pytestmark = pytest.mark.skipif( + not FIXTURES_DIR.exists(), + reason=( + f"Probe compatibility fixtures not found at '{FIXTURES_DIR}'. " + "Run .github/scripts/create_probe_compat_fixtures.py with spikeinterface==0.104.* first, " + "or set SI_PROBE_COMPAT_FIXTURES_DIR to the fixture directory." + ), +) + + +# --------------------------------------------------------------------------- +# Shared assertion helpers +# --------------------------------------------------------------------------- + + +def _check_single_probe(rec): + assert rec.has_probe(), "Recording must have a probe after loading" + assert rec.get_num_channels() == 8 + probes = rec.get_probes() + assert len(probes) == 1 + probe = probes[0] + assert probe.annotations.get("name") == "test_probe" + assert probe.annotations.get("manufacturer") == "test_vendor" + assert list(probe.contact_ids) == [f"e{i}" for i in range(8)] + # After loading, device_channel_indices must be sorted 0..N-1 + assert np.array_equal(probe.device_channel_indices, np.arange(8)) + + +def _check_two_probes(rec): + assert rec.has_probe() + assert rec.get_num_channels() == 16 + probes = rec.get_probes() + assert len(probes) == 2, "Both probes must survive after loading" + probe_names = {p.annotations.get("name") for p in probes} + assert probe_names == {"probe_A", "probe_B"}, "Per-probe names must be preserved" + manufacturers = {p.annotations.get("manufacturer") for p in probes} + assert manufacturers == {"vendor_X", "vendor_Y"}, "Per-probe manufacturers must be preserved" + all_contact_ids = set() + for p in probes: + all_contact_ids.update(p.contact_ids.tolist()) + assert all_contact_ids == {f"a{i}" for i in range(8)} | {f"b{i}" for i in range(8)} + groups = rec.get_property("group") + assert len(np.unique(groups)) == 2, "Each probe must have its own group" + + +def _check_shuffled_probe(rec): + assert rec.has_probe() + assert rec.get_num_channels() == 8 + probe = rec.get_probes()[0] + assert probe.annotations.get("name") == "shuffled_probe" + assert probe.annotations.get("manufacturer") == "shuffle_vendor" + # After the old set_probe sorted contacts by device_channel_indices and + # normalised them, the stored probegroup has dci = 0..7. + assert np.array_equal(probe.device_channel_indices, np.arange(8)) + traces = rec.get_traces(segment_index=0) + assert traces.shape == (1000, 8) + + +# --------------------------------------------------------------------------- +# Binary folder fixtures +# --------------------------------------------------------------------------- + + +def test_single_probe_binary_compat(): + _check_single_probe(load(FIXTURES_DIR / "single_probe_binary")) + + +def test_two_probe_binary_compat(): + _check_two_probes(load(FIXTURES_DIR / "two_probe_binary")) + + +def test_shuffled_probe_binary_compat(): + _check_shuffled_probe(load(FIXTURES_DIR / "shuffled_probe_binary")) + + +# --------------------------------------------------------------------------- +# JSON dump fixtures +# --------------------------------------------------------------------------- + + +def test_single_probe_json_compat(): + _check_single_probe(load(FIXTURES_DIR / "single_probe.json")) + + +def test_two_probe_json_compat(): + _check_two_probes(load(FIXTURES_DIR / "two_probe.json")) + + +def test_shuffled_probe_json_compat(): + _check_shuffled_probe(load(FIXTURES_DIR / "shuffled_probe.json")) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index a9bd71b5c0..8b40e3e93f 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -318,7 +318,7 @@ def test_SortingAnalyzer_interleaved_probegroup(dataset): probegroup.add_probe(probe2) probegroup.set_global_device_channel_indices(np.random.permutation(num_channels)) - recording = recording.set_probegroup(probegroup) + recording.set_probegroup(probegroup) sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) # check that locations are correct diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 6e4f6569d3..e224a8f289 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -180,12 +180,15 @@ def __init__( probe_dict = self._root.attrs.get("probegroup", self._root.attrs.get("probe", None)) if probe_dict is not None: probegroup = ProbeGroup.from_dict(probe_dict) - self.set_probegroup(probegroup, in_place=True) + self.set_probegroup(probegroup) # load properties if "properties" in self._root: prop_group = self._root["properties"] for key in prop_group.keys(): + # Skip contact_vector property since it is not used anymore to represent probegroup + if key == "contact_vector": + continue values = self._root["properties"][key] self.set_property(key, values) diff --git a/src/spikeinterface/extractors/bids.py b/src/spikeinterface/extractors/bids.py index 777bdd914b..3a48084ab9 100644 --- a/src/spikeinterface/extractors/bids.py +++ b/src/spikeinterface/extractors/bids.py @@ -40,7 +40,7 @@ def read_bids(folder_path): rec.annotate(bids_name=bids_name) rec.extra_requirements.extend("pandas") probegroup = _read_probe_group(file_path.parent, bids_name, rec.channel_ids) - rec = rec.set_probegroup(probegroup) + rec.set_probegroup(probegroup) recordings.append(rec) elif file_path.suffix == ".nix": @@ -54,7 +54,7 @@ def read_bids(folder_path): rec = read_nix(file_path, stream_id=stream_id) rec.extra_requirements.extend("pandas") probegroup = _read_probe_group(file_path.parent, bids_name, rec.channel_ids) - rec = rec.set_probegroup(probegroup) + rec.set_probegroup(probegroup) recordings.append(rec) return recordings diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index 2a53b999e3..891cbaee07 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -102,9 +102,9 @@ def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", probe = probeinterface.read_spikeglx(meta_file) if probe.shank_ids is not None: - self.set_probe(probe, in_place=True, group_mode="by_shank") + self.set_probe(probe, group_mode="by_shank") else: - self.set_probe(probe, in_place=True) + self.set_probe(probe) sample_shifts = get_neuropixels_sample_shifts_from_probe(probe) self.set_property("inter_sample_shift", sample_shifts) diff --git a/src/spikeinterface/extractors/iblextractors.py b/src/spikeinterface/extractors/iblextractors.py index 779c36fa23..8a57e40ec3 100644 --- a/src/spikeinterface/extractors/iblextractors.py +++ b/src/spikeinterface/extractors/iblextractors.py @@ -221,9 +221,9 @@ def __init__( probe = probeinterface.read_spikeglx(meta_file) if probe.shank_ids is not None: - self.set_probe(probe, in_place=True, group_mode="by_shank") + self.set_probe(probe, group_mode="by_shank") else: - self.set_probe(probe, in_place=True) + self.set_probe(probe) # set channel properties # sometimes there are missing metadata files on the IBL side diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index c85e82b574..b3ccb92cbd 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -72,7 +72,7 @@ def __init__( probe = probeinterface.read_3brain(file_path, **probe_kwargs) rows = probe.contact_annotations["row"] cols = probe.contact_annotations["col"] - self.set_probe(probe, in_place=True) + self.set_probe(probe) self.set_property("row", rows) self.set_property("col", cols) diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 5eaa49e6b8..38e65096c2 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -75,7 +75,7 @@ def __init__( rec_name = self.neo_reader.rec_name probe = probeinterface.read_maxwell(file_path, well_name=well_name, rec_name=rec_name) electrodes = probe.contact_annotations["electrode"] - self.set_probe(probe, in_place=True) + self.set_probe(probe) self.set_property("electrode", electrodes) self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), rec_name=rec_name)) diff --git a/src/spikeinterface/extractors/neoextractors/mearec.py b/src/spikeinterface/extractors/neoextractors/mearec.py index 7ca82af01e..d4cbe1b0de 100644 --- a/src/spikeinterface/extractors/neoextractors/mearec.py +++ b/src/spikeinterface/extractors/neoextractors/mearec.py @@ -55,7 +55,7 @@ def __init__(self, file_path: str | Path, all_annotations: bool = False, use_nam probe = probeinterface.read_mearec(file_path) probe.annotations["mearec_name"] = str(probe.annotations["mearec_name"]) - self.set_probe(probe, in_place=True) + self.set_probe(probe) self.annotate(is_filtered=True) if hasattr(self.neo_reader._recgen, "gain_to_uV"): diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index 5dc9220aa5..22ae82b117 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -331,9 +331,9 @@ def __init__( settings_file=settings_file, stream_name=oe_stream_name ) if probe.shank_ids is not None: - self.set_probe(probe, in_place=True, group_mode="by_shank") + self.set_probe(probe, group_mode="by_shank") else: - self.set_probe(probe, in_place=True) + self.set_probe(probe) # get inter-sample shifts based on the probe information and mux channels sample_shifts = get_neuropixels_sample_shifts_from_probe(probe) if sample_shifts is not None: diff --git a/src/spikeinterface/extractors/neoextractors/spikegadgets.py b/src/spikeinterface/extractors/neoextractors/spikegadgets.py index adc50df12f..da4a66e1f5 100644 --- a/src/spikeinterface/extractors/neoextractors/spikegadgets.py +++ b/src/spikeinterface/extractors/neoextractors/spikegadgets.py @@ -79,7 +79,7 @@ def __init__( if saturation_threshold_uV_probe is not None: saturation_thresholds_uV.append(saturation_threshold_uV_probe) - self.set_probegroup(probegroup, in_place=True) + self.set_probegroup(probegroup) if np.all(sample_shifts != -1): self.set_property("inter_sample_shift", sample_shifts) diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index 60b1a98be8..41c2b77bfc 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -86,9 +86,9 @@ def __init__( probe = probeinterface.read_spikeglx(ap_meta_filename) if probe.shank_ids is not None: - self.set_probe(probe, in_place=True, group_mode="by_shank") + self.set_probe(probe, group_mode="by_shank") else: - self.set_probe(probe, in_place=True) + self.set_probe(probe) # get inter-sample shifts based on the probe information and mux channels sample_shifts = get_neuropixels_sample_shifts_from_probe(probe) diff --git a/src/spikeinterface/extractors/shybridextractors.py b/src/spikeinterface/extractors/shybridextractors.py index ff08c1a3f3..eca7d46724 100644 --- a/src/spikeinterface/extractors/shybridextractors.py +++ b/src/spikeinterface/extractors/shybridextractors.py @@ -82,7 +82,7 @@ def __init__(self, file_path): # load probe file probegroup = probeinterface.read_prb(params["probe"]) - self.set_probegroup(probegroup, in_place=True) + self.set_probegroup(probegroup) self._kwargs = {"file_path": str(Path(file_path).absolute())} self.extra_requirements.extend(["hybridizer", "pyyaml"]) diff --git a/src/spikeinterface/extractors/sinapsrecordingextractors.py b/src/spikeinterface/extractors/sinapsrecordingextractors.py index 132a01f300..f47a83bc47 100644 --- a/src/spikeinterface/extractors/sinapsrecordingextractors.py +++ b/src/spikeinterface/extractors/sinapsrecordingextractors.py @@ -84,7 +84,7 @@ def __init__(self, file_path: str | Path, stream_name: str = "filt"): if (stream_name == "filt") | (stream_name == "raw"): probe = get_sinaps_probe(probe_type) if probe is not None: - self.set_probe(probe, in_place=True) + self.set_probe(probe) self._kwargs = {"file_path": str(file_path.absolute()), "stream_name": stream_name} @@ -143,7 +143,7 @@ def __init__(self, file_path: str | Path, stream_name: str = "filt"): # set probe probe = get_sinaps_probe(sinaps_info["probe_type"]) if probe is not None: - self.set_probe(probe, in_place=True) + self.set_probe(probe) self._kwargs = {"file_path": str(Path(file_path).absolute()), "stream_name": stream_name} diff --git a/src/spikeinterface/generation/drift_tools.py b/src/spikeinterface/generation/drift_tools.py index 1800138dae..6996800e27 100644 --- a/src/spikeinterface/generation/drift_tools.py +++ b/src/spikeinterface/generation/drift_tools.py @@ -477,7 +477,7 @@ def __init__( ) self.add_recording_segment(recording_segment) - self.set_probe(drifting_templates.probe, in_place=True) + self.set_probe(drifting_templates.probe) # templates are too large, we don't serialize them to JSON self._serializability["json"] = False diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 7975097629..dc83ab9596 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -117,7 +117,7 @@ def compute_monopolar_triangulation( # if enforce_decrease: # enforce_decrease_shells_data( - # wf_data, best_channels[unit_id], enforce_decrease_radial_parents, in_place=True + # wf_data, best_channels[unit_id], enforce_decrease_radial_parents # ) unit_location[i] = solve_monopolar_triangulation(wf_data, local_contact_locations, max_distance_um, optimizer) diff --git a/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py index c7c37968d9..c8825831b0 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py @@ -30,7 +30,7 @@ def recording_and_shape(): probe = probeinterface.generate_multi_columns_probe(num_columns=num_cols, num_contact_per_column=num_rows) probe.set_device_channel_indices(np.arange(num_cols * num_rows)) recording = generate_recording(num_channels=num_cols * num_rows, durations=[10.0], sampling_frequency=30000) - recording.set_probe(probe, in_place=True) + recording.set_probe(probe) recording = depth_order(recording) recording = zscore(recording) desired_shape = (num_rows, num_cols) diff --git a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py index 35f398f985..5a0e160f92 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py @@ -80,7 +80,7 @@ def test_detect_bad_channels_std_mad(): probe = generate_linear_probe(num_elec=num_channels) probe.set_device_channel_indices(np.arange(num_channels)) - rec.set_probe(probe, in_place=True) + rec.set_probe(probe) bad_channels_std, bad_labels_std = detect_bad_channels(rec, method="std") bad_channels_mad, bad_labels_mad = detect_bad_channels(rec, method="std") @@ -125,7 +125,7 @@ def test_detect_bad_channels_extremes(outside_channels_location): probe = generate_linear_probe(num_elec=num_channels) probe.set_device_channel_indices(np.arange(num_channels)) - rec.set_probe(probe, in_place=True) + rec.set_probe(probe) bad_channel_ids, bad_labels = detect_bad_channels( rec, method="coherence+psd", outside_channels_location=outside_channels_location diff --git a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py index bfa4d3d9ae..89e8e36cf8 100644 --- a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py @@ -118,7 +118,7 @@ def test_highpass_spatial_filter_with_dead_channels(): rec_with_dead = NumpyRecording( traces_list=[traces], sampling_frequency=rec.sampling_frequency, channel_ids=rec.channel_ids ) - rec_with_dead.set_probe(rec.get_probe(), in_place=True) + rec_with_dead.set_probe(rec.get_probe()) filtered = spre.highpass_spatial_filter(rec_with_dead, n_channel_pad=2) result = filtered.get_traces() assert result.shape == traces.shape diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index 1294b57a91..c79605a110 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -132,7 +132,7 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan new_positions = probe.contact_positions.copy() new_positions[:, 0] = x_new # column 0 is x recording._probegroup.probes[0]._contact_positions = new_positions - recording.set_probe(probe, in_place=True) + recording.set_probe(probe) # generate random bad channel locations bad_channel_indexes = rng.choice(num_channels, rng.integers(1, int(num_channels / 5)), replace=False) @@ -180,7 +180,7 @@ def test_output_values(): probe = pi.Probe(ndim=2) probe.set_contacts(positions=probe_locs) probe.set_device_channel_indices(np.arange(len(probe_locs))) - recording.set_probe(probe, in_place=True) + recording.set_probe(probe) # Run interpolation in SI and check the interpolated channel # 0 is a linear combination of other channels diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 80d79171ce..5698e0e142 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -410,7 +410,7 @@ def __init__( channel_indices = recording.ids_to_indices(channel_ids) probegroup_sliced = probegroup.get_slice(channel_indices) probegroup_sliced.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) - self.set_probegroup(probegroup_sliced, in_place=True) + self.set_probegroup(probegroup_sliced) # handle manual interpolation_time_bin_centers_s # the case where interpolation_time_bin_size_s is set is handled per-segment below diff --git a/src/spikeinterface/sortingcomponents/peak_localization/monopolar.py b/src/spikeinterface/sortingcomponents/peak_localization/monopolar.py index 8840a5a00d..942074ef31 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization/monopolar.py +++ b/src/spikeinterface/sortingcomponents/peak_localization/monopolar.py @@ -97,9 +97,7 @@ def compute(self, traces, peaks, waveforms): wf_data = np.abs(wf[self.nbefore]) if self.enforce_decrease_radial_parents is not None: - enforce_decrease_shells_data( - wf_data, peak["channel_index"], self.enforce_decrease_radial_parents, in_place=True - ) + enforce_decrease_shells_data(wf_data, peak["channel_index"], self.enforce_decrease_radial_parents) peak_locations[i] = solve_monopolar_triangulation( wf_data, local_contact_locations, self.max_distance_um, self.optimizer From b907d738b6fa0a5e93926e1fd198fbfb3a0ee5b1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 29 Jun 2026 17:09:13 +0200 Subject: [PATCH 13/19] test: fix backward compatibility test --- .github/scripts/create_probe_compat_fixtures.py | 8 ++++---- .../core/tests/test_probe_backward_compat.py | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/scripts/create_probe_compat_fixtures.py b/.github/scripts/create_probe_compat_fixtures.py index e41f7c7804..20ab8b83c3 100644 --- a/.github/scripts/create_probe_compat_fixtures.py +++ b/.github/scripts/create_probe_compat_fixtures.py @@ -42,7 +42,7 @@ rec_single = rec_single.set_probe(probe) # old API: in_place=False, returns new recording rec_single.save(folder=str(OUTPUT_DIR / "single_probe_binary")) -rec_single.dump_to_json(str(OUTPUT_DIR / "single_probe.json"), relative_to=None) +rec_single.save(str(OUTPUT_DIR / "single_probe.zarr"), format="zarr") # ----------------------------------------------------------------------- # Fixture 2: two probes with per-probe name/manufacturer @@ -69,10 +69,10 @@ n_total = n_A + n_B traces2 = np.arange(1000 * n_total, dtype="int16").reshape(1000, n_total) rec_two = NumpyRecording([traces2], sampling_frequency=30000.0) -rec_two = rec_two.set_probegroup(pg) # old API: in_place=False, returns new recording +rec_two.set_probegroup(pg) # old API: in_place=False, returns new recording rec_two.save(folder=str(OUTPUT_DIR / "two_probe_binary")) -rec_two.dump_to_json(str(OUTPUT_DIR / "two_probe.json"), relative_to=None) +rec_two.save(str(OUTPUT_DIR / "two_probe.zarr"), format="zarr") # ----------------------------------------------------------------------- # Fixture 3: probe with shuffled device_channel_indices @@ -91,6 +91,6 @@ rec_sh = rec_sh.set_probe(probe_sh) # old API rec_sh.save(folder=str(OUTPUT_DIR / "shuffled_probe_binary")) -rec_sh.dump_to_json(str(OUTPUT_DIR / "shuffled_probe.json"), relative_to=None) +rec_sh.save(str(OUTPUT_DIR / "shuffled_probe.zarr"), format="zarr") print(f"Fixtures written to: {OUTPUT_DIR.resolve()}") diff --git a/src/spikeinterface/core/tests/test_probe_backward_compat.py b/src/spikeinterface/core/tests/test_probe_backward_compat.py index a7164e5597..647fdaf65f 100644 --- a/src/spikeinterface/core/tests/test_probe_backward_compat.py +++ b/src/spikeinterface/core/tests/test_probe_backward_compat.py @@ -99,13 +99,13 @@ def test_shuffled_probe_binary_compat(): # --------------------------------------------------------------------------- -def test_single_probe_json_compat(): - _check_single_probe(load(FIXTURES_DIR / "single_probe.json")) +def test_single_probe_zarr_compat(): + _check_single_probe(load(FIXTURES_DIR / "single_probe.zarr")) -def test_two_probe_json_compat(): - _check_two_probes(load(FIXTURES_DIR / "two_probe.json")) +def test_two_probe_zarr_compat(): + _check_two_probes(load(FIXTURES_DIR / "two_probe.zarr")) -def test_shuffled_probe_json_compat(): - _check_shuffled_probe(load(FIXTURES_DIR / "shuffled_probe.json")) +def test_shuffled_probe_zarr_compat(): + _check_shuffled_probe(load(FIXTURES_DIR / "shuffled_probe.zarr")) From 68735fec67df1c8bf6c1c922ed063c6adf3aaaba Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 29 Jun 2026 17:10:28 +0200 Subject: [PATCH 14/19] oups --- .github/scripts/create_probe_compat_fixtures.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/scripts/create_probe_compat_fixtures.py b/.github/scripts/create_probe_compat_fixtures.py index 20ab8b83c3..a2996795b7 100644 --- a/.github/scripts/create_probe_compat_fixtures.py +++ b/.github/scripts/create_probe_compat_fixtures.py @@ -42,7 +42,7 @@ rec_single = rec_single.set_probe(probe) # old API: in_place=False, returns new recording rec_single.save(folder=str(OUTPUT_DIR / "single_probe_binary")) -rec_single.save(str(OUTPUT_DIR / "single_probe.zarr"), format="zarr") +rec_single.save(folder=str(OUTPUT_DIR / "single_probe.zarr"), format="zarr") # ----------------------------------------------------------------------- # Fixture 2: two probes with per-probe name/manufacturer @@ -72,7 +72,7 @@ rec_two.set_probegroup(pg) # old API: in_place=False, returns new recording rec_two.save(folder=str(OUTPUT_DIR / "two_probe_binary")) -rec_two.save(str(OUTPUT_DIR / "two_probe.zarr"), format="zarr") +rec_two.save(folder=str(OUTPUT_DIR / "two_probe.zarr"), format="zarr") # ----------------------------------------------------------------------- # Fixture 3: probe with shuffled device_channel_indices @@ -91,6 +91,6 @@ rec_sh = rec_sh.set_probe(probe_sh) # old API rec_sh.save(folder=str(OUTPUT_DIR / "shuffled_probe_binary")) -rec_sh.save(str(OUTPUT_DIR / "shuffled_probe.zarr"), format="zarr") +rec_sh.save(folder=str(OUTPUT_DIR / "shuffled_probe.zarr"), format="zarr") print(f"Fixtures written to: {OUTPUT_DIR.resolve()}") From 0e3a0bd702dd9663a855bd754b232397951b70a3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 29 Jun 2026 17:27:00 +0200 Subject: [PATCH 15/19] fix: most tests --- src/spikeinterface/core/baserecordingsnippets.py | 8 ++++++-- src/spikeinterface/core/channelsaggregationrecording.py | 7 +++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 0bb969ae69..83a018bf18 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -264,8 +264,12 @@ def _get_probegroup_based_on_device_channel_indices(self, probegroup: ProbeGroup f"The given probes have {len(device_channel_indices) - len(keep_indices)} unconnected contacts: " "they will be removed" ) - probegroup = probegroup.get_slice(keep_indices) - device_channel_indices = device_channel_indices[keep_indices] + if len(keep_indices) == 0: + probegorup = ProbeGroup() # empty probegroup + device_channel_indices = np.array([], dtype="int64") + else: + probegroup = probegroup.get_slice(keep_indices) + device_channel_indices = device_channel_indices[keep_indices] if len(device_channel_indices) > 0: # Check consistency of device_channel_indices with the recording channel count diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 8371a9a5ee..28f37201b5 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -106,11 +106,10 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record ), "Contact positions are not unique! Cannot aggregate recordings." # Now make a new probegroup with all probes and set global device channel indices - all_probes = [] - for probegroup in all_probegroups: - all_probes.extend([p.copy() for p in probegroup.probes]) probegroup_agg = ProbeGroup() - probegroup_agg.probes = all_probes + for probegroup in all_probegroups: + for probe in probegroup.probes: + probegroup_agg.add_probe(probe.copy()) probegroup_agg.set_global_device_channel_indices(np.arange(num_all_channels)) self.set_probegroup(probegroup_agg) From 6c1ff7970be313356f937482d991d81af3b02277 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 29 Jun 2026 17:38:54 +0200 Subject: [PATCH 16/19] test: fix bacward compat tests --- .github/scripts/create_probe_compat_fixtures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/scripts/create_probe_compat_fixtures.py b/.github/scripts/create_probe_compat_fixtures.py index a2996795b7..83fafacff8 100644 --- a/.github/scripts/create_probe_compat_fixtures.py +++ b/.github/scripts/create_probe_compat_fixtures.py @@ -69,7 +69,7 @@ n_total = n_A + n_B traces2 = np.arange(1000 * n_total, dtype="int16").reshape(1000, n_total) rec_two = NumpyRecording([traces2], sampling_frequency=30000.0) -rec_two.set_probegroup(pg) # old API: in_place=False, returns new recording +rec_two = rec_two.set_probegroup(pg) # old API: in_place=False, returns new recording rec_two.save(folder=str(OUTPUT_DIR / "two_probe_binary")) rec_two.save(folder=str(OUTPUT_DIR / "two_probe.zarr"), format="zarr") From c7491899d03a7c2a9d3c4c011af3db7535eb9063 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 29 Jun 2026 18:41:33 +0200 Subject: [PATCH 17/19] docs: fix doc tests --- doc/api.rst | 8 ++++++++ doc/get_started/quickstart.rst | 5 +++-- doc/modules/core.rst | 18 ++++++++++++------ .../forhowto/plot_working_with_tetrodes.py | 6 +++--- examples/get_started/quickstart.py | 4 ++-- .../core/plot_1_recording_extractor.py | 2 +- .../tutorials/core/plot_3_handle_probe_info.py | 10 +++++----- 7 files changed, 34 insertions(+), 19 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index fc55017606..1bc8156aef 100755 --- a/doc/api.rst +++ b/doc/api.rst @@ -16,6 +16,12 @@ spikeinterface.core .. automethod:: BaseRecording.dump_to_json .. automethod:: BaseRecording.dump_to_pickle .. automethod:: BaseRecording.remove_channels + .. automethod:: BaseRecording.set_probe + .. automethod:: BaseRecording.set_probegroup + .. automethod:: BaseRecording.reset_probe + .. automethod:: BaseRecording.select_channels_with_probe + .. automethod:: BaseRecording.select_channels_with_probegroup + .. automethod:: BaseRecording.split_by .. autoclass:: BaseSorting :members: .. automethod:: BaseSorting.save @@ -25,6 +31,8 @@ spikeinterface.core .. automethod:: BaseSorting.dump .. automethod:: BaseSorting.dump_to_json .. automethod:: BaseSorting.dump_to_pickle + .. automethod:: BaseSorting.split_by + .. automethod:: BaseSorting.register_recording .. autoclass:: BaseSnippets :members: .. automethod:: BaseSnippets.save diff --git a/doc/get_started/quickstart.rst b/doc/get_started/quickstart.rst index b5d3c2a985..3fbd113ba7 100644 --- a/doc/get_started/quickstart.rst +++ b/doc/get_started/quickstart.rst @@ -198,8 +198,9 @@ to set it *manually*. If your recording does not have a ``Probe``, you can set it using -``set_probe``. Note: ``set_probe`` creates a copy of the recording with -the new probe, rather than modifying the existing recording in place. +``set_probe``. Note: ``set_probe`` modifies the recording in place. To +get a new recording object with a subset of channels attached to a probe, +use ``select_channels_with_probe``. There is more information `here `__. diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 681542368b..6ea3d25eb6 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -522,12 +522,18 @@ The probe has 4 shanks, which can be loaded as separate groups (and spike sorted # add wiring probe.wiring_to_device('ASSY-156>RHD2164') - # set probe - recording_w_probe = recording.set_probe(probe) - # set probe with group info and return a new recording object - recording_w_probe = recording.set_probe(probe, group_mode="by_shank") - # set probe in place, ie, modify the current recording - recording.set_probe(probe, group_mode="by_shank", in_place=True) + # set probe (modifies the recording in place) + recording.set_probe(probe) + # set probe with group info derived from shank ids (in place) + recording.set_probe(probe, group_mode="by_shank") + + # to get a *new* recording without modifying the original, use select_channels_with_probe + recording_w_probe = recording.select_channels_with_probe(probe) + recording_w_probe = recording.select_channels_with_probe(probe, group_mode="by_shank") + + # multi-probe recordings use set_probegroup / select_channels_with_probegroup + recording.set_probegroup(probegroup) + recording_w_probegroup = recording.select_channels_with_probegroup(probegroup) # retrieve probe probe_from_recording = recording.get_probe() diff --git a/examples/forhowto/plot_working_with_tetrodes.py b/examples/forhowto/plot_working_with_tetrodes.py index 0c652a5186..547e9deae1 100644 --- a/examples/forhowto/plot_working_with_tetrodes.py +++ b/examples/forhowto/plot_working_with_tetrodes.py @@ -62,15 +62,15 @@ # We can now attach the :code:`tetrode_group` to our recording. To check if this worked, we'll # plot the probe map -recording_with_probe = recording.set_probegroup(tetrode_group) -plot_probe_map(recording_with_probe) +recording.set_probegroup(tetrode_group) +plot_probe_map(recording) ############################################################################## # Looks good! Now that the recording is aware of the probe geometry, we can # begin a standard spike sorting pipeline. First, we can apply preprocessing. # Note that we apply this preprocessing on the entire bundle of tetrodes. -preprocessed_recording = spre.bandpass_filter(recording_with_probe) +preprocessed_recording = spre.bandpass_filter(recording) ############################################################################## # WARNING: a very common preprocessing step is to apply a common median diff --git a/examples/get_started/quickstart.py b/examples/get_started/quickstart.py index 2481f8569f..75d5c8d63a 100644 --- a/examples/get_started/quickstart.py +++ b/examples/get_started/quickstart.py @@ -137,8 +137,8 @@ # - # If your recording does not have a `Probe`, you can set it using `set_probe`. -# Note: `set_probe` creates a copy of the recording with the new probe, -# rather than modifying the existing recording in place. +# Note: `set_probe` modifies the recording in place. To get a new recording +# object with a subset of channels attached to a probe, use `select_channels_with_probe`. # There is more information [here](https://spikeinterface.readthedocs.io/en/latest/modules_gallery/core/plot_3_handle_probe_info.html). # Using the `spikeinterface.preprocessing` module, you can perform preprocessing on the recordings. diff --git a/examples/tutorials/core/plot_1_recording_extractor.py b/examples/tutorials/core/plot_1_recording_extractor.py index e3bfda5855..477ba165b6 100644 --- a/examples/tutorials/core/plot_1_recording_extractor.py +++ b/examples/tutorials/core/plot_1_recording_extractor.py @@ -70,7 +70,7 @@ probe.set_device_channel_indices(np.arange(7)) # then we need to actually set the probe to the recording object -recording = recording.set_probe(probe) +recording.set_probe(probe) plot_probe(probe) ############################################################################## diff --git a/examples/tutorials/core/plot_3_handle_probe_info.py b/examples/tutorials/core/plot_3_handle_probe_info.py index deff58ebb7..28d2af655a 100644 --- a/examples/tutorials/core/plot_3_handle_probe_info.py +++ b/examples/tutorials/core/plot_3_handle_probe_info.py @@ -43,8 +43,8 @@ print(other_probe) other_probe.set_device_channel_indices(np.arange(32)) -recording_2_shanks = recording.set_probe(other_probe, group_mode="by_shank") -plot_probe(recording_2_shanks.get_probe()) +recording.set_probe(other_probe, group_mode="by_shank") +plot_probe(recording.get_probe()) ############################################################################### # Now let's check what we have loaded. The :code:`group_mode='by_shank'` automatically @@ -53,11 +53,11 @@ # We can access this information either as a dict with :code:`outputs='dict'` (default) # or as a list of recordings with :code:`outputs='list'`. -print(recording_2_shanks) -print(f'\nGroup Property: {recording_2_shanks.get_property("group")}\n') +print(recording) +print(f'\nGroup Property: {recording.get_property("group")}\n') # Here we split as a dict -sub_recording_dict = recording_2_shanks.split_by(property="group", outputs='dict') +sub_recording_dict = recording.split_by(property="group", outputs='dict') # Then we can pull out the individual sub-recordings sub_rec0 = sub_recording_dict[0] From 7cd0a751705b8bacdf293269d38fef1736b573c0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 2 Jul 2026 11:04:42 +0200 Subject: [PATCH 18/19] feat: use progegroup instead of contact_vector --- doc/api.rst | 2 - doc/get_started/quickstart.rst | 5 +- doc/modules/core.rst | 18 +-- examples/get_started/quickstart.py | 4 +- .../core/baserecordingsnippets.py | 150 ++++++------------ src/spikeinterface/core/basesnippets.py | 2 +- .../core/channelsaggregationrecording.py | 2 +- src/spikeinterface/core/channelslice.py | 4 +- src/spikeinterface/core/generate.py | 6 +- .../core/tests/test_baserecording.py | 18 +-- .../core/tests/test_basesnippets.py | 2 +- .../test_channelsaggregationrecording.py | 2 +- .../core/tests/test_channelslicerecording.py | 4 +- .../core/tests/test_sortinganalyzer.py | 2 +- src/spikeinterface/core/zarrextractors.py | 2 +- src/spikeinterface/extractors/bids.py | 4 +- src/spikeinterface/extractors/cbin_ibl.py | 4 +- .../extractors/iblextractors.py | 4 +- .../extractors/neoextractors/biocam.py | 2 +- .../extractors/neoextractors/maxwell.py | 2 +- .../extractors/neoextractors/mearec.py | 2 +- .../extractors/neoextractors/openephys.py | 4 +- .../extractors/neoextractors/spikegadgets.py | 2 +- .../extractors/neoextractors/spikeglx.py | 4 +- .../extractors/shybridextractors.py | 2 +- .../extractors/sinapsrecordingextractors.py | 4 +- src/spikeinterface/generation/drift_tools.py | 2 +- .../tests/test_deepinterpolation.py | 2 +- .../tests/test_detect_bad_channels.py | 4 +- .../tests/test_highpass_spatial_filter.py | 2 +- .../tests/test_interpolate_bad_channels.py | 4 +- .../motion/motion_interpolation.py | 2 +- 32 files changed, 107 insertions(+), 166 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 1bc8156aef..6aca96c4ec 100755 --- a/doc/api.rst +++ b/doc/api.rst @@ -19,8 +19,6 @@ spikeinterface.core .. automethod:: BaseRecording.set_probe .. automethod:: BaseRecording.set_probegroup .. automethod:: BaseRecording.reset_probe - .. automethod:: BaseRecording.select_channels_with_probe - .. automethod:: BaseRecording.select_channels_with_probegroup .. automethod:: BaseRecording.split_by .. autoclass:: BaseSorting :members: diff --git a/doc/get_started/quickstart.rst b/doc/get_started/quickstart.rst index 3fbd113ba7..b5d3c2a985 100644 --- a/doc/get_started/quickstart.rst +++ b/doc/get_started/quickstart.rst @@ -198,9 +198,8 @@ to set it *manually*. If your recording does not have a ``Probe``, you can set it using -``set_probe``. Note: ``set_probe`` modifies the recording in place. To -get a new recording object with a subset of channels attached to a probe, -use ``select_channels_with_probe``. +``set_probe``. Note: ``set_probe`` creates a copy of the recording with +the new probe, rather than modifying the existing recording in place. There is more information `here `__. diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 6ea3d25eb6..681542368b 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -522,18 +522,12 @@ The probe has 4 shanks, which can be loaded as separate groups (and spike sorted # add wiring probe.wiring_to_device('ASSY-156>RHD2164') - # set probe (modifies the recording in place) - recording.set_probe(probe) - # set probe with group info derived from shank ids (in place) - recording.set_probe(probe, group_mode="by_shank") - - # to get a *new* recording without modifying the original, use select_channels_with_probe - recording_w_probe = recording.select_channels_with_probe(probe) - recording_w_probe = recording.select_channels_with_probe(probe, group_mode="by_shank") - - # multi-probe recordings use set_probegroup / select_channels_with_probegroup - recording.set_probegroup(probegroup) - recording_w_probegroup = recording.select_channels_with_probegroup(probegroup) + # set probe + recording_w_probe = recording.set_probe(probe) + # set probe with group info and return a new recording object + recording_w_probe = recording.set_probe(probe, group_mode="by_shank") + # set probe in place, ie, modify the current recording + recording.set_probe(probe, group_mode="by_shank", in_place=True) # retrieve probe probe_from_recording = recording.get_probe() diff --git a/examples/get_started/quickstart.py b/examples/get_started/quickstart.py index 75d5c8d63a..2481f8569f 100644 --- a/examples/get_started/quickstart.py +++ b/examples/get_started/quickstart.py @@ -137,8 +137,8 @@ # - # If your recording does not have a `Probe`, you can set it using `set_probe`. -# Note: `set_probe` modifies the recording in place. To get a new recording -# object with a subset of channels attached to a probe, use `select_channels_with_probe`. +# Note: `set_probe` creates a copy of the recording with the new probe, +# rather than modifying the existing recording in place. # There is more information [here](https://spikeinterface.readthedocs.io/en/latest/modules_gallery/core/plot_3_handle_probe_info.html). # Using the `spikeinterface.preprocessing` module, you can perform preprocessing on the recordings. diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 83a018bf18..43ed20dc34 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -80,11 +80,15 @@ def set_probe( self, probe: Probe, group_mode: Literal["auto", "by_probe", "by_shank", "by_side"] = "auto", - in_place: bool | None = None, - ) -> None: + in_place: bool = False, + ) -> "BaseRecordingSnippets": """ Attach a Probe object to a recording. + For this Probe.device_channel_indices is used to link contacts to recording channels. + If some contacts of the Probe are not connected (device_channel_indices=-1) + then the recording is "sliced" and only connected channels are kept. + Parameters ---------- probe: Probe @@ -93,11 +97,14 @@ def set_probe( How to add the "group" property. "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks and two sides are present. - in_place: (deprecated) bool | None, default: None - Deprecated argument to indicate whether to modify the recording in place - or return a new recording. The function is always in place now. - Use the `recording.select_channels_with_probegroup()` method instead of `in_place=False` - to return a new recording with a channel selection to match the probe/probegroup. + in_place: bool, default: False + If False, a new recording (view or channel selection) is returned. + If True, the recording is modified in place, which requires all channels to be connected. + + Returns + ------- + sub_recording: BaseRecording + A view of the recording (ChannelSlice or clone or itself) Notes ----- @@ -106,130 +113,73 @@ def set_probe( assert isinstance(probe, Probe), "The input must be a Probe object" probegroup = ProbeGroup() probegroup.add_probe(probe) - # TODO: remove return in 0.106.0 after removing in_place argument return self.set_probegroup(probegroup, group_mode=group_mode, in_place=in_place) def set_probegroup( self, probegroup: ProbeGroup | dict, group_mode: Literal["auto", "by_probe", "by_shank", "by_side"] = "auto", - in_place: bool | None = None, - ) -> None: + in_place: bool = False, + ) -> "BaseRecordingSnippets": """ Attach a ProbeGroup or dict to a recording. For this Probe.device_channel_indices is used to link contacts to recording channels. - After removing unconnected contacts, the number of connected contacts must match the - number of channels in the recording. If this is not the case, use the `recording.select_with_probegroup()` - method instead to return a new recording with a channel selection to match the probe/probegroup. + If some contacts of the Probe are not connected (device_channel_indices=-1) + then the recording is "sliced" and only connected channels are kept. Note: The probe order of the probegroup is not kept. Channel ids are re-ordered to match the channel_ids of the recording. Parameters ---------- - probe_or_probegroup: ProbeGroup, or dict + probegroup: ProbeGroup, or dict The probe(s) to be attached to the recording group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" How to add the "group" property. "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks and two sides are present. - in_place: (deprecated) bool | None, default: None - Deprecated argument to indicate whether to modify the recording in place - or return a new recording. The function is always in place now. - Use the `recording.select_channels_with_probegroup()` method instead of `in_place=False` - to return a new recording with a channel selection to match the probe/probegroup. - """ - if in_place is not None: - warnings.warn( - "The 'in_place' argument is deprecated and will be removed in version 0.106.0. " - "The `set_probe/probegroup()` are always in place and assume that the probe/probegroup has the " - "same number of connected contacts as the number of channels in the recording. " - "Use the `recording.select_channels_with_probegroup()` method instead to return a new recording with " - "a channel selection to match the probe/probegroup.", - DeprecationWarning, - stacklevel=2, - ) - if not in_place: - return self.select_channels_with_probegroup(probegroup, group_mode=group_mode) - - # Handle several input possibilities: Probe or dict - if isinstance(probegroup, dict): - probegroup = ProbeGroup.from_dict(probegroup) - - probegroup_sorted = self._get_probegroup_based_on_device_channel_indices(probegroup) - - if probegroup_sorted.get_contact_count() != self.get_num_channels(): - raise ValueError( - "The probe/probegroup must have the same number of connected contacts " - f"as the number of channels as the recording, but the probe has {probegroup.get_contact_count()} " - f"connected channels and the recording has {self.get_num_channels()} channels. " - "Use the `recording.select_channels_with_probegroup()` method instead to return a new recording with " - "a channel selection to match the probe/probegroup." - ) - probegroup_sorted.set_global_device_channel_indices(np.arange(probegroup_sorted.get_contact_count())) - self._probegroup = probegroup_sorted - - # Handle and set channel groups - _set_group_property_based_on_probegroup(self, probegroup_sorted, group_mode=group_mode) - - def select_channels_with_probe( - self, probe: Probe, group_mode: Literal["auto", "by_probe", "by_shank", "by_side"] = "auto" - ) -> "BaseRecordingSnippets": - """ - Returns a new recording with channels selected based on the probe. - - Parameters - ---------- - probe: Probe - The probe to be used for channel selection - group_mode: "auto" | "by_probe" | "by_shank" | - "by_side", default: "auto" - How to add the "group" property. - "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks and two sides are present. + in_place: bool, default: False + If False, a new recording (view or channel selection) is returned. + If True, the recording is modified in place, which requires all channels to be connected. Returns ------- sub_recording: BaseRecording A view of the recording (ChannelSlice or clone or itself) """ - assert isinstance(probe, Probe), "The input must be a Probe object" - probegroup = ProbeGroup() - probegroup.add_probe(probe) - return self.select_channels_with_probegroup(probegroup, group_mode=group_mode) - - def select_channels_with_probegroup( - self, probegroup: ProbeGroup, group_mode: Literal["auto", "by_probe", "by_shank", "by_side"] = "auto" - ) -> "BaseRecordingSnippets": - """ - Selects channels based on the given ProbeGroup and returns a new recording with the selected channels. - - Parameters - ---------- - probegroup: ProbeGroup - The probegroup to be used for channel selection - group_mode: "auto" | "by_probe" | "by_shank" | - "by_side", default: "auto" - How to add the "group" property. - "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks - and two sides are present. + # Handle several input possibilities: ProbeGroup or dict + if isinstance(probegroup, dict): + probegroup = ProbeGroup.from_dict(probegroup) - Returns - ------- - sub_recording: BaseRecording - A view of the recording (ChannelSlice or clone or itself) - """ probegroup_sorted = self._get_probegroup_based_on_device_channel_indices(probegroup) + if probegroup_sorted.get_contact_count() > 0: sorted_dci = probegroup_sorted.get_global_device_channel_indices()["device_channel_indices"] new_channel_ids = self.channel_ids[sorted_dci] - probegroup_sorted.set_global_device_channel_indices(np.arange(len(new_channel_ids))) - if np.array_equal(new_channel_ids, self.channel_ids): + else: + new_channel_ids = self.channel_ids[[]] # empty selection + + # create recording: itself (in place), clone or channel slice + if in_place: + if not np.array_equal(new_channel_ids, self.get_channel_ids()): + raise ValueError( + "set_probe(in_place=True) requires the probe/probegroup to have the same number of connected " + "contacts as the number of channels in the recording. Use in_place=False to return a new " + "recording with a channel selection to match the probe/probegroup." + ) + sub_recording = self + else: + if np.array_equal(new_channel_ids, self.get_channel_ids()): sub_recording = self.clone() else: sub_recording = self.select_channels(new_channel_ids) + + if probegroup_sorted.get_contact_count() > 0: + probegroup_sorted.set_global_device_channel_indices(np.arange(probegroup_sorted.get_contact_count())) sub_recording._probegroup = probegroup_sorted + # Handle and set channel groups _set_group_property_based_on_probegroup(sub_recording, probegroup_sorted, group_mode=group_mode) else: - sub_recording = self.select_channels([]) # empty recording sub_recording._probegroup = ProbeGroup() # empty probegroup + return sub_recording def _get_probegroup_based_on_device_channel_indices(self, probegroup: ProbeGroup) -> ProbeGroup: @@ -333,10 +283,10 @@ def _extra_metadata_from_folder(self, folder): legacy_probe_file = folder / "probe.json" if probe_file.is_file(): probegroup = read_probeinterface(probe_file) - self.set_probegroup(probegroup) + self.set_probegroup(probegroup, in_place=True) elif legacy_probe_file.is_file(): probegroup = read_probeinterface(legacy_probe_file) - self.set_probegroup(probegroup) + self.set_probegroup(probegroup, in_place=True) # remove "contact_vector" property if present as it is not needed anymore if "contact_vector" in self.get_property_keys(): @@ -352,7 +302,7 @@ def _extra_metadata_from_dict(self, dump_dict): # load probe if "probegroup" in dump_dict: probegroup = dump_dict["probegroup"] - self.set_probegroup(probegroup) + self.set_probegroup(probegroup, in_place=True) def _extra_metadata_to_dict(self, dump_dict): # save probe @@ -412,7 +362,7 @@ def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params probe = self.create_dummy_probe_from_locations( np.array(locations), shape=shape, shape_params=shape_params, axes=axes ) - self.set_probe(probe) + self.set_probe(probe, in_place=True) def set_channel_locations(self, locations, channel_ids=None): warnings.warn( @@ -510,7 +460,7 @@ def planarize(self, axes: str = "xy"): probe2d = self.get_probe().to_2d(axes=axes) recording2d = self.clone() - recording2d.set_probe(probe2d) + recording2d.set_probe(probe2d, in_place=True) return recording2d diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index a1b0563186..39e8c51225 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -261,7 +261,7 @@ def _save(self, format="npy", **save_kwargs): if self.has_probe(): probegroup = self.get_probegroup() - cached.set_probegroup(probegroup) + cached.set_probegroup(probegroup, in_place=True) return cached diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 28f37201b5..746f35eb47 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -111,7 +111,7 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record for probe in probegroup.probes: probegroup_agg.add_probe(probe.copy()) probegroup_agg.set_global_device_channel_indices(np.arange(num_all_channels)) - self.set_probegroup(probegroup_agg) + self.set_probegroup(probegroup_agg, in_place=True) # finally add segments, we need a channel mapping ch_id = 0 diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index 8669e3c90c..1df401f8db 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -66,7 +66,7 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) parent_probegroup = self._parent.get_probegroup() sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) sliced_probegroup.set_global_device_channel_indices(np.arange(len(self._channel_ids))) - self.set_probegroup(sliced_probegroup) + self.set_probegroup(sliced_probegroup, in_place=True) # update dump dict self._kwargs = { @@ -157,7 +157,7 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): parent_probegroup = self._parent_snippets.get_probegroup() sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) sliced_probegroup.set_global_device_channel_indices(np.arange(len(self._channel_ids))) - self.set_probegroup(sliced_probegroup) + self.set_probegroup(sliced_probegroup, in_place=True) # update dump dict self._kwargs = { diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 9ca5cb2df9..d3a38931a4 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -86,7 +86,7 @@ def generate_recording( if ndim == 3: probe = probe.to_3d() probe.set_device_channel_indices(np.arange(num_channels)) - recording.set_probe(probe) + recording.set_probe(probe, in_place=True) recording.name = "SyntheticRecording" @@ -675,7 +675,7 @@ def generate_snippets( if set_probe: probe = recording.get_probe() - snippets.set_probe(probe) + snippets.set_probe(probe, in_place=True) return snippets, sorting @@ -2462,7 +2462,7 @@ def generate_ground_truth_recording( upsample_vector=upsample_vector, ) recording.annotate(is_filtered=True) - recording.set_probe(probe) + recording.set_probe(probe, in_place=True) recording.set_channel_gains(1.0) recording.set_channel_offsets(0.0) diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index e36150c43e..5f7786e5ac 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -203,15 +203,15 @@ def test_BaseRecording(create_cache_folder): ) probe.create_auto_shape() - rec_p = rec.select_channels_with_probe(probe, group_mode="auto") + rec_p = rec.set_probe(probe, group_mode="auto") positions2 = rec_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) - rec_p = rec.select_channels_with_probe(probe, group_mode="by_shank") + rec_p = rec.set_probe(probe, group_mode="by_shank") positions2 = rec_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) - rec_p = rec.select_channels_with_probe(probe, group_mode="by_probe") + rec_p = rec.set_probe(probe, group_mode="by_probe") positions2 = rec_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) @@ -254,13 +254,13 @@ def test_BaseRecording(create_cache_folder): probe.create_auto_shape() traces = np.zeros((1000, 12), dtype="int16") rec = NumpyRecording([traces], 30000.0) - rec1 = rec.select_channels_with_probe(probe, group_mode="auto") + rec1 = rec.set_probe(probe, group_mode="auto") assert np.unique(rec1.get_property("group")).size == 4 - rec2 = rec.select_channels_with_probe(probe, group_mode="by_probe") + rec2 = rec.set_probe(probe, group_mode="by_probe") assert np.unique(rec2.get_property("group")).size == 1 - rec3 = rec.select_channels_with_probe(probe, group_mode="by_shank") + rec3 = rec.set_probe(probe, group_mode="by_shank") assert np.unique(rec3.get_property("group")).size == 2 - rec4 = rec.select_channels_with_probe(probe, group_mode="by_side") + rec4 = rec.set_probe(probe, group_mode="by_side") assert np.unique(rec4.get_property("group")).size == 4 # set unconnected probe @@ -270,7 +270,7 @@ def test_BaseRecording(create_cache_folder): probe.set_device_channel_indices([-1, -1, -1]) probe.create_auto_shape() - rec_empty_probe = rec.select_channels_with_probe(probe, group_mode="by_shank") + rec_empty_probe = rec.set_probe(probe, group_mode="by_shank") assert rec_empty_probe.channel_ids.size == 0 # test scaling parameters @@ -464,7 +464,7 @@ def test_probes_info_annotation_backward_compat(): {"name": "probe_B", "manufacturer": "vendor_Y"}, ] - rec.set_probegroup(pg) # new default: in_place=None → always in-place + rec = rec.set_probegroup(pg) probes = rec.get_probes() assert len(probes) == 2 diff --git a/src/spikeinterface/core/tests/test_basesnippets.py b/src/spikeinterface/core/tests/test_basesnippets.py index 05710b1607..e4132870d6 100644 --- a/src/spikeinterface/core/tests/test_basesnippets.py +++ b/src/spikeinterface/core/tests/test_basesnippets.py @@ -142,7 +142,7 @@ def test_BaseSnippets(create_cache_folder): probe.set_device_channel_indices([2, -1, 0]) probe.create_auto_shape() - snippets_p = snippets.select_channels_with_probe(probe, group_mode="auto") + snippets_p = snippets.set_probe(probe, group_mode="auto") positions2 = snippets_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py index 8936e6a650..e437eeec9b 100644 --- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py +++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py @@ -14,7 +14,7 @@ def _make_rec_with_named_probe(name, manufacturer, x_shift): probe.set_device_channel_indices(np.arange(8)) probe.create_auto_shape() rec = generate_recording(num_channels=8, durations=[1.0], set_probe=False) - rec.set_probe(probe) + rec.set_probe(probe, in_place=True) return rec diff --git a/src/spikeinterface/core/tests/test_channelslicerecording.py b/src/spikeinterface/core/tests/test_channelslicerecording.py index c563888e07..bcf63c2083 100644 --- a/src/spikeinterface/core/tests/test_channelslicerecording.py +++ b/src/spikeinterface/core/tests/test_channelslicerecording.py @@ -60,7 +60,7 @@ def test_ChannelSliceRecording(create_cache_folder): # with probe and after save() probe = probeinterface.generate_linear_probe(num_elec=num_chan) probe.set_device_channel_indices(np.arange(num_chan)) - rec.set_probe(probe) + rec.set_probe(probe, in_place=True) rec_sliced3 = ChannelSliceRecording(rec, channel_ids=[0, 2], renamed_channel_ids=[3, 4]) probe3 = rec_sliced3.get_probe() locations3 = probe3.contact_positions @@ -117,7 +117,7 @@ def test_select_channels_preserves_probe_metadata(): probegroup.add_probe(probe_B) recording = generate_recording(durations=[1.0], num_channels=16, set_probe=False) - recording.set_probegroup(probegroup) + recording.set_probegroup(probegroup, in_place=True) # Drop all of probe A, keep only probe B sub = recording.select_channels(recording.channel_ids[8:]) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 8b40e3e93f..a9bd71b5c0 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -318,7 +318,7 @@ def test_SortingAnalyzer_interleaved_probegroup(dataset): probegroup.add_probe(probe2) probegroup.set_global_device_channel_indices(np.random.permutation(num_channels)) - recording.set_probegroup(probegroup) + recording = recording.set_probegroup(probegroup) sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) # check that locations are correct diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index e224a8f289..fd427ce942 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -180,7 +180,7 @@ def __init__( probe_dict = self._root.attrs.get("probegroup", self._root.attrs.get("probe", None)) if probe_dict is not None: probegroup = ProbeGroup.from_dict(probe_dict) - self.set_probegroup(probegroup) + self.set_probegroup(probegroup, in_place=True) # load properties if "properties" in self._root: diff --git a/src/spikeinterface/extractors/bids.py b/src/spikeinterface/extractors/bids.py index 3a48084ab9..777bdd914b 100644 --- a/src/spikeinterface/extractors/bids.py +++ b/src/spikeinterface/extractors/bids.py @@ -40,7 +40,7 @@ def read_bids(folder_path): rec.annotate(bids_name=bids_name) rec.extra_requirements.extend("pandas") probegroup = _read_probe_group(file_path.parent, bids_name, rec.channel_ids) - rec.set_probegroup(probegroup) + rec = rec.set_probegroup(probegroup) recordings.append(rec) elif file_path.suffix == ".nix": @@ -54,7 +54,7 @@ def read_bids(folder_path): rec = read_nix(file_path, stream_id=stream_id) rec.extra_requirements.extend("pandas") probegroup = _read_probe_group(file_path.parent, bids_name, rec.channel_ids) - rec.set_probegroup(probegroup) + rec = rec.set_probegroup(probegroup) recordings.append(rec) return recordings diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index 891cbaee07..2a53b999e3 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -102,9 +102,9 @@ def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", probe = probeinterface.read_spikeglx(meta_file) if probe.shank_ids is not None: - self.set_probe(probe, group_mode="by_shank") + self.set_probe(probe, in_place=True, group_mode="by_shank") else: - self.set_probe(probe) + self.set_probe(probe, in_place=True) sample_shifts = get_neuropixels_sample_shifts_from_probe(probe) self.set_property("inter_sample_shift", sample_shifts) diff --git a/src/spikeinterface/extractors/iblextractors.py b/src/spikeinterface/extractors/iblextractors.py index 8a57e40ec3..779c36fa23 100644 --- a/src/spikeinterface/extractors/iblextractors.py +++ b/src/spikeinterface/extractors/iblextractors.py @@ -221,9 +221,9 @@ def __init__( probe = probeinterface.read_spikeglx(meta_file) if probe.shank_ids is not None: - self.set_probe(probe, group_mode="by_shank") + self.set_probe(probe, in_place=True, group_mode="by_shank") else: - self.set_probe(probe) + self.set_probe(probe, in_place=True) # set channel properties # sometimes there are missing metadata files on the IBL side diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index b3ccb92cbd..c85e82b574 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -72,7 +72,7 @@ def __init__( probe = probeinterface.read_3brain(file_path, **probe_kwargs) rows = probe.contact_annotations["row"] cols = probe.contact_annotations["col"] - self.set_probe(probe) + self.set_probe(probe, in_place=True) self.set_property("row", rows) self.set_property("col", cols) diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 38e65096c2..5eaa49e6b8 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -75,7 +75,7 @@ def __init__( rec_name = self.neo_reader.rec_name probe = probeinterface.read_maxwell(file_path, well_name=well_name, rec_name=rec_name) electrodes = probe.contact_annotations["electrode"] - self.set_probe(probe) + self.set_probe(probe, in_place=True) self.set_property("electrode", electrodes) self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), rec_name=rec_name)) diff --git a/src/spikeinterface/extractors/neoextractors/mearec.py b/src/spikeinterface/extractors/neoextractors/mearec.py index d4cbe1b0de..7ca82af01e 100644 --- a/src/spikeinterface/extractors/neoextractors/mearec.py +++ b/src/spikeinterface/extractors/neoextractors/mearec.py @@ -55,7 +55,7 @@ def __init__(self, file_path: str | Path, all_annotations: bool = False, use_nam probe = probeinterface.read_mearec(file_path) probe.annotations["mearec_name"] = str(probe.annotations["mearec_name"]) - self.set_probe(probe) + self.set_probe(probe, in_place=True) self.annotate(is_filtered=True) if hasattr(self.neo_reader._recgen, "gain_to_uV"): diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index 22ae82b117..5dc9220aa5 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -331,9 +331,9 @@ def __init__( settings_file=settings_file, stream_name=oe_stream_name ) if probe.shank_ids is not None: - self.set_probe(probe, group_mode="by_shank") + self.set_probe(probe, in_place=True, group_mode="by_shank") else: - self.set_probe(probe) + self.set_probe(probe, in_place=True) # get inter-sample shifts based on the probe information and mux channels sample_shifts = get_neuropixels_sample_shifts_from_probe(probe) if sample_shifts is not None: diff --git a/src/spikeinterface/extractors/neoextractors/spikegadgets.py b/src/spikeinterface/extractors/neoextractors/spikegadgets.py index da4a66e1f5..adc50df12f 100644 --- a/src/spikeinterface/extractors/neoextractors/spikegadgets.py +++ b/src/spikeinterface/extractors/neoextractors/spikegadgets.py @@ -79,7 +79,7 @@ def __init__( if saturation_threshold_uV_probe is not None: saturation_thresholds_uV.append(saturation_threshold_uV_probe) - self.set_probegroup(probegroup) + self.set_probegroup(probegroup, in_place=True) if np.all(sample_shifts != -1): self.set_property("inter_sample_shift", sample_shifts) diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index 41c2b77bfc..60b1a98be8 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -86,9 +86,9 @@ def __init__( probe = probeinterface.read_spikeglx(ap_meta_filename) if probe.shank_ids is not None: - self.set_probe(probe, group_mode="by_shank") + self.set_probe(probe, in_place=True, group_mode="by_shank") else: - self.set_probe(probe) + self.set_probe(probe, in_place=True) # get inter-sample shifts based on the probe information and mux channels sample_shifts = get_neuropixels_sample_shifts_from_probe(probe) diff --git a/src/spikeinterface/extractors/shybridextractors.py b/src/spikeinterface/extractors/shybridextractors.py index eca7d46724..ff08c1a3f3 100644 --- a/src/spikeinterface/extractors/shybridextractors.py +++ b/src/spikeinterface/extractors/shybridextractors.py @@ -82,7 +82,7 @@ def __init__(self, file_path): # load probe file probegroup = probeinterface.read_prb(params["probe"]) - self.set_probegroup(probegroup) + self.set_probegroup(probegroup, in_place=True) self._kwargs = {"file_path": str(Path(file_path).absolute())} self.extra_requirements.extend(["hybridizer", "pyyaml"]) diff --git a/src/spikeinterface/extractors/sinapsrecordingextractors.py b/src/spikeinterface/extractors/sinapsrecordingextractors.py index f47a83bc47..132a01f300 100644 --- a/src/spikeinterface/extractors/sinapsrecordingextractors.py +++ b/src/spikeinterface/extractors/sinapsrecordingextractors.py @@ -84,7 +84,7 @@ def __init__(self, file_path: str | Path, stream_name: str = "filt"): if (stream_name == "filt") | (stream_name == "raw"): probe = get_sinaps_probe(probe_type) if probe is not None: - self.set_probe(probe) + self.set_probe(probe, in_place=True) self._kwargs = {"file_path": str(file_path.absolute()), "stream_name": stream_name} @@ -143,7 +143,7 @@ def __init__(self, file_path: str | Path, stream_name: str = "filt"): # set probe probe = get_sinaps_probe(sinaps_info["probe_type"]) if probe is not None: - self.set_probe(probe) + self.set_probe(probe, in_place=True) self._kwargs = {"file_path": str(Path(file_path).absolute()), "stream_name": stream_name} diff --git a/src/spikeinterface/generation/drift_tools.py b/src/spikeinterface/generation/drift_tools.py index 6996800e27..1800138dae 100644 --- a/src/spikeinterface/generation/drift_tools.py +++ b/src/spikeinterface/generation/drift_tools.py @@ -477,7 +477,7 @@ def __init__( ) self.add_recording_segment(recording_segment) - self.set_probe(drifting_templates.probe) + self.set_probe(drifting_templates.probe, in_place=True) # templates are too large, we don't serialize them to JSON self._serializability["json"] = False diff --git a/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py index c8825831b0..c7c37968d9 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py @@ -30,7 +30,7 @@ def recording_and_shape(): probe = probeinterface.generate_multi_columns_probe(num_columns=num_cols, num_contact_per_column=num_rows) probe.set_device_channel_indices(np.arange(num_cols * num_rows)) recording = generate_recording(num_channels=num_cols * num_rows, durations=[10.0], sampling_frequency=30000) - recording.set_probe(probe) + recording.set_probe(probe, in_place=True) recording = depth_order(recording) recording = zscore(recording) desired_shape = (num_rows, num_cols) diff --git a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py index 5a0e160f92..35f398f985 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py @@ -80,7 +80,7 @@ def test_detect_bad_channels_std_mad(): probe = generate_linear_probe(num_elec=num_channels) probe.set_device_channel_indices(np.arange(num_channels)) - rec.set_probe(probe) + rec.set_probe(probe, in_place=True) bad_channels_std, bad_labels_std = detect_bad_channels(rec, method="std") bad_channels_mad, bad_labels_mad = detect_bad_channels(rec, method="std") @@ -125,7 +125,7 @@ def test_detect_bad_channels_extremes(outside_channels_location): probe = generate_linear_probe(num_elec=num_channels) probe.set_device_channel_indices(np.arange(num_channels)) - rec.set_probe(probe) + rec.set_probe(probe, in_place=True) bad_channel_ids, bad_labels = detect_bad_channels( rec, method="coherence+psd", outside_channels_location=outside_channels_location diff --git a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py index 89e8e36cf8..bfa4d3d9ae 100644 --- a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py @@ -118,7 +118,7 @@ def test_highpass_spatial_filter_with_dead_channels(): rec_with_dead = NumpyRecording( traces_list=[traces], sampling_frequency=rec.sampling_frequency, channel_ids=rec.channel_ids ) - rec_with_dead.set_probe(rec.get_probe()) + rec_with_dead.set_probe(rec.get_probe(), in_place=True) filtered = spre.highpass_spatial_filter(rec_with_dead, n_channel_pad=2) result = filtered.get_traces() assert result.shape == traces.shape diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index c79605a110..1294b57a91 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -132,7 +132,7 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan new_positions = probe.contact_positions.copy() new_positions[:, 0] = x_new # column 0 is x recording._probegroup.probes[0]._contact_positions = new_positions - recording.set_probe(probe) + recording.set_probe(probe, in_place=True) # generate random bad channel locations bad_channel_indexes = rng.choice(num_channels, rng.integers(1, int(num_channels / 5)), replace=False) @@ -180,7 +180,7 @@ def test_output_values(): probe = pi.Probe(ndim=2) probe.set_contacts(positions=probe_locs) probe.set_device_channel_indices(np.arange(len(probe_locs))) - recording.set_probe(probe) + recording.set_probe(probe, in_place=True) # Run interpolation in SI and check the interpolated channel # 0 is a linear combination of other channels diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 5698e0e142..80d79171ce 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -410,7 +410,7 @@ def __init__( channel_indices = recording.ids_to_indices(channel_ids) probegroup_sliced = probegroup.get_slice(channel_indices) probegroup_sliced.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) - self.set_probegroup(probegroup_sliced) + self.set_probegroup(probegroup_sliced, in_place=True) # handle manual interpolation_time_bin_centers_s # the case where interpolation_time_bin_size_s is set is handled per-segment below From 868a40f658443c21c758c3869a3621c791a9dc1a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 2 Jul 2026 11:45:20 +0200 Subject: [PATCH 19/19] doc: revert changes to docs --- examples/forhowto/plot_working_with_tetrodes.py | 6 +++--- examples/tutorials/core/plot_1_recording_extractor.py | 2 +- examples/tutorials/core/plot_3_handle_probe_info.py | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/forhowto/plot_working_with_tetrodes.py b/examples/forhowto/plot_working_with_tetrodes.py index 547e9deae1..0c652a5186 100644 --- a/examples/forhowto/plot_working_with_tetrodes.py +++ b/examples/forhowto/plot_working_with_tetrodes.py @@ -62,15 +62,15 @@ # We can now attach the :code:`tetrode_group` to our recording. To check if this worked, we'll # plot the probe map -recording.set_probegroup(tetrode_group) -plot_probe_map(recording) +recording_with_probe = recording.set_probegroup(tetrode_group) +plot_probe_map(recording_with_probe) ############################################################################## # Looks good! Now that the recording is aware of the probe geometry, we can # begin a standard spike sorting pipeline. First, we can apply preprocessing. # Note that we apply this preprocessing on the entire bundle of tetrodes. -preprocessed_recording = spre.bandpass_filter(recording) +preprocessed_recording = spre.bandpass_filter(recording_with_probe) ############################################################################## # WARNING: a very common preprocessing step is to apply a common median diff --git a/examples/tutorials/core/plot_1_recording_extractor.py b/examples/tutorials/core/plot_1_recording_extractor.py index 477ba165b6..e3bfda5855 100644 --- a/examples/tutorials/core/plot_1_recording_extractor.py +++ b/examples/tutorials/core/plot_1_recording_extractor.py @@ -70,7 +70,7 @@ probe.set_device_channel_indices(np.arange(7)) # then we need to actually set the probe to the recording object -recording.set_probe(probe) +recording = recording.set_probe(probe) plot_probe(probe) ############################################################################## diff --git a/examples/tutorials/core/plot_3_handle_probe_info.py b/examples/tutorials/core/plot_3_handle_probe_info.py index 28d2af655a..deff58ebb7 100644 --- a/examples/tutorials/core/plot_3_handle_probe_info.py +++ b/examples/tutorials/core/plot_3_handle_probe_info.py @@ -43,8 +43,8 @@ print(other_probe) other_probe.set_device_channel_indices(np.arange(32)) -recording.set_probe(other_probe, group_mode="by_shank") -plot_probe(recording.get_probe()) +recording_2_shanks = recording.set_probe(other_probe, group_mode="by_shank") +plot_probe(recording_2_shanks.get_probe()) ############################################################################### # Now let's check what we have loaded. The :code:`group_mode='by_shank'` automatically @@ -53,11 +53,11 @@ # We can access this information either as a dict with :code:`outputs='dict'` (default) # or as a list of recordings with :code:`outputs='list'`. -print(recording) -print(f'\nGroup Property: {recording.get_property("group")}\n') +print(recording_2_shanks) +print(f'\nGroup Property: {recording_2_shanks.get_property("group")}\n') # Here we split as a dict -sub_recording_dict = recording.split_by(property="group", outputs='dict') +sub_recording_dict = recording_2_shanks.split_by(property="group", outputs='dict') # Then we can pull out the individual sub-recordings sub_rec0 = sub_recording_dict[0]