Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions spikeinterface_gui/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from spikeinterface.widgets.utils import get_unit_colors
from spikeinterface import compute_sparsity
from spikeinterface.core import get_template_extremum_channel, BaseEvent
from spikeinterface.core import BaseEvent
from spikeinterface.core.sorting_tools import spike_vector_to_indices
from spikeinterface.curation import validate_curation_dict
from spikeinterface.curation.curation_model import Curation
Expand Down Expand Up @@ -275,9 +275,7 @@ def __init__(

t0 = time.perf_counter()

self._extremum_channel = get_template_extremum_channel(self.analyzer,
mode="extremum", peak_sign='both', outputs='index')

self._main_channels = self.analyzer.get_main_channels(outputs="index", with_dict=True)
# spikeinterface handle colors in matplotlib style tuple values in range (0,1)
self.refresh_colors()

Expand All @@ -297,7 +295,7 @@ def __init__(
self.num_spikes = self.analyzer.sorting.count_num_spikes_per_unit(outputs="dict")
# print("self.num_spikes", self.num_spikes)

spike_vector = self.analyzer.sorting.to_spike_vector(concatenated=True, extremum_channel_inds=self._extremum_channel)
spike_vector = self.analyzer.sorting.to_spike_vector(concatenated=True, extremum_channel_inds=self._main_channels)
# spike_vector = self.analyzer.sorting.to_spike_vector(concatenated=True)

self.random_spikes_indices = self.analyzer.get_extension("random_spikes").get_data()
Expand Down Expand Up @@ -586,8 +584,8 @@ def get_spike_colors(self, unit_indices):
return colors


def get_extremum_channel(self, unit_id):
chan_ind = self._extremum_channel[unit_id]
def get_main_channel(self, unit_id):
chan_ind = self._main_channels[unit_id]
return chan_ind

# unit visibility zone
Expand Down Expand Up @@ -762,7 +760,7 @@ def get_template_upsampling_factor(self):
def get_upsampled_templates(self, unit_id):
template_metrics_ext = self.analyzer.get_extension("template_metrics")
unit_index = list(self.unit_ids).index(unit_id)
chan_ind = self.get_extremum_channel(unit_id)
chan_ind = self.get_main_channel(unit_id)
template = self.templates_average[unit_index, :, chan_ind]
if template_metrics_ext is None or "peaks_data" not in template_metrics_ext.data:
return template, None, None
Expand Down
4 changes: 2 additions & 2 deletions spikeinterface_gui/unitlistview.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def _qt_full_table_refresh(self):
item.unit_id = unit_id
self.items_visibility[unit_id] = item

channel_index = self.controller.get_extremum_channel(unit_id)
channel_index = self.controller.get_main_channel(unit_id)
channel_id = self.controller.channel_ids[channel_index]
item = CustomItem(f'{channel_id}')
item.setFlags(QT.Qt.ItemIsEnabled|QT.Qt.ItemIsSelectable)
Expand Down Expand Up @@ -505,7 +505,7 @@ def _panel_make_layout(self):
{"id": str(unit_id), "color": mcolors.to_hex(self.controller.get_unit_color(unit_id))}
)
data["channel_id"].append(
self.controller.channel_ids[self.controller.get_extremum_channel(unit_id)]
self.controller.channel_ids[self.controller.get_main_channel(unit_id)]
)
for col in self.controller.displayed_unit_properties:
data[col] = self.controller.units_table[col]
Expand Down
Loading