Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions spikeinterface_gui/basescatterview.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _qt_make_layout(self):
tb = self.qt_widget.view_toolbar
self.combo_seg = QT.QComboBox()
tb.addWidget(self.combo_seg)
self.combo_seg.addItems([ f'Segment {segment_index}' for segment_index in range(self.controller.num_segments) ])
self.combo_seg.addItems([f'Segment {segment_index}' for segment_index in range(self.controller.num_segments)])
self.combo_seg.currentIndexChanged.connect(self._qt_change_segment)
add_stretch_to_qtoolbar(tb)
self.lasso_but = QT.QPushButton("select", checkable = True)
Expand Down Expand Up @@ -310,7 +310,7 @@ def _qt_refresh(self, set_scatter_range=False):
# make a copy of the color
color = QT.QColor(self.get_unit_color(unit_id))
color.setAlpha(int(self.settings['alpha']*255))
self.scatter.addPoints(x=spike_times, y=spike_data, pen=pg.mkPen(None), brush=color)
self.scatter.addPoints(x=spike_times, y=spike_data, pen=pg.mkPen(None), brush=color)

color = self.get_unit_color(unit_id)
curve = pg.PlotCurveItem(hist_count, hist_bins[:-1], fillLevel=None, fillOutline=True, brush=color, pen=color)
Expand Down
47 changes: 36 additions & 11 deletions spikeinterface_gui/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@

from spikeinterface.widgets.utils import get_unit_colors
from spikeinterface import compute_sparsity
from spikeinterface.core import get_template_extremum_channel
from spikeinterface.core import get_template_extremum_channel, BaseEvent
from spikeinterface.core.sorting_tools import spike_vector_to_indices
from spikeinterface.curation import validate_curation_dict
from spikeinterface.curation.curation_model import CurationModel
from spikeinterface.widgets.utils import make_units_table_from_analyzer

from .curation_tools import add_merge, default_label_definitions, empty_curation_data
from .event_tools import parse_events

