diff --git a/.github/scripts/create_probe_compat_fixtures.py b/.github/scripts/create_probe_compat_fixtures.py new file mode 100644 index 0000000000..83fafacff8 --- /dev/null +++ b/.github/scripts/create_probe_compat_fixtures.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python +""" +Creates probe compatibility fixtures using the *currently installed* spikeinterface. + +Run this script with spikeinterface==0.104.* installed to produce the fixture +files consumed by test_probe_backward_compat.py: + + python create_probe_compat_fixtures.py [output_dir] + +If output_dir is omitted, fixtures are written to ./probe_compat_fixtures. +""" + +import sys +import shutil +import numpy as np +from pathlib import Path + +import spikeinterface + +print(f"Creating fixtures with spikeinterface {spikeinterface.__version__}") + +OUTPUT_DIR = Path(sys.argv[1]) if len(sys.argv) > 1 else Path("probe_compat_fixtures") +if OUTPUT_DIR.exists(): + shutil.rmtree(OUTPUT_DIR) +OUTPUT_DIR.mkdir(parents=True) + +from probeinterface import generate_linear_probe, ProbeGroup +from spikeinterface.core import NumpyRecording + +# ----------------------------------------------------------------------- +# Fixture 1: single probe, sequential device_channel_indices +# ----------------------------------------------------------------------- +n = 8 +probe = generate_linear_probe(num_elec=n, ypitch=20.0) +probe.annotate(name="test_probe", manufacturer="test_vendor") +probe.set_contact_ids([f"e{i}" for i in range(n)]) +probe.set_device_channel_indices(np.arange(n)) +probe.create_auto_shape() + +traces = np.arange(1000 * n, dtype="int16").reshape(1000, n) +rec_single = NumpyRecording([traces], sampling_frequency=30000.0) +rec_single = rec_single.set_probe(probe) # old API: in_place=False, returns new recording + +rec_single.save(folder=str(OUTPUT_DIR / "single_probe_binary")) +rec_single.save(folder=str(OUTPUT_DIR / "single_probe.zarr"), format="zarr") + +# ----------------------------------------------------------------------- +# Fixture 2: two probes with per-probe name/manufacturer +# ----------------------------------------------------------------------- +n_A, n_B = 8, 8 +probe_A = generate_linear_probe(num_elec=n_A, ypitch=20.0) +probe_A.move([0.0, 0.0]) +probe_A.annotate(name="probe_A", manufacturer="vendor_X") +probe_A.set_contact_ids([f"a{i}" for i in range(n_A)]) +probe_A.set_device_channel_indices(np.arange(n_A)) +probe_A.create_auto_shape() + +probe_B = generate_linear_probe(num_elec=n_B, ypitch=20.0) +probe_B.move([500.0, 0.0]) +probe_B.annotate(name="probe_B", manufacturer="vendor_Y") +probe_B.set_contact_ids([f"b{i}" for i in range(n_B)]) +probe_B.set_device_channel_indices(np.arange(n_A, n_A + n_B)) +probe_B.create_auto_shape() + +pg = ProbeGroup() +pg.add_probe(probe_A) +pg.add_probe(probe_B) + +n_total = n_A + n_B +traces2 = np.arange(1000 * n_total, dtype="int16").reshape(1000, n_total) +rec_two = NumpyRecording([traces2], sampling_frequency=30000.0) +rec_two = rec_two.set_probegroup(pg) # old API: in_place=False, returns new recording + +rec_two.save(folder=str(OUTPUT_DIR / "two_probe_binary")) +rec_two.save(folder=str(OUTPUT_DIR / "two_probe.zarr"), format="zarr") + +# ----------------------------------------------------------------------- +# Fixture 3: probe with shuffled device_channel_indices +# Verifies that the channel-reordering logic is preserved across versions. +# ----------------------------------------------------------------------- +n = 8 +probe_sh = generate_linear_probe(num_elec=n, ypitch=20.0) +probe_sh.annotate(name="shuffled_probe", manufacturer="shuffle_vendor") +shuffled_dci = np.array([3, 0, 7, 1, 5, 2, 6, 4]) # permutation of 0..7 +probe_sh.set_device_channel_indices(shuffled_dci) + +# traces[:, j] corresponds to recording channel j, which after set_probe +# is mapped to the contact whose dci equals j. +traces3 = np.arange(1000 * n, dtype="int16").reshape(1000, n) +rec_sh = NumpyRecording([traces3], sampling_frequency=30000.0) +rec_sh = rec_sh.set_probe(probe_sh) # old API + +rec_sh.save(folder=str(OUTPUT_DIR / "shuffled_probe_binary")) +rec_sh.save(folder=str(OUTPUT_DIR / "shuffled_probe.zarr"), format="zarr") + +print(f"Fixtures written to: {OUTPUT_DIR.resolve()}") diff --git a/.github/workflows/probe_backward_compat.yml b/.github/workflows/probe_backward_compat.yml new file mode 100644 index 0000000000..7eeab4089b --- /dev/null +++ b/.github/workflows/probe_backward_compat.yml @@ -0,0 +1,51 @@ +name: Probe backward compatibility + +on: + pull_request: + types: [synchronize, opened, reopened] + branches: + - main + paths: + - 'src/spikeinterface/core/**' + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + probe-backward-compat: + name: Probe compat (SI 0.104.* → current) + runs-on: ubuntu-latest + env: + SI_PROBE_COMPAT_FIXTURES_DIR: ${{ github.workspace }}/probe_compat_fixtures + + steps: + - name: Check out code + uses: actions/checkout@v6 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.11' + + - name: Set up uv + uses: astral-sh/setup-uv@v7 + with: + python-version: '3.11' + enable-cache: false + + # Step 1: install the OLD release and create fixtures. + # The fixture script uses the old in_place=False default (returns a new recording), + # saves to binary folder + JSON, and writes a known probe name/manufacturer/contact_ids. + - name: Install spikeinterface 0.104.* to create fixtures + run: uv pip install --system "spikeinterface[core]==0.104.*" + + - name: Create compatibility fixtures with old version + run: python .github/scripts/create_probe_compat_fixtures.py "$SI_PROBE_COMPAT_FIXTURES_DIR" + + # Step 2: install the NEW version from this PR source and run the load tests. + - name: Install new spikeinterface from source + run: uv pip install --system -e . --group test-core + + - name: Run backward compatibility tests + run: pytest src/spikeinterface/core/tests/test_probe_backward_compat.py -v diff --git a/doc/api.rst b/doc/api.rst index fc55017606..6aca96c4ec 100755 --- a/doc/api.rst +++ b/doc/api.rst @@ -16,6 +16,10 @@ spikeinterface.core .. automethod:: BaseRecording.dump_to_json .. automethod:: BaseRecording.dump_to_pickle .. automethod:: BaseRecording.remove_channels + .. automethod:: BaseRecording.set_probe + .. automethod:: BaseRecording.set_probegroup + .. automethod:: BaseRecording.reset_probe + .. automethod:: BaseRecording.split_by .. autoclass:: BaseSorting :members: .. automethod:: BaseSorting.save @@ -25,6 +29,8 @@ spikeinterface.core .. automethod:: BaseSorting.dump .. automethod:: BaseSorting.dump_to_json .. automethod:: BaseSorting.dump_to_pickle + .. automethod:: BaseSorting.split_by + .. automethod:: BaseSorting.register_recording .. autoclass:: BaseSnippets :members: .. automethod:: BaseSnippets.save diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index fcbafdb6bf..bbd4803748 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -220,6 +220,14 @@ def id_to_index(self, id) -> int: return ind def annotate(self, **new_annotations) -> None: + """Adds annotations. + + Parameters + ---------- + **new_annotations : dict + Key-value pairs of annotations to add. If an annotation key already exists, + it will be overwritten. + """ self._annotations.update(new_annotations) def set_annotation(self, annotation_key: str, value: Any, overwrite=False) -> None: @@ -243,6 +251,24 @@ def set_annotation(self, annotation_key: str, value: Any, overwrite=False) -> No else: raise ValueError(f"{annotation_key} is already an annotation key. Use 'overwrite=True' to overwrite it") + def delete_annotation(self, annotation_key: str) -> None: + """Deletes existing annotation. + + Parameters + ---------- + annotation_key : str + The annotation key to delete + + Raises + ------ + ValueError + If the annotation key does not exist + """ + if annotation_key in self._annotations.keys(): + del self._annotations[annotation_key] + else: + raise ValueError(f"{annotation_key} is not an annotation key") + def get_preferred_mp_context(self): """ Get the preferred context for multiprocessing. @@ -441,6 +467,15 @@ def copy_metadata( if self._preferred_mp_context is not None: other._preferred_mp_context = self._preferred_mp_context + if not only_main: + self._extra_metadata_copy(other) + + def _extra_metadata_copy(self, other: "BaseExtractor") -> None: + """ + This is a hook to copy extra metadata that is not in the annotations/properties dict. + """ + pass + def to_dict( self, include_annotations: bool = False, @@ -574,6 +609,8 @@ def to_dict( folder_metadata = Path(folder_metadata).resolve().absolute().relative_to(relative_to) dump_dict["folder_metadata"] = str(folder_metadata) + self._extra_metadata_to_dict(dump_dict) + return dump_dict @staticmethod @@ -610,8 +647,6 @@ def load_metadata_from_folder(self, folder_metadata): # hack to load probe for recording folder_metadata = Path(folder_metadata) - self._extra_metadata_from_folder(folder_metadata) - # load properties prop_folder = folder_metadata / "properties" if prop_folder.is_dir(): @@ -621,6 +656,8 @@ def load_metadata_from_folder(self, folder_metadata): key = prop_file.stem self.set_property(key, values) + self._extra_metadata_from_folder(folder_metadata) + def save_metadata_to_folder(self, folder_metadata): self._extra_metadata_to_folder(folder_metadata) @@ -862,6 +899,14 @@ def _extra_metadata_to_folder(self, folder): # This implemented in BaseRecording for probe pass + def _extra_metadata_from_dict(self, dump_dict): + # This implemented in BaseRecording for probe + pass + + def _extra_metadata_to_dict(self, dump_dict): + # This implemented in BaseRecording for probe + pass + def save(self, **kwargs) -> "BaseExtractor": """ Save a SpikeInterface object. @@ -997,10 +1042,10 @@ def save_to_folder( else: warnings.warn("The extractor is not serializable to file. The provenance will not be saved.") - self.save_metadata_to_folder(folder) - # save data (done the subclass) + self.save_metadata_to_folder(folder) cached = self._save(folder=folder, verbose=verbose, **save_kwargs) + cached.load_metadata_from_folder(folder) # copy properties/ self.copy_metadata(cached) @@ -1155,6 +1200,8 @@ def _load_extractor_from_dict(dic) -> "BaseExtractor": for k, v in dic["properties"].items(): extractor.set_property(k, v) + extractor._extra_metadata_from_dict(dic) + return extractor diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index b0f75930d3..c61d602026 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -20,7 +20,6 @@ class BaseRecording(BaseRecordingSnippets, TimeSeries): _main_annotations = BaseRecordingSnippets._main_annotations + ["is_filtered"] _main_properties = [ "group", - "location", "gain_to_uV", "offset_to_uV", "gain_to_physical_unit", @@ -324,6 +323,8 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): if format == "binary": from .time_series_tools import write_binary + from .binaryrecordingextractor import BinaryRecordingExtractor + from .binaryfolder import BinaryFolderRecording folder = kwargs["folder"] file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())] @@ -332,8 +333,6 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): write_binary(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs) - from .binaryrecordingextractor import BinaryRecordingExtractor - # This is created so it can be saved as json because the `BinaryFolderRecording` requires it loading # See the __init__ of `BinaryFolderRecording` @@ -351,9 +350,6 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): offset_to_uV=self.get_channel_offsets(), ) binary_rec.dump(folder / "binary.json", relative_to=folder) - - from .binaryfolder import BinaryFolderRecording - cached = BinaryFolderRecording(folder_path=folder) # timestamps are not saved in binary, so we have to set them explicitly @@ -389,7 +385,6 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): self, zarr_path, storage_options, verbose=verbose, **kwargs, **job_kwargs ) cached = ZarrRecordingExtractor(zarr_path, storage_options) - # timestamps are saved and restored in zarr, so no need to set them explicitly elif format == "nwb": @@ -399,18 +394,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): else: raise ValueError(f"format {format} not supported") - if self.get_property("contact_vector") is not None: - probegroup = self.get_probegroup() - cached.set_probegroup(probegroup) - return cached def _extra_metadata_from_folder(self, folder): # load probe - folder = Path(folder) - if (folder / "probe.json").is_file(): - probegroup = read_probeinterface(folder / "probe.json") - self.set_probegroup(probegroup, in_place=True) + super()._extra_metadata_from_folder(folder) # load time vector if any for segment_index, rs in enumerate(self.segments): @@ -420,10 +408,7 @@ def _extra_metadata_from_folder(self, folder): rs.time_vector = time_vector def _extra_metadata_to_folder(self, folder): - # save probe - if self.get_property("contact_vector") is not None: - probegroup = self.get_probegroup() - write_probeinterface(folder / "probe.json", probegroup) + super()._extra_metadata_to_folder(folder) # save time vector if any for segment_index, rs in enumerate(self.segments): diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 58e91ec35c..43ed20dc34 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -1,11 +1,12 @@ from pathlib import Path - +from typing import Literal +import warnings import numpy as np from probeinterface import Probe, ProbeGroup, write_probeinterface, read_probeinterface, select_axes from .base import BaseExtractor -from .recording_tools import check_probe_do_not_overlap +from .recording_tools import _set_group_property_based_on_probegroup from warnings import warn @@ -19,6 +20,7 @@ def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype BaseExtractor.__init__(self, channel_ids) self._sampling_frequency = float(sampling_frequency) self._dtype = np.dtype(dtype) + self._probegroup = None @property def channel_ids(self): @@ -51,135 +53,118 @@ def has_scaleable_traces(self) -> bool: return True def has_probe(self) -> bool: - return "contact_vector" in self.get_property_keys() + # probe group is saved and loaded to binary/zarr, so we don't need to check for legacy "contact_vector" property + return self._probegroup is not None + + def has_3d_probe(self) -> bool: + if self.has_probe(): + probe = self.get_probegroup().probes[0] + return probe.ndim == 3 + else: + return False def has_channel_location(self) -> bool: - return self.has_probe() or "location" in self.get_property_keys() + return self.has_probe() def is_filtered(self): # the is_filtered is handle with annotation return self._annotations.get("is_filtered", False) - def set_probe(self, probe, group_mode="auto", in_place=False): + def reset_probe(self): + """ + Removes probe information """ - Attach a list of Probe object to a recording. + self._probegroup = None + + def set_probe( + self, + probe: Probe, + group_mode: Literal["auto", "by_probe", "by_shank", "by_side"] = "auto", + in_place: bool = False, + ) -> "BaseRecordingSnippets": + """ + Attach a Probe object to a recording. + + For this Probe.device_channel_indices is used to link contacts to recording channels. + If some contacts of the Probe are not connected (device_channel_indices=-1) + then the recording is "sliced" and only connected channels are kept. Parameters ---------- - probe_or_probegroup: Probe, list of Probe, or ProbeGroup - The probe(s) to be attached to the recording + probe: Probe + The probe to be attached to the recording group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" How to add the "group" property. - "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks and two sides are present. - in_place: bool - False by default. - Useful internally when extractor do self.set_probegroup(probe) + "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks + and two sides are present. + in_place: bool, default: False + If False, a new recording (view or channel selection) is returned. + If True, the recording is modified in place, which requires all channels to be connected. Returns ------- sub_recording: BaseRecording A view of the recording (ChannelSlice or clone or itself) + + Notes + ----- + Internally, this will construct a ProbeGroup with the probe and call `set_probegroup()`. """ - assert isinstance(probe, Probe), "must give Probe" + assert isinstance(probe, Probe), "The input must be a Probe object" probegroup = ProbeGroup() probegroup.add_probe(probe) - return self._set_probes(probegroup, group_mode=group_mode, in_place=in_place) - - def set_probegroup(self, probegroup, group_mode="auto", in_place=False): - return self._set_probes(probegroup, group_mode=group_mode, in_place=in_place) - - def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): + return self.set_probegroup(probegroup, group_mode=group_mode, in_place=in_place) + + def set_probegroup( + self, + probegroup: ProbeGroup | dict, + group_mode: Literal["auto", "by_probe", "by_shank", "by_side"] = "auto", + in_place: bool = False, + ) -> "BaseRecordingSnippets": """ - Attach a list of Probe objects to a recording. + Attach a ProbeGroup or dict to a recording. For this Probe.device_channel_indices is used to link contacts to recording channels. If some contacts of the Probe are not connected (device_channel_indices=-1) - then the recording is "sliced" and only connected channel are kept. - - The probe order is not kept. Channel ids are re-ordered to match the channel_ids of the recording. + then the recording is "sliced" and only connected channels are kept. + Note: The probe order of the probegroup is not kept. Channel ids are re-ordered to match the channel_ids of the recording. Parameters ---------- - probe_or_probegroup: Probe, list of Probe, or ProbeGroup + probegroup: ProbeGroup, or dict The probe(s) to be attached to the recording group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" How to add the "group" property. "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks and two sides are present. - in_place: bool - False by default. - Useful internally when extractor do self.set_probegroup(probe) + in_place: bool, default: False + If False, a new recording (view or channel selection) is returned. + If True, the recording is modified in place, which requires all channels to be connected. Returns ------- sub_recording: BaseRecording A view of the recording (ChannelSlice or clone or itself) """ - assert group_mode in ( - "auto", - "by_probe", - "by_shank", - "by_side", - ), "'group_mode' can be 'auto' 'by_probe' 'by_shank' or 'by_side'" - - # handle several input possibilities - if isinstance(probe_or_probegroup, Probe): - probegroup = ProbeGroup() - probegroup.add_probe(probe_or_probegroup) - elif isinstance(probe_or_probegroup, ProbeGroup): - probegroup = probe_or_probegroup - elif isinstance(probe_or_probegroup, list): - assert all([isinstance(e, Probe) for e in probe_or_probegroup]) - probegroup = ProbeGroup() - for probe in probe_or_probegroup: - probegroup.add_probe(probe) - else: - raise ValueError("must give Probe or ProbeGroup or list of Probe") + # Handle several input possibilities: ProbeGroup or dict + if isinstance(probegroup, dict): + probegroup = ProbeGroup.from_dict(probegroup) - # check that the probe do not overlap - num_probes = len(probegroup.probes) - if num_probes > 1: - check_probe_do_not_overlap(probegroup.probes) + probegroup_sorted = self._get_probegroup_based_on_device_channel_indices(probegroup) - # handle not connected channels - assert all( - probe.device_channel_indices is not None for probe in probegroup.probes - ), "Probe must have device_channel_indices" - - # this is a vector with complex fileds (dataframe like) that handle all contact attr - probe_as_numpy_array = probegroup.to_numpy(complete=True) - - # keep only connected contact ( != -1) - keep = probe_as_numpy_array["device_channel_indices"] >= 0 - if np.any(~keep): - warn("The given probes have unconnected contacts: they are removed") - - probe_as_numpy_array = probe_as_numpy_array[keep] - - device_channel_indices = probe_as_numpy_array["device_channel_indices"] - order = np.argsort(device_channel_indices) - device_channel_indices = device_channel_indices[order] - - # check TODO: Where did this came from? - number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) - if number_of_device_channel_indices >= self.get_num_channels(): - error_msg = ( - f"The given Probe either has 'device_channel_indices' that does not match channel count \n" - f"{len(device_channel_indices)} vs {self.get_num_channels()} \n" - f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n" - f"If using all channels remember that python is 0-indexed so max device_channel_index should be {self.get_num_channels() - 1} \n" - f"device_channel_indices are the following: {device_channel_indices} \n" - f"recording channels are the following: {self.get_channel_ids()} \n" - ) - raise ValueError(error_msg) - - new_channel_ids = self.get_channel_ids()[device_channel_indices] - probe_as_numpy_array = probe_as_numpy_array[order] - probe_as_numpy_array["device_channel_indices"] = np.arange(probe_as_numpy_array.size, dtype="int64") + if probegroup_sorted.get_contact_count() > 0: + sorted_dci = probegroup_sorted.get_global_device_channel_indices()["device_channel_indices"] + new_channel_ids = self.channel_ids[sorted_dci] + else: + new_channel_ids = self.channel_ids[[]] # empty selection - # create recording : channel slice or clone or self + # create recording: itself (in place), clone or channel slice if in_place: if not np.array_equal(new_channel_ids, self.get_channel_ids()): - raise Exception("set_probe(inplace=True) must have all channel indices") + raise ValueError( + "set_probe(in_place=True) requires the probe/probegroup to have the same number of connected " + "contacts as the number of channels in the recording. Use in_place=False to return a new " + "recording with a channel selection to match the probe/probegroup." + ) sub_recording = self else: if np.array_equal(new_channel_ids, self.get_channel_ids()): @@ -187,62 +172,92 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): else: sub_recording = self.select_channels(new_channel_ids) - # create a vector that handle all contacts in property - sub_recording.set_property("contact_vector", probe_as_numpy_array, ids=None) - - # planar_contour is saved in annotations - for probe_index, probe in enumerate(probegroup.probes): - contour = probe.probe_planar_contour - if contour is not None: - sub_recording.set_annotation(f"probe_{probe_index}_planar_contour", contour, overwrite=True) - - # duplicate positions to "locations" property - ndim = probegroup.ndim - locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") - for i, dim in enumerate(["x", "y", "z"][:ndim]): - locations[:, i] = probe_as_numpy_array[dim] - sub_recording.set_property("location", locations, ids=None) - - # handle groups - has_shank_id = "shank_ids" in probe_as_numpy_array.dtype.fields - has_contact_side = "contact_sides" in probe_as_numpy_array.dtype.fields - if group_mode == "auto": - group_keys = ["probe_index"] - if has_shank_id: - group_keys += ["shank_ids"] - if has_contact_side: - group_keys += ["contact_sides"] - elif group_mode == "by_probe": - group_keys = ["probe_index"] - elif group_mode == "by_shank": - assert has_shank_id, "shank_ids is None in probe, you cannot group by shank" - group_keys = ["probe_index", "shank_ids"] - elif group_mode == "by_side": - assert has_contact_side, "contact_sides is None in probe, you cannot group by side" - if has_shank_id: - group_keys = ["probe_index", "shank_ids", "contact_sides"] - else: - group_keys = ["probe_index", "contact_sides"] - groups = np.zeros(probe_as_numpy_array.size, dtype="int64") - unique_keys = np.unique(probe_as_numpy_array[group_keys]) - for group, a in enumerate(unique_keys): - mask = np.ones(probe_as_numpy_array.size, dtype=bool) - for k in group_keys: - mask &= probe_as_numpy_array[k] == a[k] - groups[mask] = group - sub_recording.set_property("group", groups, ids=None) - - # add probe annotations to recording - probes_info = [] - for probe in probegroup.probes: - probes_info.append(probe.annotations) - sub_recording.annotate(probes_info=probes_info) + if probegroup_sorted.get_contact_count() > 0: + probegroup_sorted.set_global_device_channel_indices(np.arange(probegroup_sorted.get_contact_count())) + sub_recording._probegroup = probegroup_sorted + # Handle and set channel groups + _set_group_property_based_on_probegroup(sub_recording, probegroup_sorted, group_mode=group_mode) + else: + sub_recording._probegroup = ProbeGroup() # empty probegroup return sub_recording + def _get_probegroup_based_on_device_channel_indices(self, probegroup: ProbeGroup) -> ProbeGroup: + """ + Returns a new probegroup sorted based on their device_channel_indices. + This is useful to ensure that the probes are ordered correctly when attached to a recording. + Also checks that the device_channel_indices are consistent with the recording channel count and + contacts are unique across probes in the probegroup. + + Parameters + ---------- + probegroup : ProbeGroup + The probegroup to be sorted. + + Returns + ------- + ProbeGroup + The sorted probegroup. + """ + if not isinstance(probegroup, ProbeGroup): + raise ValueError("The input must be a ProbeGroup or dict") + + assert all( + probe.device_channel_indices is not None for probe in probegroup.probes + ), "Probe must have device_channel_indices" + + # Remove unconnected contacts and slice the probe group accordingly + device_channel_indices = probegroup.get_global_device_channel_indices()["device_channel_indices"] + keep_indices = np.flatnonzero(device_channel_indices >= 0) + if len(keep_indices) < len(device_channel_indices): + warn( + f"The given probes have {len(device_channel_indices) - len(keep_indices)} unconnected contacts: " + "they will be removed" + ) + if len(keep_indices) == 0: + probegorup = ProbeGroup() # empty probegroup + device_channel_indices = np.array([], dtype="int64") + else: + probegroup = probegroup.get_slice(keep_indices) + device_channel_indices = device_channel_indices[keep_indices] + + if len(device_channel_indices) > 0: + # Check consistency of device_channel_indices with the recording channel count + number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) + if number_of_device_channel_indices >= self.get_num_channels(): + error_msg = ( + f"The given Probe either has 'device_channel_indices' that does not match channel count \n" + f"{len(device_channel_indices)} vs {self.get_num_channels()} \n" + f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n" + f"If using all channels remember that python is 0-indexed so max device_channel_index should be {self.get_num_channels() - 1} \n" + f"device_channel_indices are the following: {device_channel_indices} \n" + f"recording channels are the following: {self.get_channel_ids()} \n" + ) + raise ValueError(error_msg) + # Now slice the probe using the device channel indices to match the recording channel_ids + order = np.argsort(device_channel_indices) + probegroup = probegroup.get_slice(order) + else: + warn( + "No connected channels in the probegroup! " + "The probegroup will be attached but no channel will be selected." + ) + probegroup = ProbeGroup() # empty probegroup + + # In some older SI versions, before #4300, the probe annotations were + # saved to the recording annotations as `probes_info`. If this is the + # case, we can copy the annotations to the probegroup and delete the + # `probes_info` from the recording annotations. + if "probes_info" in self._annotations: + probes_info = self._annotations.pop("probes_info") + for probe, probe_info in zip(probegroup.probes, probes_info): + probe.annotations.update(probe_info) + + return probegroup + def get_probe(self): probes = self.get_probes() - assert len(probes) == 1, "there are several probe use .get_probes() or get_probegroup()" + assert len(probes) == 1, "There are several probe use .get_probes() or get_probegroup()" return probes[0] def get_probes(self): @@ -250,43 +265,50 @@ def get_probes(self): return probegroup.probes def get_probegroup(self): - arr = self.get_property("contact_vector") - if arr is None: - positions = self.get_property("location") - if positions is None: - raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") - else: - warn("There is no Probe attached to this recording. Creating a dummy one with contact positions") - probe = self.create_dummy_probe_from_locations(positions) - # probe.create_auto_shape() - probegroup = ProbeGroup() - probegroup.add_probe(probe) - else: - probegroup = ProbeGroup.from_numpy(arr) + if self._probegroup is None: + raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") + return self._probegroup - if "probes_info" in self.get_annotation_keys(): - probes_info = self.get_annotation("probes_info") - for probe, probe_info in zip(probegroup.probes, probes_info): - probe.annotations = probe_info - - for probe_index, probe in enumerate(probegroup.probes): - contour = self.get_annotation(f"probe_{probe_index}_planar_contour") - if contour is not None: - probe.set_planar_contour(contour) - return probegroup + def _extra_metadata_copy(self, other): + if self._probegroup is not None: + other._probegroup = self._probegroup.copy() def _extra_metadata_from_folder(self, folder): - # load probe + # load probe from folder + # Note: we don't need any fix for legacy probegroups, since the + # set_probegroup() method will handle the device_channel_indices + # sorting and global contact order folder = Path(folder) - if (folder / "probe.json").is_file(): - probegroup = read_probeinterface(folder / "probe.json") + probe_file = folder / "probegroup.json" + legacy_probe_file = folder / "probe.json" + if probe_file.is_file(): + probegroup = read_probeinterface(probe_file) + self.set_probegroup(probegroup, in_place=True) + elif legacy_probe_file.is_file(): + probegroup = read_probeinterface(legacy_probe_file) self.set_probegroup(probegroup, in_place=True) + # remove "contact_vector" property if present as it is not needed anymore + if "contact_vector" in self.get_property_keys(): + self.delete_property("contact_vector") + def _extra_metadata_to_folder(self, folder): # save probe - if self.get_property("contact_vector") is not None: + if self.has_probe(): + probegroup = self.get_probegroup() + write_probeinterface(folder / "probegroup.json", probegroup) + + def _extra_metadata_from_dict(self, dump_dict): + # load probe + if "probegroup" in dump_dict: + probegroup = dump_dict["probegroup"] + self.set_probegroup(probegroup, in_place=True) + + def _extra_metadata_to_dict(self, dump_dict): + # save probe + if self.has_probe(): probegroup = self.get_probegroup() - write_probeinterface(folder / "probe.json", probegroup) + dump_dict["probegroup"] = probegroup def create_dummy_probe_from_locations(self, locations, shape="circle", shape_params={"radius": 1}, axes="xy"): """ @@ -330,51 +352,55 @@ def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params ---------- locations : np.array Array with channel locations (num_channels, ndim) [ndim can be 2 or 3] - shape : str, default: default: "circle" + shape : str, default: "circle" Electrode shapes shape_params : dict, default: {"radius": 1} Shape parameters axes : "xy" | "yz" | "xz", default: "xy" If ndim is 3, indicates the axes that define the plane of the electrodes """ - probe = self.create_dummy_probe_from_locations(locations, shape=shape, shape_params=shape_params, axes=axes) + probe = self.create_dummy_probe_from_locations( + np.array(locations), shape=shape, shape_params=shape_params, axes=axes + ) self.set_probe(probe, in_place=True) def set_channel_locations(self, locations, channel_ids=None): - if self.get_property("contact_vector") is not None: - raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)") - self.set_property("location", locations, ids=channel_ids) + warnings.warn( + ( + "set_channel_locations() is deprecated and will be removed in version 0.106.0. " + "If you want to set probe information, use `set_dummy_probe_from_locations()`." + ), + DeprecationWarning, + stacklevel=2, + ) + self.set_dummy_probe_from_locations(locations, axes="xy") def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarray: if channel_ids is None: channel_ids = self.get_channel_ids() channel_indices = self.ids_to_indices(channel_ids) - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - # here we bypass the probe reconstruction so this works both for probe and probegroup - ndim = len(axes) - all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") - for i, dim in enumerate(axes): - all_positions[:, i] = contact_vector[dim] - positions = all_positions[channel_indices] - return positions - else: - locations = self.get_property("location") - if locations is None: - raise Exception("There are no channel locations") - locations = np.asarray(locations)[channel_indices] - return select_axes(locations, axes) + if not self.has_probe(): + raise ValueError("get_channel_locations(..) needs a probe to be attached to the recording") + probegroup = self.get_probegroup() + contact_positions = probegroup.get_global_contact_positions() + return select_axes(contact_positions, axes)[channel_indices] - def has_3d_locations(self) -> bool: - return self.get_property("location").shape[1] == 3 + def is_probe_3d(self) -> bool: + if not self.has_probe(): + raise ValueError("is_probe_3d() needs a probe to be attached to the recording") + probegroup = self.get_probegroup() + return probegroup.ndim == 3 def clear_channel_locations(self, channel_ids=None): - if channel_ids is None: - n = self.get_num_channel() - else: - n = len(channel_ids) - locations = np.zeros((n, 2)) * np.nan - self.set_property("location", locations, ids=channel_ids) + warnings.warn( + ( + "clear_channel_locations() is deprecated and will be removed in version 0.106.0. " + "If you want to remove probe information, use `reset_probe()`." + ), + DeprecationWarning, + stacklevel=2, + ) + self.reset_probe() def set_channel_groups(self, groups, channel_ids=None): if "probes" in self._annotations: @@ -429,7 +455,7 @@ def planarize(self, axes: str = "xy"): BaseRecording The recording with 2D positions """ - assert self.has_3d_locations, "The 'planarize' function needs a recording with 3d locations" + assert self.has_3d_probe(), "The 'planarize' function needs a recording with 3d locations" assert len(axes) == 2, "You need to specify 2 dimensions (e.g. 'xy', 'zy')" probe2d = self.get_probe().to_2d(axes=axes) diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index b56a093ccc..39e8c51225 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -11,7 +11,7 @@ class BaseSnippets(BaseRecordingSnippets): Abstract class representing several multichannel snippets. """ - _main_properties = ["group", "location", "gain_to_uV", "offset_to_uV"] + _main_properties = ["group", "gain_to_uV", "offset_to_uV"] _main_features = [] def __init__(self, sampling_frequency: float, nbefore: int | None, snippet_len: int, channel_ids: list, dtype): @@ -259,9 +259,9 @@ def _save(self, format="npy", **save_kwargs): else: raise ValueError(f"format {format} not supported") - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() - cached.set_probegroup(probegroup) + cached.set_probegroup(probegroup, in_place=True) return cached diff --git a/src/spikeinterface/core/binaryfolder.py b/src/spikeinterface/core/binaryfolder.py index 4b9d7b7d09..8d8ed3c206 100644 --- a/src/spikeinterface/core/binaryfolder.py +++ b/src/spikeinterface/core/binaryfolder.py @@ -3,6 +3,8 @@ import numpy as np +from probeinterface import read_probeinterface + from .binaryrecordingextractor import BinaryRecordingExtractor from .core_tools import define_function_from_class, make_paths_absolute diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 697aab875e..746f35eb47 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -2,6 +2,7 @@ import numpy as np +from probeinterface import ProbeGroup from .baserecording import BaseRecording, BaseRecordingSegment @@ -90,31 +91,27 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record break for prop_name, prop_values in property_dict.items(): - if prop_name == "contact_vector": - # remap device channel indices correctly - prop_values["device_channel_indices"] = np.arange(self.get_num_channels()) self.set_property(key=prop_name, values=prop_values) - # if locations are present, check that they are all different! - if "location" in self.get_property_keys(): - location_tuple = [tuple(loc) for loc in self.get_property("location")] - assert len(set(location_tuple)) == self.get_num_channels(), ( - "Locations are not unique! " "Cannot aggregate recordings!" - ) - - planar_contour_keys = [ - key for recording in recording_list for key in recording.get_annotation_keys() if "planar_contour" in key - ] - if len(planar_contour_keys) > 0: - if all( - k == planar_contour_keys[0] for k in planar_contour_keys - ): # we add the 'planar_contour' annotations only if there is a unique one in the recording_list - planar_contour_key = planar_contour_keys[0] - collect_planar_contours = [] - for rec in recording_list: - collect_planar_contours.append(rec.get_annotation(planar_contour_key)) - if all(np.array_equal(arr, collect_planar_contours[0]) for arr in collect_planar_contours): - self.set_annotation(planar_contour_key, collect_planar_contours[0]) + # Aggregate probe information + all_probegroups = [rec.get_probegroup() for rec in recording_list if rec.has_probe()] + if len(all_probegroups) == len(recording_list): + # check that contact positions are unique across all recordings + all_positions = [] + for probegroup in all_probegroups: + for probe in probegroup.probes: + all_positions.extend(probe.contact_positions) + assert len(np.unique(np.array(all_positions), axis=0)) == len( + all_positions + ), "Contact positions are not unique! Cannot aggregate recordings." + + # Now make a new probegroup with all probes and set global device channel indices + probegroup_agg = ProbeGroup() + for probegroup in all_probegroups: + for probe in probegroup.probes: + probegroup_agg.add_probe(probe.copy()) + probegroup_agg.set_global_device_channel_indices(np.arange(num_all_channels)) + self.set_probegroup(probegroup_agg, in_place=True) # finally add segments, we need a channel mapping ch_id = 0 diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index de693d5c26..1df401f8db 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -62,10 +62,11 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) self._parent = parent_recording # change the wiring of the probe - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + if self._parent.has_probe(): + parent_probegroup = self._parent.get_probegroup() + sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) + sliced_probegroup.set_global_device_channel_indices(np.arange(len(self._channel_ids))) + self.set_probegroup(sliced_probegroup, in_place=True) # update dump dict self._kwargs = { @@ -152,10 +153,11 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): parent_snippets.copy_metadata(self, only_main=False, ids=self._channel_ids) # change the wiring of the probe - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + if self._parent_snippets.has_probe(): + parent_probegroup = self._parent_snippets.get_probegroup() + sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) + sliced_probegroup.set_global_device_channel_indices(np.arange(len(self._channel_ids))) + self.set_probegroup(sliced_probegroup, in_place=True) # update dump dict self._kwargs = { diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index fda08ff1b0..02ee9bd915 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -10,6 +10,7 @@ from collections import namedtuple import inspect +from probeinterface import ProbeGroup import numpy as np @@ -148,6 +149,9 @@ def default(self, obj): if isinstance(obj, Motion): return obj.to_dict() + if isinstance(obj, ProbeGroup): + return obj.to_dict() + # The base-class handles the assertion return super().default(obj) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 4fa68ebec0..d3a38931a4 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -675,7 +675,7 @@ def generate_snippets( if set_probe: probe = recording.get_probe() - snippets = snippets.set_probe(probe) + snippets.set_probe(probe, in_place=True) return snippets, sorting diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 74b3ccb56e..99045501fc 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -8,7 +8,8 @@ import numpy as np -from .core_tools import add_suffix, make_shared_array +from probeinterface import ProbeGroup + from .job_tools import ( ensure_chunk_size, divide_segment_into_chunks, @@ -683,44 +684,60 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y"), return order_f, order_r -def check_probe_do_not_overlap(probes): +def _set_group_property_based_on_probegroup( + recording, probegroup: ProbeGroup, group_mode: Literal["auto", "by_probe", "by_shank", "by_side"] +): """ - When several probes this check that that they do not overlap in space - and so channel positions can be safely concatenated. - - Raises - ------ - Exception : - If probes are overlapping + Set the group property for a recording based on a ProbeGroup. + Use "auto" (default) to automatically determine the grouping based on the available + information in the ProbeGroup (default: probe + shank + side if available). - Returns - ------- - None : None - If the check is successful + Parameters + ---------- + recording : BaseRecording + The recording object + probegroup : ProbeGroup + The ProbeGroup object + group_mode : {"auto", "by_probe", "by_shank", "by_side"} + The mode for grouping channels """ - for i in range(len(probes)): - probe_i = probes[i] - # check that all positions in probe_j are outside probe_i boundaries - x_bounds_i = [ - np.min(probe_i.contact_positions[:, 0]), - np.max(probe_i.contact_positions[:, 0]), - ] - y_bounds_i = [ - np.min(probe_i.contact_positions[:, 1]), - np.max(probe_i.contact_positions[:, 1]), - ] - - for j in range(i + 1, len(probes)): - probe_j = probes[j] - if np.any( - np.array( - [ - x_bounds_i[0] <= cp[0] <= x_bounds_i[1] and y_bounds_i[0] <= cp[1] <= y_bounds_i[1] - for cp in probe_j.contact_positions - ] - ) - ): - raise Exception("Probes are overlapping! Retrieve locations of single probes separately") + if not isinstance(probegroup, ProbeGroup): + raise ValueError("`probegroup` must be a ProbeGroup instance.") + assert group_mode in ( + "auto", + "by_probe", + "by_shank", + "by_side", + ), "'group_mode' can be 'auto' 'by_probe' 'by_shank' or 'by_side'" + + probe_array = probegroup.to_numpy(complete=True) + has_shank_id = "shank_ids" in probe_array.dtype.fields + has_contact_side = "contact_sides" in probe_array.dtype.fields + if group_mode == "auto": + group_keys = ["probe_index"] + if has_shank_id: + group_keys += ["shank_ids"] + if has_contact_side: + group_keys += ["contact_sides"] + elif group_mode == "by_probe": + group_keys = ["probe_index"] + elif group_mode == "by_shank": + assert has_shank_id, "shank_ids is None in probe, you cannot group by shank" + group_keys = ["probe_index", "shank_ids"] + elif group_mode == "by_side": + assert has_contact_side, "contact_sides is None in probe, you cannot group by side" + if has_shank_id: + group_keys = ["probe_index", "shank_ids", "contact_sides"] + else: + group_keys = ["probe_index", "contact_sides"] + groups = np.zeros(probe_array.size, dtype="int64") + unique_keys = np.unique(probe_array[group_keys]) + for group, a in enumerate(unique_keys): + mask = np.ones(probe_array.size, dtype=bool) + for k in group_keys: + mask &= probe_array[k] == a[k] + groups[mask] = group + recording.set_property("group", groups, ids=None) def get_rec_attributes(recording): @@ -738,8 +755,6 @@ def get_rec_attributes(recording): The rec_attributes dictionary """ properties_to_attrs = deepcopy(recording._properties) - if "contact_vector" in properties_to_attrs: - del properties_to_attrs["contact_vector"] rec_attributes = dict( channel_ids=recording.channel_ids, sampling_frequency=recording.get_sampling_frequency(), diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index b5885598fe..5abf58ddfd 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -23,7 +23,7 @@ from spikeinterface.core import BaseRecording, BaseSorting, aggregate_channels, aggregate_units from spikeinterface.core.waveform_tools import has_exceeding_spikes -from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match +from .recording_tools import get_rec_attributes, do_recording_attributes_match from .core_tools import ( check_json, retrieve_importing_provenance, @@ -365,7 +365,6 @@ def create( ) # check that multiple probes are non-overlapping all_probes = recording.get_probegroup().probes - check_probe_do_not_overlap(all_probes) if has_exceeding_spikes(sorting=sorting, recording=recording): warnings.warn( diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index bb6db4cb66..5f7786e5ac 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -12,7 +12,13 @@ from probeinterface import Probe, ProbeGroup, generate_linear_probe -from spikeinterface.core import BinaryRecordingExtractor, NumpyRecording, load, get_default_zarr_compressor +from spikeinterface.core import ( + BinaryRecordingExtractor, + NumpyRecording, + load, + get_default_zarr_compressor, + aggregate_channels, +) from spikeinterface.core.base import BaseExtractor from spikeinterface.core.testing import check_recordings_equal @@ -198,7 +204,13 @@ def test_BaseRecording(create_cache_folder): probe.create_auto_shape() rec_p = rec.set_probe(probe, group_mode="auto") + positions2 = rec_p.get_channel_locations() + assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) + rec_p = rec.set_probe(probe, group_mode="by_shank") + positions2 = rec_p.get_channel_locations() + assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) + rec_p = rec.set_probe(probe, group_mode="by_probe") positions2 = rec_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) @@ -206,7 +218,6 @@ def test_BaseRecording(create_cache_folder): probe2 = rec_p.get_probe() positions3 = probe2.contact_positions assert np.array_equal(positions2, positions3) - assert np.array_equal(probe2.device_channel_indices, [0, 1]) # test save with probe @@ -286,8 +297,9 @@ def test_BaseRecording(create_cache_folder): rec_int16.set_property("offset_to_uV", [0.0] * 5) # Test deprecated return_scaled parameter - traces_float32_old = rec_int16.get_traces(return_scaled=True) # Keep this for testing the deprecation warning - assert traces_float32_old.dtype == "float32" + with pytest.warns(DeprecationWarning, match="`return_scaled` is deprecated"): + traces_float32_old = rec_int16.get_traces(return_scaled=True) # Keep this for testing the deprecation warning + assert traces_float32_old.dtype == "float32" # Test new return_in_uV parameter traces_float32_new = rec_int16.get_traces(return_in_uV=True) @@ -344,7 +356,7 @@ def test_BaseRecording(create_cache_folder): # test 3d probe rec_3d = generate_recording(ndim=3, num_channels=30) - locations_3d = rec_3d.get_property("location") + locations_3d = rec_3d.get_probe().contact_positions locations_xy = rec_3d.get_channel_locations(axes="xy") assert np.allclose(locations_xy, locations_3d[:, [0, 1]]) @@ -413,40 +425,54 @@ def test_json_pickle_equivalence(create_cache_folder): for key, value in data_json.items(): # skip probe info, since pickle keeps some additional information - if key not in ["properties"]: - if isinstance(value, dict): + if key not in ["properties", "probegroup"]: + if isinstance(value, dict) and isinstance(data_pickle[key], dict): for sub_key, sub_value in value.items(): assert np.all(sub_value == data_pickle[key][sub_key]) else: assert np.all(value == data_pickle[key]) -def test_interleaved_probegroups(): - recording = generate_recording(durations=[1.0], num_channels=16) +def test_probes_info_annotation_backward_compat(): + """ + Regression test: SI versions before #4300 stored per-probe metadata in a + 'probes_info' annotation on the recording rather than inside the ProbeGroup. + set_probegroup() must migrate those annotations onto the probe objects and + remove the stale 'probes_info' entry from the recording annotations. + """ + from probeinterface import generate_linear_probe, ProbeGroup + + # Simulate probegroup as read from an old probegroup.json: probes exist but + # have no name/manufacturer annotations (old probeinterface did not write them). + probe_A = generate_linear_probe(num_elec=8, ypitch=20.0) + probe_A.move([0.0, 0.0]) + probe_A.set_device_channel_indices(np.arange(8)) + + probe_B = generate_linear_probe(num_elec=8, ypitch=20.0) + probe_B.move([500.0, 0.0]) + probe_B.set_device_channel_indices(np.arange(8, 16)) - probe1 = generate_linear_probe(num_elec=8, ypitch=20.0) - probe2_overlap = probe1.copy() + pg = ProbeGroup() + pg.add_probe(probe_A) + pg.add_probe(probe_B) - probegroup_overlap = ProbeGroup() - probegroup_overlap.add_probe(probe1) - probegroup_overlap.add_probe(probe2_overlap) - probegroup_overlap.set_global_device_channel_indices(np.arange(16)) + rec = NumpyRecording([np.zeros((100, 16), dtype="int16")], sampling_frequency=30000.0) - # setting overlapping probes should raise an error - with pytest.raises(Exception): - recording.set_probegroup(probegroup_overlap) + # Inject the old-style annotation that SI used to write alongside the probegroup. + rec._annotations["probes_info"] = [ + {"name": "probe_A", "manufacturer": "vendor_X"}, + {"name": "probe_B", "manufacturer": "vendor_Y"}, + ] - probe2 = probe1.copy() - probe2.move([100.0, 100.0]) - probegroup = ProbeGroup() - probegroup.add_probe(probe1) - probegroup.add_probe(probe2) - probegroup.set_global_device_channel_indices(np.random.permutation(16)) + rec = rec.set_probegroup(pg) - recording.set_probegroup(probegroup) - probegroup_set = recording.get_probegroup() - # check that the probe group is correctly set, by sorting the device channel indices - assert np.array_equal(probegroup_set.get_global_device_channel_indices()["device_channel_indices"], np.arange(16)) + probes = rec.get_probes() + assert len(probes) == 2 + probe_names = {p.annotations.get("name") for p in probes} + assert probe_names == {"probe_A", "probe_B"}, "Probe names must be migrated from probes_info" + manufacturers = {p.annotations.get("manufacturer") for p in probes} + assert manufacturers == {"vendor_X", "vendor_Y"}, "Manufacturers must be migrated from probes_info" + assert "probes_info" not in rec._annotations, "probes_info annotation must be consumed after migration" def test_rename_channels(): diff --git a/src/spikeinterface/core/tests/test_basesnippets.py b/src/spikeinterface/core/tests/test_basesnippets.py index 751a03460c..e4132870d6 100644 --- a/src/spikeinterface/core/tests/test_basesnippets.py +++ b/src/spikeinterface/core/tests/test_basesnippets.py @@ -143,7 +143,6 @@ def test_BaseSnippets(create_cache_folder): probe.create_auto_shape() snippets_p = snippets.set_probe(probe, group_mode="auto") - snippets_p = snippets.set_probe(probe, group_mode="by_probe") positions2 = snippets_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py index d5ba74cfd9..e437eeec9b 100644 --- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py +++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py @@ -1,10 +1,23 @@ import numpy as np +from probeinterface import generate_linear_probe from spikeinterface.core import aggregate_channels from spikeinterface.core import generate_recording from spikeinterface.core.testing import check_recordings_equal +def _make_rec_with_named_probe(name, manufacturer, x_shift): + """Helper: single-probe recording with annotated name and manufacturer.""" + probe = generate_linear_probe(num_elec=8, ypitch=20.0) + probe.move([x_shift, 0.0]) + probe.annotate(name=name, manufacturer=manufacturer) + probe.set_device_channel_indices(np.arange(8)) + probe.create_auto_shape() + rec = generate_recording(num_channels=8, durations=[1.0], set_probe=False) + rec.set_probe(probe, in_place=True) + return rec + + def test_channelsaggregationrecording(): num_channels = 3 @@ -262,5 +275,51 @@ def test_channel_aggregation_with_string_dtypes_of_different_size(): assert aggregated_recording.channel_ids.dtype == np.dtype(" +with spikeinterface==0.104.* installed. + +The GH Action workflow probe_backward_compat.yml does this automatically. +Set SI_PROBE_COMPAT_FIXTURES_DIR to point at the fixture directory if running locally. +""" + +import os +import numpy as np +import pytest +from pathlib import Path + +from spikeinterface.core import load + +FIXTURES_DIR = Path(os.environ.get("SI_PROBE_COMPAT_FIXTURES_DIR", "probe_compat_fixtures")) + +pytestmark = pytest.mark.skipif( + not FIXTURES_DIR.exists(), + reason=( + f"Probe compatibility fixtures not found at '{FIXTURES_DIR}'. " + "Run .github/scripts/create_probe_compat_fixtures.py with spikeinterface==0.104.* first, " + "or set SI_PROBE_COMPAT_FIXTURES_DIR to the fixture directory." + ), +) + + +# --------------------------------------------------------------------------- +# Shared assertion helpers +# --------------------------------------------------------------------------- + + +def _check_single_probe(rec): + assert rec.has_probe(), "Recording must have a probe after loading" + assert rec.get_num_channels() == 8 + probes = rec.get_probes() + assert len(probes) == 1 + probe = probes[0] + assert probe.annotations.get("name") == "test_probe" + assert probe.annotations.get("manufacturer") == "test_vendor" + assert list(probe.contact_ids) == [f"e{i}" for i in range(8)] + # After loading, device_channel_indices must be sorted 0..N-1 + assert np.array_equal(probe.device_channel_indices, np.arange(8)) + + +def _check_two_probes(rec): + assert rec.has_probe() + assert rec.get_num_channels() == 16 + probes = rec.get_probes() + assert len(probes) == 2, "Both probes must survive after loading" + probe_names = {p.annotations.get("name") for p in probes} + assert probe_names == {"probe_A", "probe_B"}, "Per-probe names must be preserved" + manufacturers = {p.annotations.get("manufacturer") for p in probes} + assert manufacturers == {"vendor_X", "vendor_Y"}, "Per-probe manufacturers must be preserved" + all_contact_ids = set() + for p in probes: + all_contact_ids.update(p.contact_ids.tolist()) + assert all_contact_ids == {f"a{i}" for i in range(8)} | {f"b{i}" for i in range(8)} + groups = rec.get_property("group") + assert len(np.unique(groups)) == 2, "Each probe must have its own group" + + +def _check_shuffled_probe(rec): + assert rec.has_probe() + assert rec.get_num_channels() == 8 + probe = rec.get_probes()[0] + assert probe.annotations.get("name") == "shuffled_probe" + assert probe.annotations.get("manufacturer") == "shuffle_vendor" + # After the old set_probe sorted contacts by device_channel_indices and + # normalised them, the stored probegroup has dci = 0..7. + assert np.array_equal(probe.device_channel_indices, np.arange(8)) + traces = rec.get_traces(segment_index=0) + assert traces.shape == (1000, 8) + + +# --------------------------------------------------------------------------- +# Binary folder fixtures +# --------------------------------------------------------------------------- + + +def test_single_probe_binary_compat(): + _check_single_probe(load(FIXTURES_DIR / "single_probe_binary")) + + +def test_two_probe_binary_compat(): + _check_two_probes(load(FIXTURES_DIR / "two_probe_binary")) + + +def test_shuffled_probe_binary_compat(): + _check_shuffled_probe(load(FIXTURES_DIR / "shuffled_probe_binary")) + + +# --------------------------------------------------------------------------- +# JSON dump fixtures +# --------------------------------------------------------------------------- + + +def test_single_probe_zarr_compat(): + _check_single_probe(load(FIXTURES_DIR / "single_probe.zarr")) + + +def test_two_probe_zarr_compat(): + _check_two_probes(load(FIXTURES_DIR / "two_probe.zarr")) + + +def test_shuffled_probe_zarr_compat(): + _check_shuffled_probe(load(FIXTURES_DIR / "shuffled_probe.zarr")) diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index bbc797c693..fd427ce942 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -177,7 +177,7 @@ def __init__( total_nbytes_stored += nbytes_stored_segment # load probe - probe_dict = self._root.attrs.get("probe", None) + probe_dict = self._root.attrs.get("probegroup", self._root.attrs.get("probe", None)) if probe_dict is not None: probegroup = ProbeGroup.from_dict(probe_dict) self.set_probegroup(probegroup, in_place=True) @@ -186,6 +186,9 @@ def __init__( if "properties" in self._root: prop_group = self._root["properties"] for key in prop_group.keys(): + # Skip contact_vector property since it is not used anymore to represent probegroup + if key == "contact_vector": + continue values = self._root["properties"][key] self.set_property(key, values) @@ -548,9 +551,9 @@ def add_recording_to_zarr_group( ) # save probe - if recording.get_property("contact_vector") is not None: + if recording.has_probe(): probegroup = recording.get_probegroup() - zarr_group.attrs["probe"] = check_json(probegroup.to_dict(array_as_list=True)) + zarr_group.attrs["probegroup"] = check_json(probegroup.to_dict(array_as_list=True)) # save time vector if any t_starts = np.zeros(recording.get_num_segments(), dtype="float64") * np.nan diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index 8d1fac0c72..c85e82b574 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -70,9 +70,11 @@ def __init__( if electrode_width is not None: probe_kwargs["electrode_width"] = electrode_width probe = probeinterface.read_3brain(file_path, **probe_kwargs) + rows = probe.contact_annotations["row"] + cols = probe.contact_annotations["col"] self.set_probe(probe, in_place=True) - self.set_property("row", self.get_property("contact_vector")["row"]) - self.set_property("col", self.get_property("contact_vector")["col"]) + self.set_property("row", rows) + self.set_property("col", cols) self._kwargs.update( { diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 932ecee106..5eaa49e6b8 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -74,8 +74,9 @@ def __init__( # rec_name auto set by neo rec_name = self.neo_reader.rec_name probe = probeinterface.read_maxwell(file_path, well_name=well_name, rec_name=rec_name) + electrodes = probe.contact_annotations["electrode"] self.set_probe(probe, in_place=True) - self.set_property("electrode", self.get_property("contact_vector")["electrode"]) + self.set_property("electrode", electrodes) self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), rec_name=rec_name)) @classmethod diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index bd0d2184d4..fa9d59ec47 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -31,11 +31,13 @@ def setUpClass(cls): cache_dir=None, ) except: + print("Skipping test due to server being down.") pytest.skip("Skipping test due to server being down.") try: cls.recording = read_ibl_recording(eid=cls.eid, stream_name="probe00.ap", one=cls.one) except requests.exceptions.HTTPError as e: if e.response.status_code == 503: + print("Skipping test due to server being down (HTTP 503).") pytest.skip("Skipping test due to server being down (HTTP 503).") else: raise @@ -84,8 +86,6 @@ def test_property_keys(self): expected_property_keys = [ "gain_to_uV", "offset_to_uV", - "contact_vector", - "location", "group", "shank", "shank_row", @@ -97,6 +97,9 @@ def test_property_keys(self): ] self.assertCountEqual(first=self.recording.get_property_keys(), second=expected_property_keys) + def test_has_probe(self): + assert self.recording.has_probe() is True + def test_trace_shape(self): expected_shape = (21, 384) self.assertTupleEqual(tuple1=self.small_scaled_trace.shape, tuple2=expected_shape) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 7975097629..dc83ab9596 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -117,7 +117,7 @@ def compute_monopolar_triangulation( # if enforce_decrease: # enforce_decrease_shells_data( - # wf_data, best_channels[unit_id], enforce_decrease_radial_parents, in_place=True + # wf_data, best_channels[unit_id], enforce_decrease_radial_parents # ) unit_location[i] = solve_monopolar_triangulation(wf_data, local_contact_locations, max_distance_um, optimizer) diff --git a/src/spikeinterface/preprocessing/average_across_direction.py b/src/spikeinterface/preprocessing/average_across_direction.py index 8d1c4475cd..23e0a1d5ae 100644 --- a/src/spikeinterface/preprocessing/average_across_direction.py +++ b/src/spikeinterface/preprocessing/average_across_direction.py @@ -63,7 +63,7 @@ def __init__( # my geometry channel_locations = np.zeros( (n_pos_unique, parent_channel_locations.shape[1]), - dtype=parent_channel_locations.dtype, + dtype=np.float32, ) # average other dimensions in the geometry other_dim = np.arange(parent_channel_locations.shape[1]) != dim diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index 61996e9036..1294b57a91 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -2,6 +2,8 @@ import numpy as np import os +import probeinterface as pi + import spikeinterface as si import spikeinterface.preprocessing as spre import spikeinterface.extractors as se @@ -125,9 +127,12 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan # distribute default probe locations across 4 shanks if set rng = np.random.default_rng(seed=None) - x = rng.choice(shanks, num_channels) - for idx, __ in enumerate(recording._properties["contact_vector"]): - recording._properties["contact_vector"][idx]["x"] = x[idx] + x_new = rng.choice(shanks, num_channels) + probe = recording.get_probe() + new_positions = probe.contact_positions.copy() + new_positions[:, 0] = x_new # column 0 is x + recording._probegroup.probes[0]._contact_positions = new_positions + recording.set_probe(probe, in_place=True) # generate random bad channel locations bad_channel_indexes = rng.choice(num_channels, rng.integers(1, int(num_channels / 5)), replace=False) @@ -161,18 +166,21 @@ def test_output_values(): the non-interpolated channels is also an implicit test these were not accidently changed. """ - recording = generate_recording(num_channels=5, durations=[1]) + recording = generate_recording(num_channels=5, durations=[1], set_probe=False) bad_channel_indexes = np.array([0]) bad_channel_ids = recording.channel_ids[bad_channel_indexes] - new_probe_locs = [ - [5, 7, 3, 5, 5], # 5 channels, a in the center ('bad channel', zero index) - [5, 5, 5, 7, 3], - ] # all others equal distance away. - # Overwrite the probe information with the new locations - for idx, (x, y) in enumerate(zip(*new_probe_locs)): - recording._properties["contact_vector"][idx]["x"] = x - recording._properties["contact_vector"][idx]["y"] = y + probe_locs = np.array( + [ + [5, 7, 3, 5, 5], # 5 channels, a in the center ('bad channel', zero index) + [5, 5, 5, 7, 3], + ] # all others equal distance away. + ).T + # Set the probe information with the new locations + probe = pi.Probe(ndim=2) + probe.set_contacts(positions=probe_locs) + probe.set_device_channel_indices(np.arange(len(probe_locs))) + recording.set_probe(probe, in_place=True) # Run interpolation in SI and check the interpolated channel # 0 is a linear combination of other channels @@ -186,8 +194,7 @@ def test_output_values(): # Shift the last channel position so that it is 4 units, rather than 2 # away. Setting sigma_um = p = 1 allows easy calculation of the expected # weights. - recording._properties["contact_vector"][-1]["x"] = 5 - recording._properties["contact_vector"][-1]["y"] = 9 + recording._probegroup.probes[0]._contact_positions[-1] = [5, 9] expected_weights = np.r_[np.tile(np.exp(-2), 3), np.exp(-4)] expected_weights /= np.sum(expected_weights) diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index 45d4809cd8..4854d94dba 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -157,7 +157,7 @@ def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping: "The new mapping cannot exceed total number of channels " "in the zero-chanenl-padded recording." ) else: - if "locations" in recording.get_property_keys() or "contact_vector" in recording.get_property_keys(): + if recording.has_probe(): self.channel_mapping = np.argsort(recording.get_channel_locations()[:, 1]) else: self.channel_mapping = np.arange(recording.get_num_channels()) diff --git a/src/spikeinterface/sorters/external/hdsort.py b/src/spikeinterface/sorters/external/hdsort.py index 3daaf85b7a..07d59332d2 100644 --- a/src/spikeinterface/sorters/external/hdsort.py +++ b/src/spikeinterface/sorters/external/hdsort.py @@ -276,8 +276,8 @@ def write_hdsort_input_format(cls, recording, save_path, chunk_memory="500M"): [("electrode", np.int32), ("x", np.float64), ("y", np.float64), ("channel", np.int32)] ) - locations = recording.get_property("location") - assert locations is not None, "'location' property is needed to run HDSort" + assert recording.has_probe(), "The recording must have a probe to run HDSort" + locations = recording.get_channel_locations() with h5py.File(save_path, "w") as f: f.create_group("ephys") diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index f616888166..80d79171ce 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -405,11 +405,12 @@ def __init__( if border_mode == "remove_channels": # change the wiring of the probe - # TODO this is also done in ChannelSliceRecording, this should be done in a common place - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + if recording.has_probe(): + probegroup = recording.get_probegroup() + channel_indices = recording.ids_to_indices(channel_ids) + probegroup_sliced = probegroup.get_slice(channel_indices) + probegroup_sliced.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) + self.set_probegroup(probegroup_sliced, in_place=True) # handle manual interpolation_time_bin_centers_s # the case where interpolation_time_bin_size_s is set is handled per-segment below diff --git a/src/spikeinterface/sortingcomponents/peak_localization/monopolar.py b/src/spikeinterface/sortingcomponents/peak_localization/monopolar.py index 8840a5a00d..942074ef31 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization/monopolar.py +++ b/src/spikeinterface/sortingcomponents/peak_localization/monopolar.py @@ -97,9 +97,7 @@ def compute(self, traces, peaks, waveforms): wf_data = np.abs(wf[self.nbefore]) if self.enforce_decrease_radial_parents is not None: - enforce_decrease_shells_data( - wf_data, peak["channel_index"], self.enforce_decrease_radial_parents, in_place=True - ) + enforce_decrease_shells_data(wf_data, peak["channel_index"], self.enforce_decrease_radial_parents) peak_locations[i] = solve_monopolar_triangulation( wf_data, local_contact_locations, self.max_distance_um, self.optimizer