diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index 6f3f60c..4d1b22e 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -8,7 +8,7 @@ from spikeinterface.widgets.utils import get_unit_colors from spikeinterface import compute_sparsity -from spikeinterface.core import get_template_extremum_channel, BaseEvent +from spikeinterface.core import BaseEvent from spikeinterface.core.sorting_tools import spike_vector_to_indices from spikeinterface.curation import validate_curation_dict from spikeinterface.curation.curation_model import Curation @@ -275,9 +275,7 @@ def __init__( t0 = time.perf_counter() - self._extremum_channel = get_template_extremum_channel(self.analyzer, - mode="extremum", peak_sign='both', outputs='index') - + self._main_channels = self.analyzer.get_main_channels(outputs="index", with_dict=True) # spikeinterface handle colors in matplotlib style tuple values in range (0,1) self.refresh_colors() @@ -297,7 +295,7 @@ def __init__( self.num_spikes = self.analyzer.sorting.count_num_spikes_per_unit(outputs="dict") # print("self.num_spikes", self.num_spikes) - spike_vector = self.analyzer.sorting.to_spike_vector(concatenated=True, extremum_channel_inds=self._extremum_channel) + spike_vector = self.analyzer.sorting.to_spike_vector(concatenated=True, extremum_channel_inds=self._main_channels) # spike_vector = self.analyzer.sorting.to_spike_vector(concatenated=True) self.random_spikes_indices = self.analyzer.get_extension("random_spikes").get_data() @@ -586,8 +584,8 @@ def get_spike_colors(self, unit_indices): return colors - def get_extremum_channel(self, unit_id): - chan_ind = self._extremum_channel[unit_id] + def get_main_channel(self, unit_id): + chan_ind = self._main_channels[unit_id] return chan_ind # unit visibility zone @@ -762,7 +760,7 @@ def get_template_upsampling_factor(self): def get_upsampled_templates(self, unit_id): template_metrics_ext = self.analyzer.get_extension("template_metrics") unit_index = list(self.unit_ids).index(unit_id) - chan_ind = self.get_extremum_channel(unit_id) + chan_ind = self.get_main_channel(unit_id) template = self.templates_average[unit_index, :, chan_ind] if template_metrics_ext is None or "peaks_data" not in template_metrics_ext.data: return template, None, None diff --git a/spikeinterface_gui/unitlistview.py b/spikeinterface_gui/unitlistview.py index c501dc8..66f8128 100644 --- a/spikeinterface_gui/unitlistview.py +++ b/spikeinterface_gui/unitlistview.py @@ -284,7 +284,7 @@ def _qt_full_table_refresh(self): item.unit_id = unit_id self.items_visibility[unit_id] = item - channel_index = self.controller.get_extremum_channel(unit_id) + channel_index = self.controller.get_main_channel(unit_id) channel_id = self.controller.channel_ids[channel_index] item = CustomItem(f'{channel_id}') item.setFlags(QT.Qt.ItemIsEnabled|QT.Qt.ItemIsSelectable) @@ -505,7 +505,7 @@ def _panel_make_layout(self): {"id": str(unit_id), "color": mcolors.to_hex(self.controller.get_unit_color(unit_id))} ) data["channel_id"].append( - self.controller.channel_ids[self.controller.get_extremum_channel(unit_id)] + self.controller.channel_ids[self.controller.get_main_channel(unit_id)] ) for col in self.controller.displayed_unit_properties: data[col] = self.controller.units_table[col]