diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py index 44ea159..0a301dc 100644 --- a/spikeinterface_gui/basescatterview.py +++ b/spikeinterface_gui/basescatterview.py @@ -208,7 +208,7 @@ def _qt_make_layout(self): tb = self.qt_widget.view_toolbar self.combo_seg = QT.QComboBox() tb.addWidget(self.combo_seg) - self.combo_seg.addItems([ f'Segment {segment_index}' for segment_index in range(self.controller.num_segments) ]) + self.combo_seg.addItems([f'Segment {segment_index}' for segment_index in range(self.controller.num_segments)]) self.combo_seg.currentIndexChanged.connect(self._qt_change_segment) add_stretch_to_qtoolbar(tb) self.lasso_but = QT.QPushButton("select", checkable = True) @@ -310,7 +310,7 @@ def _qt_refresh(self, set_scatter_range=False): # make a copy of the color color = QT.QColor(self.get_unit_color(unit_id)) color.setAlpha(int(self.settings['alpha']*255)) - self.scatter.addPoints(x=spike_times, y=spike_data, pen=pg.mkPen(None), brush=color) + self.scatter.addPoints(x=spike_times, y=spike_data, pen=pg.mkPen(None), brush=color) color = self.get_unit_color(unit_id) curve = pg.PlotCurveItem(hist_count, hist_bins[:-1], fillLevel=None, fillOutline=True, brush=color, pen=color) diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index 013f399..0a3e397 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -8,13 +8,14 @@ from spikeinterface.widgets.utils import get_unit_colors from spikeinterface import compute_sparsity -from spikeinterface.core import get_template_extremum_channel +from spikeinterface.core import get_template_extremum_channel, BaseEvent from spikeinterface.core.sorting_tools import spike_vector_to_indices from spikeinterface.curation import validate_curation_dict from spikeinterface.curation.curation_model import CurationModel from spikeinterface.widgets.utils import make_units_table_from_analyzer from .curation_tools import add_merge, default_label_definitions, empty_curation_data +from .event_tools import parse_events spike_dtype =[('sample_index', 'int64'), ('unit_index', 'int64'), ('channel_index', 'int64'), ('segment_index', 'int64'), @@ -46,6 +47,7 @@ def __init__( extra_unit_properties=None, skip_extensions=None, disable_save_settings_button=False, + events=None, external_data=None, curation_callback=None, curation_callback_kwargs=None, @@ -250,6 +252,16 @@ def __init__( self.pc_ext = pc_ext self._potential_merges = None + # some direct attribute + self.num_segments = self.analyzer.get_num_segments() + self.sampling_frequency = self.analyzer.sampling_frequency + + # parse events + self.events = None + if events is not None: + self.events = parse_events(events, self, verbose=verbose) + if len(self.events) == 0: + self.events = None t1 = time.perf_counter() if verbose: @@ -260,10 +272,6 @@ def __init__( self._extremum_channel = get_template_extremum_channel(self.analyzer, mode="extremum", peak_sign='both', outputs='index') - # some direct attribute - self.num_segments = self.analyzer.get_num_segments() - self.sampling_frequency = self.analyzer.sampling_frequency - # spikeinterface handle colors in matplotlib style tuple values in range (0,1) self.refresh_colors() @@ -462,9 +470,12 @@ def update_time_info(self): else: self.time_info['time_by_seg'] = time_by_seg - def get_t_start_t_stop(self): - segment_index = self.time_info["segment_index"] - if self.main_settings["use_times"] and self.has_extension("recording"): + def get_t_start_t_stop(self, use_times=None, segment_index=None): + if segment_index is None: + segment_index = self.time_info["segment_index"] + if use_times is None: + use_times = self.main_settings["use_times"] + if use_times and self.has_extension("recording"): t_start = self.analyzer.recording.get_start_time(segment_index=segment_index) t_stop = self.analyzer.recording.get_end_time(segment_index=segment_index) return t_start, t_stop @@ -502,14 +513,26 @@ def sample_index_to_time(self, sample_index): else: return sample_index / self.sampling_frequency - def time_to_sample_index(self, time): - segment_index = self.time_info["segment_index"] - if self.main_settings["use_times"] and self.has_extension("recording"): + def time_to_sample_index(self, time, segment_index=None, use_times=None): + if segment_index is None: + segment_index = self.time_info["segment_index"] + if use_times is None: + use_times = self.main_settings["use_times"] + if use_times and self.has_extension("recording"): time = self.analyzer.recording.time_to_sample_index(time, segment_index=segment_index) return time else: return int(time * self.sampling_frequency) + def get_events(self, event_name, segment_index=None): + if self.events is None: + return None + if event_name not in self.events: + return None + if segment_index is None: + segment_index = self.time_info['segment_index'] + return self.events[event_name][segment_index] + def get_information_txt(self): nseg = self.analyzer.get_num_segments() nchan = self.analyzer.get_num_channels() @@ -762,6 +785,8 @@ def set_channel_visibility(self, visible_channel_inds): def has_extension(self, extension_name): if extension_name == 'recording': return self.analyzer.has_recording() or self.analyzer.has_temporary_recording() + elif extension_name == 'events': + return self.events is not None else: # extension needs to be loaded if extension_name in self.skip_extensions: diff --git a/spikeinterface_gui/event_tools.py b/spikeinterface_gui/event_tools.py new file mode 100644 index 0000000..9a96cf1 --- /dev/null +++ b/spikeinterface_gui/event_tools.py @@ -0,0 +1,89 @@ +import numpy as np + +from spikeinterface.core import BaseEvent + +def parse_events(events, controller, verbose=False): + """Parse events input into a standard format. + + Parameters + ---------- + events : dict | BaseEvent + BaseEvent object or a dictionary where keys are event names and values are dictionaries with at least a + 'samples' key or 'times' key. + controller : Controller + Controller object managing the event parsing. + verbose : bool, default: False + Whether to print verbose messages. + + Returns + ------- + parsed_events : dict + Parsed events dictionary. The keys are event names, and the values are lists of numpy arrays of event sample indices. + Each element corresponds to a segment in the recording. + """ + parsed_events = {} + if isinstance(events, dict): + for key, val in events.items(): + if not isinstance(val, dict): + if verbose: + print(f'\tSkipping event {key}: not a dict') + continue + if 'samples' not in val and 'times' not in val: + if verbose: + print(f'\tSkipping event {key}: missing samples or times') + continue + if 'times' in val: + samples_data = val['times'] + convert_to_samples = True + else: + samples_data = val['samples'] + convert_to_samples = False + if controller.num_segments > 1: + if not len(samples_data) == controller.num_segments: + if verbose: + print(f'\tSkipping event {key}: inconsistent number of samples') + continue + else: + # here we make sure samples is a list of list + if np.array(samples_data).ndim == 1: + samples_data = [samples_data] + if convert_to_samples: + # filter events based on recording start/stop times + filtered_samples_data = [] + parsed_events[key] = [] + for segment_index in range(controller.num_segments): + t_start, t_end = controller.get_t_start_t_stop(use_times=True, segment_index=segment_index) + s = samples_data[segment_index] + filtered_samples_data = s[(s >= t_start) & (s <= t_end)] + parsed_events[key].append( + np.sort( + controller.time_to_sample_index( + filtered_samples_data, + segment_index=segment_index, + use_times=True + ) + ) + ) + + else: + parsed_events[key] = [np.sort(s) for s in samples_data] + elif isinstance(events, BaseEvent): + event_names = events.channel_ids + parsed_events = { + event_name: [] for event_name in event_names + } + for event_name in event_names: + for segment_index in range(controller.num_segments): + event_times_segment = events.get_event_times( + channel_id=event_name, + segment_index=segment_index + ) + event_samples_segment = controller.time_to_sample_index( + event_times_segment, segment_index=segment_index, use_times=True + ) + parsed_events[event_name].append(np.sort(event_samples_segment)) + else: + if verbose: + print('\tSkipping events: not a dict or BaseEvent') + + return parsed_events \ No newline at end of file diff --git a/spikeinterface_gui/eventview.py b/spikeinterface_gui/eventview.py new file mode 100644 index 0000000..b055904 --- /dev/null +++ b/spikeinterface_gui/eventview.py @@ -0,0 +1,310 @@ +import numpy as np +from .view_base import ViewBase + +class EventView(ViewBase): + id = "event" + _supported_backend = ['qt', 'panel'] + _depend_on = ["events"] + _settings = [ + {'name': 'max_trials', 'type': 'int', 'value' : 50 }, + {'name': 'window_start', 'type': 'float', 'value': -0.2}, + {'name': 'window_end', 'type': 'float', 'value': 0.5}, + {'name': 'alpha_psth', 'type': 'float', 'value': 0.5}, + {'name': 'num_bins', 'type': 'int', 'value': 50}, + ] + _need_compute = False + + def __init__(self, controller=None, parent=None, backend="qt"): + self.mode = 'rasters' # or 'psth' + self.selected_unit = None + self.selected_event_key = None + ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) + + + def get_aligned_spikes(self, unit_ids): + window_s = [self.settings['window_start'], self.settings['window_end']] + window_samples = [int(w * self.controller.sampling_frequency) for w in window_s] + # gather all events across segments + all_event_times = [] + num_segments = self.controller.num_segments + num_events = 0 + for segment_index in range(num_segments): + events = self.controller.get_events(self.selected_event_key, segment_index=segment_index) + if events is not None: + all_event_times.append(events) + num_events += len(events) + else: + all_event_times.append(np.array([])) + + if num_events > self.settings['max_trials']: + # sub-sample events based on segment durations + segment_samples = [self.controller.get_num_samples(segment_index) for segment_index in range(num_segments)] + total_samples = sum(segment_samples) + events_per_segment = [] + for segment_index in range(num_segments - 1): + events_per_segment.append( + min(int(segment_samples[segment_index] / total_samples * self.settings['max_trials']), len(all_event_times[segment_index])) + ) + # assign remaining to last segment to ensure total is max_trials + events_per_segment.append(min(self.settings['max_trials'] - sum(events_per_segment), len(all_event_times[-1]))) + event_times = [np.random.choice(et, size=events_per_segment[i], replace=False) for i, et in enumerate(all_event_times)] + + aligned_spikes_dict = {} + for selected_unit in unit_ids: + aligned_spikes = [] + + for segment_index in range(num_segments): + inds = self.controller.get_spike_indices(selected_unit, segment_index=segment_index) + spike_times = self.controller.spikes["sample_index"][inds] + + for et in event_times[segment_index]: + rel_spikes = spike_times - et + rel_spikes = rel_spikes[(rel_spikes >= window_samples[0]) & (rel_spikes <= window_samples[1])] + aligned_spikes.append(rel_spikes / self.controller.sampling_frequency) # convert to seconds + aligned_spikes_dict[selected_unit] = aligned_spikes + return aligned_spikes_dict + + def _qt_make_layout(self): + import pyqtgraph as pg + from .myqt import QT, QtWidgets + + layout = QtWidgets.QVBoxLayout() + # Mode selection + toolbar = QtWidgets.QHBoxLayout() + self.mode_combo = QtWidgets.QComboBox() + self.mode_combo.addItems(['Rasters', 'PSTH']) + self.mode_combo.currentIndexChanged.connect(self._qt_on_mode_changed) + toolbar.addWidget(self.mode_combo) + # Event key selection + event_keys = list(self.controller.events.keys()) + if len(event_keys) > 1: + self.event_combo = QtWidgets.QComboBox() + self.event_combo.addItems(event_keys) + self.event_combo.currentIndexChanged.connect(self._qt_on_event_changed) + toolbar.addWidget(self.event_combo) + self.selected_event_key = event_keys[0] if event_keys else None + layout.addLayout(toolbar) + # Pyqtgraph PlotWidget + self.pg_plot = pg.PlotWidget() + self.scatter = pg.ScatterPlotItem(size=10, pxMode=True) + self.pg_plot.addItem(self.scatter) + + # Create vertical line at x=0 once + self.zero_line = pg.InfiniteLine(pos=0, angle=90, pen=pg.mkPen('gray', width=2, style=QT.Qt.DashLine)) + self.pg_plot.addItem(self.zero_line) + + layout.addWidget(self.pg_plot) + self.layout = layout + + def _qt_on_mode_changed(self, idx): + self.mode = 'rasters' if idx == 0 else 'psth' + self._qt_refresh() + + def _qt_on_event_changed(self, idx): + self.selected_event_key = self.event_combo.currentText() + self._qt_refresh() + + def _qt_refresh(self): + from .myqt import QT + import pyqtgraph as pg + + self.scatter.clear() + # Clear everything including scatter + self.pg_plot.clear() + self.pg_plot.addItem(self.zero_line) + + if self.mode == 'rasters': + # Clear all plot items except scatter + self.pg_plot.addItem(self.scatter) + + # Get visible units from controller + visible_units = self.controller.get_visible_unit_ids() + if not visible_units or self.selected_event_key is None: + return + + aligned_spikes_by_unit = self.get_aligned_spikes(visible_units) + window_s = [self.settings['window_start'], self.settings['window_end']] + + for selected_unit in visible_units: + aligned_spikes = aligned_spikes_by_unit[selected_unit] + color = QT.QColor(self.get_unit_color(selected_unit)) + + if self.mode == 'rasters': + all_x = [] + all_y = [] + for i, trial in enumerate(aligned_spikes): + if len(trial) > 0: + all_x.extend(trial) + y = [i]*len(trial) + all_y.extend(y) + if all_x: + self.scatter.addPoints(x=np.array(all_x), y=np.array(all_y), pen=pg.mkPen(None), brush=color, symbol="|") + else: + from pyqtgraph import BarGraphItem + + all_spikes = np.concatenate(aligned_spikes) if aligned_spikes else np.array([]) + all_y_hists = [] + if len(all_spikes) > 0: + bins = np.linspace(window_s[0], window_s[1], 51) + y, x = np.histogram(all_spikes, bins=bins) + # Use bin centers for plotting + bin_centers = (x[:-1] + x[1:]) / 2 + # Create a bar graph item instead of using stepMode + width = (x[1] - x[0]) * 0.8 # 80% of bin width + color.setAlpha(int(self.settings['alpha_psth']*255)) + bg = BarGraphItem(x=bin_centers, height=y, width=width, brush=color, pen=pg.mkPen(color, width=2)) + self.pg_plot.addItem(bg) + all_y_hists.extend(y) + # Set ranges + if self.mode == 'rasters': + self.pg_plot.setYRange(-0.5, len(aligned_spikes)+0.5, padding=0) + self.pg_plot.setXRange(window_s[0], window_s[1], padding=0) + self.pg_plot.setLabel('left', 'Event #') + self.pg_plot.setLabel('bottom', 'Time (s)') + self.pg_plot.setTitle(f'Rasters aligned to {self.selected_event_key}') + else: + self.pg_plot.setXRange(window_s[0], window_s[1], padding=0.05) + if len(all_y_hists) > 0: + self.pg_plot.setYRange(0, max(all_y_hists)*1.1, padding=0) + self.pg_plot.setLabel('left', 'Spike count') + self.pg_plot.setLabel('bottom', 'Time (s)') + self.pg_plot.setTitle(f'PSTH aligned to {self.selected_event_key}') + + + def _panel_make_layout(self): + import panel as pn + import bokeh.plotting as bpl + from bokeh.models import ColumnDataSource, Span, Range1d + from .utils_panel import _bg_color + + top_items = [] + self.panel_mode_select = pn.widgets.Select(name="Mode", value="Rasters", options=["Rasters", "PSTH"]) + self.panel_mode_select.param.watch(self._panel_on_mode_changed, 'value') + top_items.append(self.panel_mode_select) + event_keys = list(self.controller.events.keys()) + if len(event_keys) > 1: + self.panel_event_select = pn.widgets.Select(name="Event", value=event_keys[0], options=event_keys) + self.panel_event_select.param.watch( self._panel_on_event_changed, 'value') + top_items.append(self.panel_event_select) + self.selected_event_key = event_keys[0] + + top_bar = pn.Row(*top_items, sizing_mode="stretch_width") + self.bins = np.linspace( + self.settings["window_start"], + self.settings["window_end"], + self.settings["num_bins"] + ) + self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2 + self.scatter_source = ColumnDataSource(data={"x": [], "y": [], "color": []}) + self.hist_source = ColumnDataSource(data={"center": [], "height": [], "color": []}) + self.x_range = Range1d(start=self.settings['window_start'], end=self.settings['window_end']) + self.panel_fig = bpl.figure( + sizing_mode="stretch_both", + tools="reset,wheel_zoom", + active_scroll="wheel_zoom", + background_fill_color=_bg_color, + border_fill_color=_bg_color, + outline_line_color="white", + x_range=self.x_range, + styles={"flex": "1"} + ) + self.scatter = self.panel_fig.scatter( + "x", + "y", + source=self.scatter_source, + color="color", + ) + self.bar = self.panel_fig.vbar( + x="center", + top="height", + width=self.bins[1] - self.bins[0], + color="color", + source=self.hist_source, + alpha=self.settings['alpha_psth'] + ) + self.vline = Span(location=0, dimension='height', line_color='white', line_width=2, line_dash='dashed') + + self.panel_fig.yaxis.axis_label = 'Event #' + self.panel_fig.xaxis.axis_label = 'Time (s)' + self.panel_fig.toolbar.logo = None + self.panel_fig.add_layout(self.vline) + self.panel_plot_pane = pn.pane.Bokeh(self.panel_fig, sizing_mode="stretch_both") + self.layout = pn.Column( + top_bar, + self.panel_plot_pane, + sizing_mode="stretch_both", + ) + + def _panel_on_mode_changed(self, event): + self.mode = 'rasters' if event.new == 'Rasters' else 'psth' + self._panel_refresh() + + def _panel_on_event_changed(self, event): + self.selected_event_key = event.new + self._panel_refresh() + + def _panel_refresh(self): + import numpy as np + import bokeh.plotting as bpl + + visible_units = self.controller.get_visible_unit_ids() + aligned_spikes_by_unit = self.get_aligned_spikes(visible_units) + if self.mode == 'rasters': + self.hist_source.data = {"center": [], "height": [], "color": []} # Clear histogram data + self.panel_fig.title.text = f'Rasters aligned to {self.selected_event_key}' + self.panel_fig.yaxis.axis_label = 'Event #' + all_x = [] + all_y = [] + all_colors = [] + for selected_unit in visible_units: + aligned_spikes = aligned_spikes_by_unit[selected_unit] + color = self.get_unit_color(selected_unit) + for i, trial in enumerate(aligned_spikes): + if len(trial) > 0: + all_x.extend(trial) + y = [i] * len(trial) + all_y.extend(y) + all_colors.extend([color] * len(trial)) + self.scatter_source.data = { + "x": np.array(all_x), + "y": np.array(all_y), + "color": all_colors + } + else: + self.scatter_source.data = {"x": [], "y": [], "color": []} # Clear scatter data + + all_centers = [] + all_heights = [] + all_colors = [] + for selected_unit in visible_units: + aligned_spikes = aligned_spikes_by_unit[selected_unit] + all_spikes = np.concatenate(aligned_spikes) if aligned_spikes else np.array([]) + hist, _ = np.histogram(all_spikes, bins=self.bins) + all_centers.extend(list(self.bin_centers)) + all_heights.extend(list(hist)) + all_colors.extend([self.get_unit_color(selected_unit)] * len(hist)) + self.hist_source.data = { + "center": all_centers, + "height": all_heights, + "color": all_colors + } + self.panel_fig.yaxis.axis_label = 'Spike count' + self.panel_fig.title.text = f'PSTH aligned to {self.selected_event_key}' + + # adjust x_range if needed + if self.settings["window_start"] != self.x_range.start: + self.x_range.start = self.settings["window_start"] + if self.settings["window_end"] != self.x_range.end: + self.x_range.end = self.settings["window_end"] + + def _panel_on_settings_changed(self): + self.bins = np.linspace( + self.settings["window_start"], + self.settings["window_end"], + self.settings["num_bins"] + ) + self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2 + self.x_range.start = self.settings['window_start'] + self.x_range.end = self.settings['window_end'] + self.bar.glyph.width = self.bins[1] - self.bins[0] + self._panel_refresh() diff --git a/spikeinterface_gui/layout_presets.py b/spikeinterface_gui/layout_presets.py index d7f5e05..8478f51 100644 --- a/spikeinterface_gui/layout_presets.py +++ b/spikeinterface_gui/layout_presets.py @@ -56,7 +56,7 @@ def get_layout_description(preset_name, layout=None): default_layout = dict( zone1=['curation', 'spikelist'], zone2=['unitlist', 'merge'], - zone3=['trace', 'tracemap', 'spikeamplitude', 'amplitudescalings', 'spikedepth', 'spikerate'], + zone3=['trace', 'tracemap', 'spikeamplitude', 'amplitudescalings', 'spikedepth', 'spikerate', 'event'], zone4=[], zone5=['probe'], zone6=['ndscatter', 'similarity'], diff --git a/spikeinterface_gui/main.py b/spikeinterface_gui/main.py index 3117a73..c43cb2c 100644 --- a/spikeinterface_gui/main.py +++ b/spikeinterface_gui/main.py @@ -2,11 +2,12 @@ import argparse import json from pathlib import Path +from typing import Callable import numpy as np import warnings from spikeinterface import load_sorting_analyzer, load -from spikeinterface.core.core_tools import is_path_remote +from spikeinterface.core import BaseRecording, SortingAnalyzer, BaseEvent from spikeinterface.core.sortinganalyzer import get_available_analyzer_extensions from .utils_global import get_config_folder from spikeinterface_gui.layout_presets import get_layout_description @@ -16,29 +17,30 @@ from spikeinterface_gui.viewlist import get_all_possible_views def run_mainwindow( - analyzer, - mode="desktop", - with_traces=True, - curation=False, - curation_dict=None, - label_definitions=None, - displayed_unit_properties=None, - extra_unit_properties=None, - skip_extensions=None, - recording=None, - start_app=True, - layout_preset=None, - layout=None, - external_data=None, - curation_callback=None, - curation_callback_kwargs=None, - address="localhost", - port=0, - panel_start_server_kwargs=None, - panel_window_servable=True, - verbose=False, - user_settings=None, - disable_save_settings_button=False, + analyzer: SortingAnalyzer, + mode: str = "desktop", + with_traces: bool = True, + curation: bool = False, + curation_dict: dict | None = None, + label_definitions: dict | None = None, + displayed_unit_properties: list | None=None, + extra_unit_properties: list | None=None, + skip_extensions: list | None = None, + recording: BaseRecording | None = None, + events: BaseEvent | dict | None = None, + start_app: bool = True, + layout_preset: str | None = None, + layout: dict | None = None, + external_data: dict | None = None, + curation_callback: Callable | None = None, + curation_callback_kwargs: dict | None = None, + address: str = "localhost", + port: int = 0, + panel_start_server_kwargs: dict | None = None, + panel_window_servable: bool = True, + verbose: bool = False, + user_settings: dict | None = None, + disable_save_settings_button: bool = False, ): """ Create the main window and start the QT app loop. @@ -68,6 +70,9 @@ def run_mainwindow( recording: RecordingExtractor | None, default: None The recording object to display traces. This can be used when the SortingAnalyzer is recordingless. + events: BaseEvent | dict | None, default: None + The events to display in the GUI. This can be a BaseEvent object or a dictionary + with keys as event names and another dictionary as values with "samples" or "times". start_qt_app: bool, default: True If True, the QT app loop is started layout_preset : str | None @@ -156,6 +161,7 @@ def run_mainwindow( extra_unit_properties=extra_unit_properties, skip_extensions=skip_extensions, disable_save_settings_button=disable_save_settings_button, + events=events, external_data=external_data, curation_callback=curation_callback, curation_callback_kwargs=curation_callback_kwargs, diff --git a/spikeinterface_gui/tests/test_mainwindow_panel.py b/spikeinterface_gui/tests/test_mainwindow_panel.py index 2725745..3971078 100644 --- a/spikeinterface_gui/tests/test_mainwindow_panel.py +++ b/spikeinterface_gui/tests/test_mainwindow_panel.py @@ -31,7 +31,7 @@ def teardown_module(): clean_all(test_folder) -def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_extensions=False, from_si_api=False, port=0): +def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_extensions=False, events=False, port=0): analyzer = load_sorting_analyzer(test_folder / "sorting_analyzer") @@ -60,6 +60,16 @@ def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_ext ) win = None + events_dict = None + if events: + events_dict = {"event1": {"samples": []}, "event2": {"samples": []}} + for segment_index in range(analyzer.get_num_segments()): + events_dict["event1"]["samples"].append( + np.random.choice(np.arange(analyzer.get_num_samples(segment_index)), 30) + ) + events_dict["event2"]["samples"].append( + np.random.choice(np.arange(analyzer.get_num_samples(segment_index)), 50) + ) for segment_index in range(analyzer.get_num_segments()): shift = (segment_index + 1) * 100 # add a gap to times @@ -82,6 +92,7 @@ def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_ext layout_preset='default', # address="10.69.168.40", port=port, + events=events_dict # user_settings={"mainsettings": {"color_mode": "color_by_visibility", "max_visible_units": 5}} ) return win @@ -111,6 +122,7 @@ def test_launcher(verbose=True): parser = ArgumentParser() parser.add_argument('--dataset', default="small", help='Path to the dataset folder') +parser.add_argument('--events', action="store_true", help='Simulate and add events') if __name__ == '__main__': args = parser.parse_args() @@ -126,7 +138,7 @@ def test_launcher(verbose=True): if not test_folder.is_dir(): setup_module() - win = test_mainwindow(start_app=True, verbose=True, curation=True, port=0) + win = test_mainwindow(start_app=True, verbose=True, curation=True, events=args.events, port=0) # test_launcher(verbose=True) diff --git a/spikeinterface_gui/tests/test_mainwindow_qt.py b/spikeinterface_gui/tests/test_mainwindow_qt.py index bff2285..3349eba 100644 --- a/spikeinterface_gui/tests/test_mainwindow_qt.py +++ b/spikeinterface_gui/tests/test_mainwindow_qt.py @@ -35,7 +35,7 @@ def teardown_module(): clean_all(test_folder) -def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_extensions=False, from_si_api=False): +def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_extensions=False, events=False): analyzer = load_sorting_analyzer(test_folder / "sorting_analyzer") @@ -67,6 +67,7 @@ def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_ext yop=np.array([f"yop{i}" for i in range(n)]), yip=np.array([f"yip{i}" for i in range(n)]), ) + for segment_index in range(analyzer.get_num_segments()): shift = (segment_index + 1) * 100 @@ -80,6 +81,25 @@ def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_ext segment_index=segment_index ) + events_dict = None + if events: + events_dict = {"event1": {"times": []}, "event2": {"times": []}} + for segment_index in range(analyzer.get_num_segments()): + times = analyzer.recording.get_times(segment_index) + events_dict["event1"]["times"].append( + np.random.choice(times, 30) + ) + events_dict["event2"]["times"].append( + np.random.choice(times, 50) + ) + # add some events outside of recording times to test filtering + events_dict["event1"]["times"][-1] = np.concatenate( + [events_dict["event1"]["times"][-1], [times[0] - 10, times[-1] + 20]] + ) + events_dict["event2"]["times"][-1] = np.concatenate( + [events_dict["event2"]["times"][-1], [times[0] - 5, times[-1] + 15]] + ) + win = run_mainwindow( analyzer, mode="desktop", @@ -89,6 +109,7 @@ def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_ext displayed_unit_properties=None, extra_unit_properties=extra_unit_properties, layout_preset='default', + events=events_dict # user_settings={"mainsettings": {"color_mode": "color_by_visibility", "max_visible_units": 5}} ) @@ -122,6 +143,7 @@ def test_launcher(verbose=True): parser = ArgumentParser() parser.add_argument('--dataset', default="small", help='Path to the dataset folder') +parser.add_argument('--events', action="store_true", help='Simulate and add events') if __name__ == '__main__': args = parser.parse_args() @@ -133,7 +155,7 @@ def test_launcher(verbose=True): if not test_folder.is_dir(): setup_module() - win = test_mainwindow(start_app=True, verbose=True, curation=True) + win = test_mainwindow(start_app=True, verbose=True, curation=True, events=args.events) # win = test_mainwindow(start_app=True, verbose=True, curation=False) # test_launcher(verbose=True) diff --git a/spikeinterface_gui/tests/testingtools.py b/spikeinterface_gui/tests/testingtools.py index 2b0a764..dcd87af 100644 --- a/spikeinterface_gui/tests/testingtools.py +++ b/spikeinterface_gui/tests/testingtools.py @@ -129,12 +129,6 @@ def make_analyzer_folder(test_folder, case="small", unit_dtype="str"): sorting_analyzer.compute(["spike_amplitudes", "spike_locations"], **job_kwargs) - qm = sorting_analyzer.get_extension("quality_metrics").get_data() - # print(qm.index) - # print(qm.index.dtype) - # print(sorting_analyzer.unit_ids.dtype) - - def make_curation_dict(analyzer): unit_ids = analyzer.unit_ids.tolist() curation_dict = { diff --git a/spikeinterface_gui/tracemapview.py b/spikeinterface_gui/tracemapview.py index b09a23f..d5b75fc 100644 --- a/spikeinterface_gui/tracemapview.py +++ b/spikeinterface_gui/tracemapview.py @@ -78,8 +78,7 @@ def _qt_make_layout(self, **kargs): self.layout = QT.QVBoxLayout() - self._qt_create_toolbar() - + self._qt_create_toolbars() # create graphic view and 2 scroll bar g = QT.QGridLayout() @@ -93,14 +92,7 @@ def _qt_make_layout(self, **kargs): self.scatter = pg.ScatterPlotItem(size=10, pxMode = True) self.plot.addItem(self.scatter) - - self.scroll_time = QT.QScrollBar(orientation=QT.Qt.Horizontal) - g.addWidget(self.scroll_time, 1,1) - self.scroll_time.valueChanged.connect(self._qt_on_scroll_time) - - - # self.on_params_changed(do_refresh=False) - #this do refresh + self.layout.addWidget(self.bottom_toolbar) self._qt_change_segment(0) def _qt_on_settings_changed(self, do_refresh=True): @@ -169,6 +161,10 @@ def _qt_seek(self, t): self.plot.setXRange(t1, t2, padding=0.0) self.plot.setYRange(0, num_chans, padding=0.0) + # events + self._qt_add_event_lines(t1, t2) + + # group separation lines if self.chan_group_offsets is not None: for ch in self.chan_group_offsets: hline = pg.InfiniteLine(pos=ch, angle=0, movable=False, pen=pg.mkPen("black")) @@ -256,19 +252,24 @@ def _panel_make_layout(self): x="x", y="y", size=10, fill_color="color", fill_alpha=self.settings['alpha'], source=self.spike_source ) + self.event_source = ColumnDataSource({"x": [], "y": []}) + self.event_renderer = self.figure.line( + x="x", y="y", source=self.event_source, line_color="yellow", line_width=2, line_dash='dashed' + ) if self.chan_group_offsets is not None: self.figure.hspan(y=list(self.chan_group_offsets), line_color="yellow") # # Add hover tool for spikes # hover_spikes = HoverTool(renderers=[self.spike_renderer], tooltips=[("Unit", "@unit_id")]) # self.figure.add_tools(hover_spikes) - self._panel_create_toolbar() + self.toolbar = self._panel_create_toolbar() + self.bottom_toolbar = self._panel_create_bottom_toolbar() self.layout = pn.Column( pn.Column( # Main content area self.toolbar, self.figure, - self.time_slider, + self.bottom_toolbar, styles={"flex": "1"}, sizing_mode="stretch_both" ), @@ -277,6 +278,7 @@ def _panel_make_layout(self): ) def _panel_refresh(self): + self._panel_remove_event_line() t, segment_index = self.controller.get_time() xsize = self.xsize t1, t2 = t - xsize / 3.0, t + xsize * 2 / 3.0 @@ -316,6 +318,8 @@ def _panel_refresh(self): self.figure.x_range.end = t2 self.figure.y_range.end = data_curves.shape[1] + self._panel_add_event_lines(t1, t2) + def _panel_on_settings_changed(self): self.make_color_lut() self.refresh() @@ -352,15 +356,15 @@ def _do_update(): def _panel_on_time_info_updated(self): # Update segment and time slider range time, segment_index = self.controller.get_time() - self._block_auto_refresh_and_notify = True self._panel_change_segment(segment_index) - - # Update time slider value + self._block_auto_refresh_and_notify = False + self.refresh() + # Update slider visually after refresh, blocking the callback to avoid + # a Bokeh round-trip loop (browser sends value_throttled back to server) + self._block_auto_refresh_and_notify = True self.time_slider.value = time - self._block_auto_refresh_and_notify = False - # we don't need a refresh in panel because changing tab triggers a refresh def _panel_on_use_times_updated(self): # Update time seeker diff --git a/spikeinterface_gui/traceview.py b/spikeinterface_gui/traceview.py index 8a461ff..a5bfbf2 100644 --- a/spikeinterface_gui/traceview.py +++ b/spikeinterface_gui/traceview.py @@ -100,7 +100,7 @@ def get_data_in_chunk(self, t1, t2, segment_index): return times_chunk, data_curves, scatter_x, scatter_y, scatter_colors ## Qt ## - def _qt_create_toolbar(self): + def _qt_create_toolbars(self): from .myqt import QT import pyqtgraph as pg from .utils_qt import TimeSeeker, add_stretch_to_qtoolbar @@ -133,6 +133,47 @@ def _qt_create_toolbar(self): but.clicked.connect(self.auto_scale) tb.addWidget(but) + + self.scroll_time = QT.QScrollBar(orientation=QT.Qt.Horizontal) + self.scroll_time.valueChanged.connect(self._qt_on_scroll_time) + if self.controller.has_extension("events"): + bottom_layout = QT.QHBoxLayout() + bottom_layout.addWidget(self.scroll_time, stretch=8) + bottom_layout.addStretch() # Push button to the right + + event_layout = QT.QHBoxLayout() + event_keys = list(self.controller.events.keys()) + if len(event_keys) > 1: + self.event_type_combo = QT.QComboBox() + self.event_type_combo.addItems(event_keys) + self.event_type_combo.currentIndexChanged.connect(self._qt_on_event_type_changed) + event_layout.addWidget(QT.QLabel("Event:"), stretch=2) + event_layout.addWidget(self.event_type_combo, stretch=3) + self.event_key = event_keys[0] + else: + self.event_key = event_keys[0] + + self.prev_event_button = QT.QPushButton("◀") + self.next_event_button = QT.QPushButton("▶") + self.prev_event_button.setMaximumWidth(30) + self.next_event_button.setMaximumWidth(30) + self.next_event_button.clicked.connect(self._qt_on_next_event) + self.prev_event_button.clicked.connect(self._qt_on_prev_event) + event_layout.addWidget(self.prev_event_button, stretch=1) + event_layout.addWidget(self.next_event_button, stretch=1) + + # Wrap event_layout in a QWidget + event_widget = QT.QWidget() + event_widget.setLayout(event_layout) + bottom_layout.addWidget(event_widget) + bottom_widget = QT.QWidget() + bottom_widget.setLayout(bottom_layout) + self.event_lines = [] + else: + bottom_widget = self.scroll_time + self.event_lines = None + + self.bottom_toolbar = bottom_widget def _qt_initialize_plot(self): from .myqt import QT @@ -245,9 +286,62 @@ def _qt_scatter_item_clicked(self, x, y): self.notify_spike_selection_changed() self._qt_seek_with_selected_spike() + # change selected unit + unit_id = self.controller.unit_ids[self.controller.spikes[ind_spike_nearest]["unit_index"]] + self.controller.set_visible_unit_ids([unit_id]) + self.notify_unit_visibility_changed() + + def _qt_on_event_type_changed(self): + self.event_key = self.event_type_combo.currentText() + self.refresh() + + def _qt_add_event_lines(self, t1, t2): + import pyqtgraph as pg + from .myqt import QT + + if self.controller.has_extension("events"): + self._qt_remove_event_lines() + # Add vertical lines at event time + sample_start, sample_end = self.controller.get_chunk_indices(t1, t2, self.controller.get_time()[1]) + events = self.controller.get_events(self.event_key) + events_in_range = events[(events >= sample_start) & (events <= sample_end)] + for evt in events_in_range: + evt_time = self.controller.sample_index_to_time(evt) + pen = pg.mkPen(color=(255, 255, 0, 180), width=2, style=QT.Qt.DotLine) + event_line = pg.InfiniteLine(pos=evt_time, angle=90, movable=False, pen=pen) + self.event_lines.append(event_line) + self.plot.addItem(event_line) + + def _qt_remove_event_lines(self): + if hasattr(self, 'event_lines'): + for event_line in self.event_lines: + self.plot.removeItem(event_line) + self.event_lines = [] + + def _qt_on_next_event(self): + current_sample = self.controller.time_to_sample_index(self.controller.get_time()[0]) + event_samples = self.controller.get_events(self.event_key) + next_events = event_samples[event_samples > current_sample] + if next_events.size > 0: + next_evt_sample = next_events[0] + evt_time = self.controller.sample_index_to_time(next_evt_sample) + self.controller.set_time(time=evt_time) + self._qt_on_time_info_updated() + self.notify_time_info_updated() + + def _qt_on_prev_event(self): + current_sample = self.controller.time_to_sample_index(self.controller.get_time()[0]) + event_samples = self.controller.get_events(self.event_key) + prev_events = event_samples[event_samples < current_sample] + if prev_events.size > 0: + prev_evt_sample = prev_events[-1] + evt_time = self.controller.sample_index_to_time(prev_evt_sample) + self.controller.set_time(time=evt_time) + self._qt_on_time_info_updated() + self.notify_time_info_updated() ## panel ## - def _panel_create_toolbar(self): + def _panel_create_toolbars(self): import panel as pn segment_index = self.controller.get_time()[1] @@ -272,21 +366,56 @@ def _panel_create_toolbar(self): self.xsize_spinner.param.watch(self._panel_on_xsize_changed, "value") self.auto_scale_button.on_click(self._panel_auto_scale) - self.toolbar = pn.Row( + toolbar = pn.Row( self.segment_selector, xsize, self.auto_scale_button, sizing_mode="stretch_width", ) + return toolbar + + def _panel_create_bottom_toolbar(self): + import panel as pn + from bokeh.models import ColumnDataSource - # Time slider - segment_index = self.controller.get_time()[1] # update with controller.get_t_start/get_t_end t_start, t_stop = self.controller.get_t_start_t_stop() self.time_slider = pn.widgets.FloatSlider(name="Time (s)", start=t_start, end=t_stop, value=0, step=0.1, value_throttled=0, sizing_mode="stretch_width") self.time_slider.param.watch(self._panel_on_time_slider_changed, "value_throttled") + bottom_bar_items = [self.time_slider] + if self.controller.has_extension("events"): + self.event_line = None + if self.controller.has_extension("events"): + self.event_source = ColumnDataSource({"xs": [], "ys": []}) + self.event_line = self.figure.multi_line( + source=self.event_source, + xs="xs", ys="ys", line_color="yellow", line_dash="dashed", line_width=2, line_alpha=0.8 + ) + event_keys = list(self.controller.events.keys()) + if len(event_keys) > 1: + self.event_selector = pn.widgets.Select( + name="", + options=event_keys, + value=event_keys[0], + ) + self.event_key = event_keys[0] + else: + self.event_selector = None + self.event_key = event_keys[0] + + self.prev_event_button = pn.widgets.Button(name="◀", button_type="default", width=40) + self.next_event_button = pn.widgets.Button(name="▶", button_type="default", width=40) + + self.prev_event_button.on_click(self._panel_on_prev_event) + self.next_event_button.on_click(self._panel_on_next_event) + bottom_bar_items.extend([self.prev_event_button, self.next_event_button]) + if self.event_selector is not None: + bottom_bar_items.append(self.event_selector) + return pn.Row(*bottom_bar_items, sizing_mode="stretch_width") + + def _panel_on_segment_changed(self, event): segment_index = int(event.new.split()[-1]) self._panel_change_segment(segment_index) @@ -315,6 +444,48 @@ def _panel_on_time_slider_changed(self, event): self.refresh() self.notify_time_info_updated() + def _panel_on_event_type_changed(self): + self.event_key = self.event_selector.value + self.refresh() + + def _panel_add_event_lines(self, t1, t2): + if self.event_line is not None: + event_samples = self.controller.get_events(self.event_key) + segment_index = self.controller.get_time()[1] + start_sample, end_sample = self.controller.get_chunk_indices(t1, t2, segment_index) + events_in_range = event_samples[(event_samples >= start_sample) & (event_samples <= end_sample)] + xs = [] + ys = [] + y_min = self.figure.y_range.start + y_max = self.figure.y_range.end + for evt in events_in_range: + evt_time = self.controller.sample_index_to_time(evt) + xs.append([evt_time, evt_time]) + ys.append([y_min, y_max]) + self.event_source.data = {"xs": xs, "ys": ys} + + def _panel_on_next_event(self, event): + current_sample = self.controller.time_to_sample_index(self.controller.get_time()[0]) + event_samples = self.controller.get_events(self.event_key) + next_events = event_samples[event_samples > current_sample] + if next_events.size > 0: + next_evt_sample = next_events[0] + evt_time = self.controller.sample_index_to_time(next_evt_sample) + self.controller.set_time(time=evt_time) + self._panel_on_time_info_updated() + self.notify_time_info_updated() + + def _panel_on_prev_event(self, event): + current_sample = self.controller.time_to_sample_index(self.controller.get_time()[0]) + event_samples = self.controller.get_events(self.event_key) + prev_events = event_samples[event_samples < current_sample] + if prev_events.size > 0: + prev_evt_sample = prev_events[-1] + evt_time = self.controller.sample_index_to_time(prev_evt_sample) + self.controller.set_time(time=evt_time) + self._panel_on_time_info_updated() + self.notify_time_info_updated() + def _panel_seek_with_selected_spike(self): import panel as pn @@ -362,6 +533,49 @@ def _panel_on_double_tap(self, event): self.controller.set_indices_spike_selected([ind_spike_nearest]) self._panel_seek_with_selected_spike() self.notify_spike_selection_changed() + # change selected unit + unit_id = self.controller.unit_ids[self.controller.spikes[ind_spike_nearest]["unit_index"]] + self.controller.set_visible_unit_ids([unit_id]) + self.notify_unit_visibility_changed() + + def _panel_on_event_type_changed(self, event): + self.event_key = event.new + self.refresh() + + def _panel_on_next_event(self, event): + current_sample = self.controller.time_to_sample_index(self.controller.get_time()[0]) + event_samples = self.controller.get_events(self.event_key) + next_events = event_samples[event_samples > current_sample] + if next_events.size > 0: + next_evt_sample = next_events[0] + evt_time = self.controller.sample_index_to_time(next_evt_sample) + self.controller.set_time(time=evt_time) + self.time_slider.value = evt_time + self._panel_refresh() + self._panel_add_event_line() + + def _panel_on_prev_event(self, event): + current_sample = self.controller.time_to_sample_index(self.controller.get_time()[0]) + event_samples = self.controller.get_events(self.event_key) + prev_events = event_samples[event_samples < current_sample] + if prev_events.size > 0: + prev_evt_sample = prev_events[-1] + evt_time = self.controller.sample_index_to_time(prev_evt_sample) + self.controller.set_time(time=evt_time) + self.time_slider.value = evt_time + self._panel_refresh() + self._panel_add_event_line() + + def _panel_add_event_line(self): + # Add vertical line at event time + evt_time = self.controller.get_time()[0] + # get yspan from self.figure + fig = self.figure + yspan = [fig.y_range.start, fig.y_range.end] + self.event_source.data = dict(x=[evt_time, evt_time], y=yspan) + + def _panel_remove_event_line(self): + self.event_source.data = dict(x=[], y=[]) # TODO: pan behavior like Qt? # def _panel_on_pan_start(self, event): @@ -449,24 +663,20 @@ def _qt_make_layout(self): self.layout = QT.QVBoxLayout() # self.setLayout(self.layout) - self._qt_create_toolbar() - - + self._qt_create_toolbars() + # create graphic view and 2 scroll bar - g = QT.QGridLayout() - self.layout.addLayout(g) + # g = QT.QGridLayout() + # self.layout.addLayout(g) self.graphicsview = pg.GraphicsView() - g.addWidget(self.graphicsview, 0,1) + # g.addWidget(self.graphicsview, 0, 1) + self.layout.addWidget(self.graphicsview) MixinViewTrace._qt_initialize_plot(self) self.scatter = pg.ScatterPlotItem(size=10, pxMode = True) self.plot.addItem(self.scatter) - self.scroll_time = QT.QScrollBar(orientation=QT.Qt.Horizontal) - g.addWidget(self.scroll_time, 1,1) - self.scroll_time.valueChanged.connect(self._qt_on_scroll_time) - - + self.layout.addWidget(self.bottom_toolbar) self._qt_update_scroll_limits() def _qt_on_settings_changed(self): @@ -542,6 +752,9 @@ def _qt_seek(self, t): # ranges self.plot.setXRange(t1, t2, padding=0.0) self.plot.setYRange(-.5, visible_channel_inds.size - .5, padding=0.0) + # events + self._qt_add_event_lines(t1, t2) + def _qt_on_time_info_updated(self): # Update segment and time slider range @@ -609,20 +822,27 @@ def _panel_make_layout(self): x="x", y="y", size=10, fill_color="color", fill_alpha=self.settings['alpha'], source=self.spike_source ) + self.event_source = ColumnDataSource({"x": [], "y": []}) + self.event_renderer = self.figure.line( + x="x", y="y", source=self.event_source, line_color="yellow", line_width=2, line_dash='dashed' + ) + self.figure.on_event(DoubleTap, self._panel_on_double_tap) - self._panel_create_toolbar() + self.toolbar = self._panel_create_toolbar() + self.bottom_toolbar = self._panel_create_bottom_toolbar() self.layout = pn.Column( self.toolbar, self.figure, - self.time_slider, + self.bottom_toolbar, styles={"display": "flex", "flex-direction": "column"}, sizing_mode="stretch_both" ) def _panel_refresh(self): + self._panel_remove_event_line() t, segment_index = self.controller.get_time() xsize = self.xsize t1, t2 = t - xsize / 3.0, t + xsize * 2 / 3.0 @@ -666,8 +886,9 @@ def _panel_refresh(self): self.figure.x_range.end = t2 self.figure.y_range.end = n - 0.5 - # TODO: if from a different unit, change unit visibility + self._panel_add_event_lines(t1, t2) + # TODO: if from a different unit, change unit visibility def _panel_on_spike_selection_changed(self): self._panel_seek_with_selected_spike() @@ -686,9 +907,12 @@ def _panel_on_time_info_updated(self): time, segment_index = self.controller.get_time() self._block_auto_refresh_and_notify = True self._panel_change_segment(segment_index) - # Update time slider value - self.time_slider.value = time + self._block_auto_refresh_and_notify = False self.refresh() + # Update slider visually after refresh, blocking the callback to avoid + # a Bokeh round-trip loop (browser sends value_throttled back to server) + self._block_auto_refresh_and_notify = True + self.time_slider.value = time self._block_auto_refresh_and_notify = False def _panel_on_use_times_updated(self): diff --git a/spikeinterface_gui/viewlist.py b/spikeinterface_gui/viewlist.py index e1cac5b..39cc70d 100644 --- a/spikeinterface_gui/viewlist.py +++ b/spikeinterface_gui/viewlist.py @@ -20,6 +20,7 @@ from .metricsview import MetricsView from .spikerateview import SpikeRateView from .maintemplateview import MainTemplateView +from .eventview import EventView # probe and mainsettings view are first, since they affect other views (e.g., time info) builtin_views = [ @@ -27,7 +28,7 @@ TraceView, TraceMapView, WaveformView, WaveformHeatMapView, ISIView, CorrelogramView, NDScatterView, SimilarityView, SpikeAmplitudeView, SpikeDepthView, SpikeRateView, CurationView, MetricsView, SpikeListView, - AmplitudeScalingsView, MainTemplateView + AmplitudeScalingsView, MainTemplateView, EventView ] def get_all_possible_views(): @@ -53,4 +54,4 @@ def get_all_possible_views(): # Log but don't crash if a plugin fails to load print(f"Warning: Failed to load plugin view '{ep.name}': {e}") - return possible_class_views \ No newline at end of file + return possible_class_views