spike_dtype =[('sample_index', 'int64'), ('unit_index', 'int64'),
('channel_index', 'int64'), ('segment_index', 'int64'),
Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(
extra_unit_properties=None,
skip_extensions=None,
disable_save_settings_button=False,
events=None,
external_data=None,
curation_callback=None,
curation_callback_kwargs=None,
Expand Down Expand Up @@ -250,6 +252,16 @@ def __init__(
self.pc_ext = pc_ext

self._potential_merges = None
# some direct attribute
self.num_segments = self.analyzer.get_num_segments()
self.sampling_frequency = self.analyzer.sampling_frequency

# parse events
self.events = None
if events is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets move this events handling in a separate function no ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self.events = parse_events(events, self, verbose=verbose)
if len(self.events) == 0:
self.events = None

t1 = time.perf_counter()
if verbose:
Expand All @@ -260,10 +272,6 @@ def __init__(
self._extremum_channel = get_template_extremum_channel(self.analyzer,
mode="extremum", peak_sign='both', outputs='index')

# some direct attribute
self.num_segments = self.analyzer.get_num_segments()
self.sampling_frequency = self.analyzer.sampling_frequency

# spikeinterface handle colors in matplotlib style tuple values in range (0,1)
self.refresh_colors()

Expand Down Expand Up @@ -462,9 +470,12 @@ def update_time_info(self):
else:
self.time_info['time_by_seg'] = time_by_seg

def get_t_start_t_stop(self):
segment_index = self.time_info["segment_index"]
if self.main_settings["use_times"] and self.has_extension("recording"):
def get_t_start_t_stop(self, use_times=None, segment_index=None):
if segment_index is None:
segment_index = self.time_info["segment_index"]
if use_times is None:
use_times = self.main_settings["use_times"]
if use_times and self.has_extension("recording"):
t_start = self.analyzer.recording.get_start_time(segment_index=segment_index)
t_stop = self.analyzer.recording.get_end_time(segment_index=segment_index)
return t_start, t_stop
Expand Down Expand Up @@ -502,14 +513,26 @@ def sample_index_to_time(self, sample_index):
else:
return sample_index / self.sampling_frequency

def time_to_sample_index(self, time):
segment_index = self.time_info["segment_index"]
if self.main_settings["use_times"] and self.has_extension("recording"):
def time_to_sample_index(self, time, segment_index=None, use_times=None):
if segment_index is None:
segment_index = self.time_info["segment_index"]
if use_times is None:
use_times = self.main_settings["use_times"]
if use_times and self.has_extension("recording"):
time = self.analyzer.recording.time_to_sample_index(time, segment_index=segment_index)
return time
else:
return int(time * self.sampling_frequency)

def get_events(self, event_name, segment_index=None):
if self.events is None:
return None
if event_name not in self.events:
return None
if segment_index is None:
segment_index = self.time_info['segment_index']
return self.events[event_name][segment_index]

def get_information_txt(self):
nseg = self.analyzer.get_num_segments()
nchan = self.analyzer.get_num_channels()
Expand Down Expand Up @@ -762,6 +785,8 @@ def set_channel_visibility(self, visible_channel_inds):
def has_extension(self, extension_name):
if extension_name == 'recording':
return self.analyzer.has_recording() or self.analyzer.has_temporary_recording()
elif extension_name == 'events':
return self.events is not None
else:
# extension needs to be loaded
if extension_name in self.skip_extensions:
Expand Down
89 changes: 89 additions & 0 deletions spikeinterface_gui/event_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import numpy as np

from spikeinterface.core import BaseEvent

def parse_events(events, controller, verbose=False):
"""Parse events input into a standard format.

Parameters
----------
events : dict | BaseEvent
BaseEvent object or a dictionary where keys are event names and values are dictionaries with at least a
'samples' key or 'times' key.
controller : Controller
Controller object managing the event parsing.
verbose : bool, default: False
Whether to print verbose messages.

Returns
-------
parsed_events : dict
Parsed events dictionary. The keys are event names, and the values are lists of numpy arrays of event sample indices.
Each element corresponds to a segment in the recording.
"""
parsed_events = {}
if isinstance(events, dict):
for key, val in events.items():
if not isinstance(val, dict):
if verbose:
print(f'\tSkipping event {key}: not a dict')
continue
if 'samples' not in val and 'times' not in val:
if verbose:
print(f'\tSkipping event {key}: missing samples or times')
continue
if 'times' in val:
samples_data = val['times']
convert_to_samples = True
else:
samples_data = val['samples']
convert_to_samples = False
if controller.num_segments > 1:
if not len(samples_data) == controller.num_segments:
if verbose:
print(f'\tSkipping event {key}: inconsistent number of samples')
continue
else:
# here we make sure samples is a list of list
if np.array(samples_data).ndim == 1:
samples_data = [samples_data]
if convert_to_samples:
# filter events based on recording start/stop times
filtered_samples_data = []
parsed_events[key] = []
for segment_index in range(controller.num_segments):
t_start, t_end = controller.get_t_start_t_stop(use_times=True, segment_index=segment_index)
s = samples_data[segment_index]
filtered_samples_data = s[(s >= t_start) & (s <= t_end)]
parsed_events[key].append(
np.sort(
controller.time_to_sample_index(
filtered_samples_data,
segment_index=segment_index,
use_times=True
)
)
)

else:
parsed_events[key] = [np.sort(s) for s in samples_data]
elif isinstance(events, BaseEvent):
event_names = events.channel_ids
parsed_events = {
event_name: [] for event_name in event_names
}
for event_name in event_names:
for segment_index in range(controller.num_segments):
event_times_segment = events.get_event_times(
channel_id=event_name,
segment_index=segment_index
)
event_samples_segment = controller.time_to_sample_index(
event_times_segment, segment_index=segment_index, use_times=True
)
parsed_events[event_name].append(np.sort(event_samples_segment))
else:
if verbose:
print('\tSkipping events: not a dict or BaseEvent')

return parsed_events
Loading
Loading