diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py index 2085296e..b99823d6 100644 --- a/spikeinterface_gui/basescatterview.py +++ b/spikeinterface_gui/basescatterview.py @@ -314,7 +314,7 @@ def _qt_refresh(self, set_scatter_range=False): # set x range to time range of the current segment for scatter, and max count for histogram # set y range to min and max of visible spike amplitudes - if set_scatter_range or not self._first_refresh_done: + if len(ymins) > 0 and (set_scatter_range or not self._first_refresh_done): ymin = np.min(ymins) ymax = np.max(ymaxs) t_start, t_stop = self.controller.get_t_start_t_stop() @@ -490,6 +490,7 @@ def _panel_make_layout(self): self.plotted_inds = [] def _panel_refresh(self, set_scatter_range=False): + import panel as pn from bokeh.models import FixedTicker self.plotted_inds = [] @@ -555,7 +556,8 @@ def _panel_refresh(self, set_scatter_range=False): # handle selected spikes self._panel_update_selected_spikes() - # set y range to min and max of visible spike amplitudes + # Defer Range updates to avoid nested document lock issues + # def update_ranges(): if set_scatter_range or not self._first_refresh_done: self.y_range.start = np.min(ymins) self.y_range.end = np.max(ymaxs) @@ -563,20 +565,36 @@ def _panel_refresh(self, set_scatter_range=False): self.hist_fig.x_range.end = max_count self.hist_fig.xaxis.ticker = FixedTicker(ticks=[0, max_count // 2, max_count]) + # Schedule the update to run after the current event loop iteration + # pn.state.execute(update_ranges, schedule=True) + def _panel_on_select_button(self, event): - if self.select_toggle_button.value: - self.scatter_fig.toolbar.active_drag = self.lasso_tool - else: - self.scatter_fig.toolbar.active_drag = None - self.scatter_source.selected.indices = [] + import panel as pn + + value = self.select_toggle_button.value + + def _do_update(): + if value: + self.scatter_fig.toolbar.active_drag = self.lasso_tool + else: + self.scatter_fig.toolbar.active_drag = None + self.scatter_source.selected.indices = [] + + pn.state.execute(_do_update, schedule=True) def _panel_change_segment(self, event): + import panel as pn + self._current_selected = 0 segment_index = int(self.segment_selector.value.split()[-1]) self.controller.set_time(segment_index=segment_index) t_start, t_end = self.controller.get_t_start_t_stop() - self.scatter_fig.x_range.start = t_start - self.scatter_fig.x_range.end = t_end + + def _do_update(): + self.scatter_fig.x_range.start = t_start + self.scatter_fig.x_range.end = t_end + + pn.state.execute(_do_update, schedule=True) self.refresh(set_scatter_range=True) self.notify_time_info_updated() @@ -618,9 +636,17 @@ def _panel_split(self, event): self.split() def _panel_update_selected_spikes(self): + import panel as pn + # handle selected spikes selected_spike_indices = self.controller.get_indices_spike_selected() selected_spike_indices = np.intersect1d(selected_spike_indices, self.plotted_inds) + if len(selected_spike_indices) == 1: + selected_segment = self.controller.spikes[selected_spike_indices[0]]['segment_index'] + segment_index = self.controller.get_time()[1] + if selected_segment != segment_index: + self.segment_selector.value = f"Segment {selected_segment}" + self._panel_change_segment(None) if len(selected_spike_indices) > 0: # map absolute indices to visible spikes segment_index = self.controller.get_time()[1] @@ -634,23 +660,16 @@ def _panel_update_selected_spikes(self): # set selected spikes in scatter plot if self.settings["auto_decimate"] and len(selected_indices) > 0: selected_indices, = np.nonzero(np.isin(self.plotted_inds, selected_spike_indices)) - self.scatter_source.selected.indices = list(selected_indices) else: - self.scatter_source.selected.indices = [] + selected_indices = [] + + def _do_update(): + self.scatter_source.selected.indices = list(selected_indices) + + pn.state.execute(_do_update, schedule=True) def _panel_on_spike_selection_changed(self): - # set selection in scatter plot - selected_indices = self.controller.get_indices_spike_selected() - if len(selected_indices) == 0: - self.scatter_source.selected.indices = [] - return - elif len(selected_indices) == 1: - selected_segment = self.controller.spikes[selected_indices[0]]['segment_index'] - segment_index = self.controller.get_time()[1] - if selected_segment != segment_index: - self.segment_selector.value = f"Segment {selected_segment}" - self._panel_change_segment(None) - # update selected spikes + # update selected spikes (scheduled via pn.state.execute inside) self._panel_update_selected_spikes() def _panel_handle_shortcut(self, event): diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index 3bb7eb94..41b63981 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -9,10 +9,7 @@ from spikeinterface.widgets.utils import get_unit_colors from spikeinterface import compute_sparsity from spikeinterface.core import get_template_extremum_channel -import spikeinterface.postprocessing -import spikeinterface.qualitymetrics from spikeinterface.core.sorting_tools import spike_vector_to_indices -from spikeinterface.core.core_tools import check_json from spikeinterface.curation import validate_curation_dict from spikeinterface.curation.curation_model import CurationModel from spikeinterface.widgets.utils import make_units_table_from_analyzer @@ -345,6 +342,7 @@ def __init__( if curation_data is not None: # validate the curation data + curation_data = deepcopy(curation_data) format_version = curation_data.get("format_version", None) # assume version 2 if not present if format_version is None: @@ -354,24 +352,6 @@ def __init__( except Exception as e: raise ValueError(f"Invalid curation data.\nError: {e}") - if curation_data.get("merges") is None: - curation_data["merges"] = [] - else: - # here we reset the merges for better formatting (str) - existing_merges = curation_data["merges"] - new_merges = [] - for m in existing_merges: - if "unit_ids" not in m: - continue - if len(m["unit_ids"]) < 2: - continue - new_merges = add_merge(new_merges, m["unit_ids"]) - curation_data["merges"] = new_merges - if curation_data.get("splits") is None: - curation_data["splits"] = [] - if curation_data.get("removed") is None: - curation_data["removed"] = [] - elif self.analyzer.format == "binary_folder": json_file = self.analyzer.folder / "spikeinterface_gui" / "curation_data.json" if json_file.exists(): @@ -386,24 +366,27 @@ def __init__( if curation_data is None: curation_data = deepcopy(empty_curation_data) + curation_data["unit_ids"] = self.unit_ids.tolist() - self.curation_data = curation_data - - self.has_default_quality_labels = False - if "label_definitions" not in self.curation_data: + if "label_definitions" not in curation_data: if label_definitions is not None: - self.curation_data["label_definitions"] = label_definitions + curation_data["label_definitions"] = label_definitions else: - self.curation_data["label_definitions"] = default_label_definitions.copy() + curation_data["label_definitions"] = default_label_definitions.copy() - if "quality" in self.curation_data["label_definitions"]: - curation_dict_quality_labels = self.curation_data["label_definitions"]["quality"]["label_options"] + # This will enable the default shortcuts if has default quality labels + self.has_default_quality_labels = False + if "quality" in curation_data["label_definitions"]: + curation_dict_quality_labels = curation_data["label_definitions"]["quality"]["label_options"] default_quality_labels = default_label_definitions["quality"]["label_options"] if set(curation_dict_quality_labels) == set(default_quality_labels): if self.verbose: print('Curation quality labels are the default ones') self.has_default_quality_labels = True + curation_data = CurationModel(**curation_data).model_dump() + self.curation_data = curation_data + def check_is_view_possible(self, view_name): from .viewlist import get_all_possible_views possible_class_views = get_all_possible_views() @@ -825,7 +808,24 @@ def compute_auto_merge(self, **params): ) return merge_unit_groups, extra - + + def set_curation_data(self, curation_data): + print("Setting curation data") + new_curation_data = empty_curation_data.copy() + new_curation_data.update(curation_data) + + if "unit_ids" not in curation_data: + print("Setting unit_ids from controller") + new_curation_data["unit_ids"] = self.unit_ids.tolist() + + if "label_definitions" not in curation_data: + print("Setting default label definitions") + new_curation_data["label_definitions"] = default_label_definitions.copy() + + # validate the curation data + model = CurationModel(**new_curation_data) + self.curation_data = model.model_dump() + def curation_can_be_saved(self): return self.analyzer.format != "memory" diff --git a/spikeinterface_gui/curation_tools.py b/spikeinterface_gui/curation_tools.py index 42b7b06b..1aff9fcd 100644 --- a/spikeinterface_gui/curation_tools.py +++ b/spikeinterface_gui/curation_tools.py @@ -38,3 +38,32 @@ def add_merge(previous_merges, new_merge_unit_ids): # Ensure the uniqueness new_merges = [{"unit_ids": list(set(gp))} for gp in new_merge_units] return new_merges + + +# def cast_unit_dtypes_in_curation(curation_data, unit_ids_dtype): +# """Cast unit ids in curation data to the correct dtype.""" +# if "unit_ids" in curation_data: +# curation_data["unit_ids"] = [unit_ids_dtype(uid) for uid in curation_data["unit_ids"]] + +# if "merges" in curation_data: +# for merge in curation_data["merges"]: +# merge["unit_ids"] = [unit_ids_dtype(uid) for uid in merge["unit_ids"]] +# new_unit_id = merge.get("new_unit_id", None) +# if new_unit_id is not None: +# merge["new_unit_id"] = unit_ids_dtype(new_unit_id) + +# if "splits" in curation_data: +# for split in curation_data["splits"]: +# split["unit_id"] = unit_ids_dtype(split["unit_id"]) +# new_unit_ids = split.get("new_unit_ids", None) +# if new_unit_ids is not None: +# split["new_unit_ids"] = [unit_ids_dtype(uid) for uid in new_unit_ids] + +# if "removed" in curation_data: +# curation_data["removed"] = [unit_ids_dtype(uid) for uid in curation_data["removed"]] + +# if "manual_labels" in curation_data: +# for label_entry in curation_data["manual_labels"]: +# label_entry["unit_id"] = unit_ids_dtype(label_entry["unit_id"]) + +# return curation_data \ No newline at end of file diff --git a/spikeinterface_gui/curationview.py b/spikeinterface_gui/curationview.py index 5394e72c..53f01786 100644 --- a/spikeinterface_gui/curationview.py +++ b/spikeinterface_gui/curationview.py @@ -6,20 +6,18 @@ from spikeinterface.core.core_tools import check_json - - class CurationView(ViewBase): id = "curation" - _supported_backend = ['qt', 'panel'] + _supported_backend = ["qt", "panel"] _need_compute = False def __init__(self, controller=None, parent=None, backend="qt"): self.active_table = "merge" - ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) + ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) # TODO: Cast unit ids to the correct type here def restore_units(self): - if self.backend == 'qt': + if self.backend == "qt": unit_ids = self._qt_get_delete_table_selection() else: unit_ids = self._panel_get_delete_table_selection() @@ -30,7 +28,7 @@ def restore_units(self): self.refresh() def unmerge(self): - if self.backend == 'qt': + if self.backend == "qt": merge_indices = self._qt_get_merge_table_row() else: merge_indices = self._panel_get_merge_table_row() @@ -40,7 +38,7 @@ def unmerge(self): self.refresh() def unsplit(self): - if self.backend == 'qt': + if self.backend == "qt": split_indices = self._qt_get_split_table_row() else: split_indices = self._panel_get_split_table_row() @@ -55,8 +53,8 @@ def select_and_notify_split(self, split_unit_id): self.controller.set_visible_unit_ids([split_unit_id]) self.notify_unit_visibility_changed() spike_inds = self.controller.get_spike_indices(split_unit_id, segment_index=None) - active_split = [s for s in self.controller.curation_data['splits'] if s['unit_id'] == split_unit_id][0] - split_indices = active_split['indices'][0] + active_split = [s for s in self.controller.curation_data["splits"] if s["unit_id"] == split_unit_id][0] + split_indices = active_split["indices"][0] self.controller.set_indices_spike_selected(spike_inds[split_indices]) self.notify_spike_selection_changed() @@ -73,25 +71,26 @@ def _qt_make_layout(self): but = QT.QPushButton("Save in analyzer") tb.addWidget(but) but.clicked.connect(self.save_in_analyzer) + but = QT.QPushButton("Export JSON") - but.clicked.connect(self._qt_export_json) + but.clicked.connect(self._qt_export_json) tb.addWidget(but) h = QT.QHBoxLayout() self.layout.addLayout(h) - v = QT.QVBoxLayout() h.addLayout(v) - self.table_delete = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection, - selectionBehavior=QT.QAbstractItemView.SelectRows) + self.table_delete = QT.QTableWidget( + selectionMode=QT.QAbstractItemView.SingleSelection, selectionBehavior=QT.QAbstractItemView.SelectRows + ) v.addWidget(self.table_delete) self.table_delete.setContextMenuPolicy(QT.Qt.CustomContextMenu) self.table_delete.customContextMenuRequested.connect(self._qt_open_context_menu_delete) self.table_delete.itemSelectionChanged.connect(self._qt_on_item_selection_changed_delete) self.delete_menu = QT.QMenu() - act = self.delete_menu.addAction('Restore') + act = self.delete_menu.addAction("Restore") act.triggered.connect(self.restore_units) shortcut_restore = QT.QShortcut(self.qt_widget) shortcut_restore.setKey(QT.QKeySequence("ctrl+r")) @@ -99,8 +98,9 @@ def _qt_make_layout(self): v = QT.QVBoxLayout() h.addLayout(v) - self.table_merge = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection, - selectionBehavior=QT.QAbstractItemView.SelectRows) + self.table_merge = QT.QTableWidget( + selectionMode=QT.QAbstractItemView.SingleSelection, selectionBehavior=QT.QAbstractItemView.SelectRows + ) # self.table_merge.setContextMenuPolicy(QT.Qt.CustomContextMenu) v.addWidget(self.table_merge) @@ -109,7 +109,7 @@ def _qt_make_layout(self): self.table_merge.itemSelectionChanged.connect(self._qt_on_item_selection_changed_merge) self.merge_menu = QT.QMenu() - act = self.merge_menu.addAction('Remove merge') + act = self.merge_menu.addAction("Remove merge") act.triggered.connect(self.unmerge) shortcut_unmerge = QT.QShortcut(self.qt_widget) shortcut_unmerge.setKey(QT.QKeySequence("ctrl+u")) @@ -117,14 +117,15 @@ def _qt_make_layout(self): v = QT.QVBoxLayout() h.addLayout(v) - self.table_split = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection, - selectionBehavior=QT.QAbstractItemView.SelectRows) + self.table_split = QT.QTableWidget( + selectionMode=QT.QAbstractItemView.SingleSelection, selectionBehavior=QT.QAbstractItemView.SelectRows + ) v.addWidget(self.table_split) self.table_split.setContextMenuPolicy(QT.Qt.CustomContextMenu) self.table_split.customContextMenuRequested.connect(self._qt_open_context_menu_split) self.table_split.itemSelectionChanged.connect(self._qt_on_item_selection_changed_split) self.split_menu = QT.QMenu() - act = self.split_menu.addAction('Remove split') + act = self.split_menu.addAction("Remove split") act.triggered.connect(self.unsplit) shortcut_unsplit = QT.QShortcut(self.qt_widget) shortcut_unsplit.setKey(QT.QKeySequence("ctrl+x")) @@ -132,6 +133,7 @@ def _qt_make_layout(self): def _qt_refresh(self): from .myqt import QT + # Merged merged_units = [m["unit_ids"] for m in self.controller.curation_data["merges"]] self.table_merge.clear() @@ -141,12 +143,12 @@ def _qt_refresh(self): self.table_merge.setSortingEnabled(False) for ix, group in enumerate(merged_units): item = QT.QTableWidgetItem(str(group)) - item.setFlags(QT.Qt.ItemIsEnabled|QT.Qt.ItemIsSelectable) + item.setFlags(QT.Qt.ItemIsEnabled | QT.Qt.ItemIsSelectable) self.table_merge.setItem(ix, 0, item) for i in range(self.table_merge.columnCount()): self.table_merge.resizeColumnToContents(i) - # Removed + # Removed removed_units = self.controller.curation_data["removed"] self.table_delete.clear() self.table_delete.setRowCount(len(removed_units)) @@ -155,12 +157,12 @@ def _qt_refresh(self): self.table_delete.setSortingEnabled(False) for i, unit_id in enumerate(removed_units): color = self.get_unit_color(unit_id) - pix = QT.QPixmap(16,16) + pix = QT.QPixmap(16, 16) pix.fill(color) icon = QT.QIcon(pix) - item = QT.QTableWidgetItem( f'{unit_id}') - item.setFlags(QT.Qt.ItemIsEnabled|QT.Qt.ItemIsSelectable) - self.table_delete.setItem(i,0, item) + item = QT.QTableWidgetItem(f"{unit_id}") + item.setFlags(QT.Qt.ItemIsEnabled | QT.Qt.ItemIsSelectable) + self.table_delete.setItem(i, 0, item) item.setIcon(icon) item.unit_id = unit_id self.table_delete.resizeColumnToContents(0) @@ -178,13 +180,11 @@ def _qt_refresh(self): num_spikes = self.controller.num_spikes[unit_id] num_splits = f"({num_indices}-{num_spikes - num_indices})" item = QT.QTableWidgetItem(f"{unit_id} {num_splits}") - item.setFlags(QT.Qt.ItemIsEnabled|QT.Qt.ItemIsSelectable) + item.setFlags(QT.Qt.ItemIsEnabled | QT.Qt.ItemIsSelectable) self.table_split.setItem(i, 0, item) item.unit_id = unit_id self.table_split.resizeColumnToContents(0) - - def _qt_get_delete_table_selection(self): selected_items = self.table_delete.selectedItems() if len(selected_items) == 0: @@ -198,7 +198,7 @@ def _qt_get_merge_table_row(self): return None else: return [s.row() for s in selected_items] - + def _qt_get_split_table_row(self): selected_items = self.table_split.selectedItems() if len(selected_items) == 0: @@ -214,7 +214,7 @@ def _qt_open_context_menu_merge(self): def _qt_open_context_menu_split(self): self.split_menu.popup(self.qt_widget.cursor().pos()) - + def _qt_on_item_selection_changed_merge(self): if len(self.table_merge.selectedIndexes()) == 0: return @@ -274,15 +274,16 @@ def _qt_on_unit_visibility_changed(self): def on_manual_curation_updated(self): self.refresh() - + def save_in_analyzer(self): self.controller.save_curation_in_analyzer() def _qt_export_json(self): from .myqt import QT + fd = QT.QFileDialog(fileMode=QT.QFileDialog.AnyFile, acceptMode=QT.QFileDialog.AcceptSave) - fd.setNameFilters(['JSON (*.json);']) - fd.setDefaultSuffix('json') + fd.setNameFilters(["JSON (*.json);"]) + fd.setDefaultSuffix("json") fd.setViewMode(QT.QFileDialog.Detail) if fd.exec_(): json_file = Path(fd.selectedFiles()[0]) @@ -296,14 +297,17 @@ def _panel_make_layout(self): import pandas as pd import panel as pn - from .utils_panel import KeyboardShortcut, KeyboardShortcuts, SelectableTabulator + from .utils_panel import KeyboardShortcut, KeyboardShortcuts, SelectableTabulator, PostMessageListener, IFrameDetector pn.extension("tabulator") + # Initialize listenet_pane as None + self.listener_pane = None + # Create dataframe delete_df = pd.DataFrame({"removed": []}) merge_df = pd.DataFrame({"merges": []}) - split_df = pd.DataFrame({"splits": []}) + split_df = pd.DataFrame({"splits": []}) # Create tables self.table_delete = SelectableTabulator( @@ -311,11 +315,11 @@ def _panel_make_layout(self): show_index=False, disabled=True, sortable=False, + selectable=True, formatters={"removed": "plaintext"}, sizing_mode="stretch_width", # SelectableTabulator functions parent_view=self, - # refresh_table_function=self.refresh, conditional_shortcut=self._conditional_refresh_delete, column_callbacks={"removed": self._panel_on_deleted_col}, ) @@ -329,7 +333,6 @@ def _panel_make_layout(self): sizing_mode="stretch_width", # SelectableTabulator functions parent_view=self, - # refresh_table_function=self.refresh, conditional_shortcut=self._conditional_refresh_merge, column_callbacks={"merges": self._panel_on_merged_col}, ) @@ -347,59 +350,43 @@ def _panel_make_layout(self): column_callbacks={"splits": self._panel_on_split_col}, ) - self.table_delete.param.watch(self._panel_update_unit_visibility, "selection") - self.table_merge.param.watch(self._panel_update_unit_visibility, "selection") - self.table_split.param.watch(self._panel_update_unit_visibility, "selection") + # Watch selection changes instead of calling from column callbacks + self.table_delete.param.watch(self._panel_on_table_selection_changed, "selection") + self.table_merge.param.watch(self._panel_on_table_selection_changed, "selection") + self.table_split.param.watch(self._panel_on_table_selection_changed, "selection") # Create buttons - save_button = pn.widgets.Button( - name="Save in analyzer", - button_type="primary", - height=30 - ) - save_button.on_click(self._panel_save_in_analyzer) + buttons_row = [] + self.save_button = None + if self.controller.curation_can_be_saved(): + self.save_button = pn.widgets.Button(name="Save in analyzer", button_type="primary", height=30) + self.save_button.on_click(self._panel_save_in_analyzer) + buttons_row.append(self.save_button) - download_button = pn.widgets.FileDownload( - button_type="primary", - filename="curation.json", - callback=self._panel_generate_json, - height=30 + self.download_button = pn.widgets.FileDownload( + button_type="primary", filename="curation.json", callback=self._panel_generate_json, height=30 ) + buttons_row.append(self.download_button) - restore_button = pn.widgets.Button( - name="Restore", - button_type="primary", - height=30 - ) + restore_button = pn.widgets.Button(name="Restore", button_type="primary", height=30) restore_button.on_click(self._panel_restore_units) - remove_merge_button = pn.widgets.Button( - name="Unmerge", - button_type="primary", - height=30 - ) + remove_merge_button = pn.widgets.Button(name="Unmerge", button_type="primary", height=30) remove_merge_button.on_click(self._panel_unmerge) - submit_button = pn.widgets.Button( - name="Submit to parent", - button_type="primary", - height=30 - ) + remove_split = pn.widgets.Button(name="Unsplit", button_type="primary", height=30) + remove_split.on_click(self._panel_unsplit) # Create layout - buttons_save = pn.Row( - save_button, - download_button, - submit_button, - sizing_mode="stretch_width", - ) - save_sections = pn.Column( - buttons_save, + self.buttons_save = pn.Row( + *buttons_row, sizing_mode="stretch_width", ) + buttons_curate = pn.Row( restore_button, remove_merge_button, + remove_split, sizing_mode="stretch_width", ) @@ -413,29 +400,19 @@ def _panel_make_layout(self): shortcuts_component.on_msg(self._panel_handle_shortcut) # Create main layout with proper sizing - sections = pn.Row(self.table_delete, self.table_merge, self.table_split, - sizing_mode="stretch_width") + sections = pn.Row(self.table_delete, self.table_merge, self.table_split, sizing_mode="stretch_width") self.layout = pn.Column( - save_sections, - buttons_curate, - sections, - shortcuts_component, - scroll=True, - sizing_mode="stretch_both" + self.buttons_save, buttons_curate, sections, shortcuts_component, scroll=True, sizing_mode="stretch_both" ) - # Add a custom JavaScript callback to the button that doesn't interact with Bokeh models - submit_button.on_click(self._panel_submit_to_parent) - - # Add a hidden div to store the data - self.data_div = pn.pane.HTML("", width=0, height=0, margin=0, sizing_mode="fixed") - self.layout.append(self.data_div) - + self.iframe_detector = IFrameDetector() + self.iframe_detector.param.watch(self._panel_on_iframe_change, "in_iframe") + self.layout.append(self.iframe_detector) def _panel_refresh(self): import pandas as pd - ## deleted + ## deleted removed_units = self.controller.curation_data["removed"] removed_units = [str(unit_id) for unit_id in removed_units] df = pd.DataFrame({"removed": removed_units}) @@ -473,7 +450,7 @@ def _panel_refresh(self): def ensure_save_warning_message(self): - if self.layout[0].name == 'curation_save_warning': + if self.layout[0].name == "curation_save_warning": return import panel as pn @@ -482,13 +459,13 @@ def ensure_save_warning_message(self): f"""⚠️⚠️⚠️ Your curation is not saved""", hard_line_break=True, styles={"color": "red", "font-size": "16px"}, - name="curation_save_warning" + name="curation_save_warning", ) self.layout.insert(0, alert_markdown) def ensure_no_message(self): - if self.layout[0].name == 'curation_save_warning': + if self.layout[0].name == "curation_save_warning": self.layout.pop(0) def _panel_update_unit_visibility(self, event): @@ -518,13 +495,16 @@ def _panel_restore_units(self, event): def _panel_unmerge(self, event): self.unmerge() + def _panel_unsplit(self, event): + self.unsplit() + def _panel_save_in_analyzer(self, event): self.save_in_analyzer() self.refresh() def _panel_generate_json(self): # Get the path from the text input - export_path = "curation.json" + export_path = Path("curation.json") # Save the JSON file curation_model = self.controller.construct_final_curation() with export_path.open("w") as f: @@ -536,40 +516,119 @@ def _panel_generate_json(self): return export_path - def _panel_submit_to_parent(self, event): + def _panel_submit_to_parent(self, event): """Send the curation data to the parent window""" + import time + # Get the curation data and convert it to a JSON string curation_model = self.controller.construct_final_curation() + curation_data = curation_model.model_dump_json() + # Trigger the JavaScript function via the TextInput + # Update the value to trigger the jscallback + self.submit_trigger.value = curation_data + f"_{int(time.time() * 1000)}" - # Create a JavaScript snippet that will send the data to the parent window - js_code = f""" - - """ - - # Update the hidden div with the JavaScript code - self.data_div.object = js_code # Submitting to parent is a way to "save" the curation (the parent can handle it) self.controller.current_curation_saved = True + self.ensure_no_message() + print(f"Curation data sent to parent app!") + + def _panel_set_curation_data(self, event): + """ + Handler for PostMessageListener.on_msg. + + event.data is whatever the JS side passed to model.send_msg(...). + Expected shape: + { + "payload": {"type": "curation-data", "data": }, + } + """ + msg = event.data + payload = (msg or {}).get("payload", {}) + curation_data = payload.get("data", None) + + if curation_data is None: + print("Received message without curation data:", msg) + return + + # Optional: validate basic structure + if not isinstance(curation_data, dict): + print("Invalid curation_data type:", type(curation_data), curation_data) + return + + self.controller.set_curation_data(curation_data) self.refresh() + def _panel_on_iframe_change(self, event): + import panel as pn + from .utils_panel import PostMessageListener + + in_iframe = event.new + print(f"CurationView detected iframe mode: {in_iframe}") + if in_iframe: + # Remove save in analyzer button and add submit to parent button + self.submit_button = pn.widgets.Button(name="Submit to parent", button_type="primary", height=30) + self.submit_button.on_click(self._panel_submit_to_parent) + + self.buttons_save = pn.Row( + self.submit_button, + self.download_button, + sizing_mode="stretch_width", + ) + self.layout[0] = self.buttons_save + + # Create objects to submit and listen + self.submit_trigger = pn.widgets.TextInput(value="", visible=False) + # Add JavaScript callback that triggers when the TextInput value changes + self.submit_trigger.jscallback( + value=""" + // Extract just the JSON data (remove timestamp suffix) + const fullValue = cb_obj.value; + const lastUnderscore = fullValue.lastIndexOf('_'); + const dataStr = lastUnderscore > 0 ? fullValue.substring(0, lastUnderscore) : fullValue; + + if (dataStr && dataStr.length > 0) { + try { + const data = JSON.parse(dataStr); + console.log('Sending data to parent:', data); + parent.postMessage({ + type: 'panel-data', + data: data + }, + '*'); + console.log('Data sent successfully to parent window'); + } catch (error) { + console.error('Error sending data to parent:', error); + } + } + """ + ) + self.layout.append(self.submit_trigger) + # Set up listener for external curation changes + self.listener = PostMessageListener() + self.listener.on_msg(self._panel_set_curation_data) + self.layout.append(self.listener) + + # Notify parent that panel is ready to receive messages + self.ready_trigger = pn.pane.HTML( + """ + + """ + ) + self.layout.append(self.ready_trigger) + def _panel_get_delete_table_selection(self): selected_items = self.table_delete.selection if len(selected_items) == 0: @@ -606,22 +665,46 @@ def _panel_on_unit_visibility_changed(self): def _panel_on_deleted_col(self, row): self.active_table = "delete" - self.table_merge.selection = [] - self.table_split.selection = [] def _panel_on_merged_col(self, row): self.active_table = "merge" - self.table_delete.selection = [] - self.table_split.selection = [] def _panel_on_split_col(self, row): self.active_table = "split" - self.table_delete.selection = [] - self.table_merge.selection = [] - # set split selection - split_unit_str = self.table_split.value["splits"].values[row] - split_unit_id = self.controller.unit_ids.dtype.type(split_unit_str.split(" ")[0]) - self.select_and_notify_split(split_unit_id) + + def _panel_on_table_selection_changed(self, event): + """ + Unified handler for all table selection changes. + Determines which table was changed and updates visibility accordingly. + """ + import panel as pn + + print(f"Selection changed in table: {self.active_table}") + # Determine which table triggered the change + if self.active_table == "delete" and len(self.table_delete.selection) > 0: + self.table_merge.selection = [] + self.table_split.selection = [] + # Defer to avoid nested Bokeh callbacks (especially from ctrl+click) + pn.state.execute(lambda: self._panel_update_unit_visibility(None), schedule=True) + elif self.active_table == "merge" and len(self.table_merge.selection) > 0: + self.table_delete.selection = [] + self.table_split.selection = [] + # Defer to avoid nested Bokeh callbacks (especially from ctrl+click) + pn.state.execute(lambda: self._panel_update_unit_visibility(None), schedule=True) + elif self.active_table == "split" and len(self.table_split.selection) > 0: + self.table_delete.selection = [] + self.table_merge.selection = [] + # Handle split specially (sets selected spikes) - also deferred + def handle_split(): + split_unit_str = self.table_split.value["splits"].values[self.table_split.selection[0]] + split_unit_id = self.controller.unit_ids.dtype.type(split_unit_str.split(" ")[0]) + self.select_and_notify_split(split_unit_id) + # Defer to avoid nested Bokeh callbacks (especially from ctrl+click) + pn.state.execute(handle_split, schedule=True) + elif len(self.table_delete.selection) == 0 and len(self.table_merge.selection) == 0 and len(self.table_split.selection) == 0: + # All tables are cleared + self.active_table = None + def _conditional_refresh_merge(self): # Check if the view is active before refreshing @@ -656,8 +739,22 @@ def _conditional_refresh_split(self): - **export/download JSON**: Export the current curation state to a JSON file. - **restore**: Restore the selected unit from the deleted units table. - **unmerge**: Unmerge the selected merges from the merged units table. +- **unsplit**: Unsplit the selected split groups from the split units table. - **submit to parent**: Submit the current curation state to the parent window (for use in web applications). - **press 'ctrl+r'**: Restore the selected units from the deleted units table. - **press 'ctrl+u'**: Unmerge the selected merges from the merged units table. - **press 'ctrl+x'**: Unsplit the selected split groups from the split units table. + +### Note +When setting the `iframe_mode` setting to `True` using the `user_settings=dict(curation=dict(iframe_mode=True))`, +the GUI is expected to be used inside an iframe. In this mode, the curation view will include a "Submit to parent" +button that, when clicked, will send the current curation data to the parent window. +In this mode, bi-directional communication is established between the GUI and the parent window using the `postMessage` +API. The GUI listens for incoming messages of this expected shape: + +``` +{ + "payload": {"type": "curation-data", "data": }, +} +``` """ diff --git a/spikeinterface_gui/mergeview.py b/spikeinterface_gui/mergeview.py index 18fc7b13..1a6169f3 100644 --- a/spikeinterface_gui/mergeview.py +++ b/spikeinterface_gui/mergeview.py @@ -438,9 +438,15 @@ def _panel_on_preset_change(self, event): self.layout[layout_index] = self.preset_params_selectors[self.preset] def _panel_on_click(self, event): + import panel as pn + # set unit visibility row = event.row - self.table.selection = [row] + + def _do_update(): + self.table.selection = [row] + + pn.state.execute(_do_update, schedule=True) self._panel_update_visible_pair(row) def _panel_include_deleted_change(self, event): @@ -458,19 +464,40 @@ def _panel_update_visible_pair(self, row): self.notify_unit_visibility_changed() def _panel_handle_shortcut(self, event): + import panel as pn + if event.data == "accept": selected = self.table.selection - for row in selected: - group_ids = self.table.value.iloc[row].group_ids - self.accept_group_merge(group_ids) + if len(selected) == 0: + return + # selected is always 1 + row = selected[0] + group_ids = self.table.value.iloc[row].group_ids + self.accept_group_merge(group_ids) self.notify_manual_curation_updated() + + next_row = min(row + 1, len(self.table.value) - 1) + + def _select_next(): + self.table.selection = [next_row] + + pn.state.execute(_select_next, schedule=True) + self._panel_update_visible_pair(next_row) elif event.data == "next": next_row = min(self.table.selection[0] + 1, len(self.table.value) - 1) - self.table.selection = [next_row] + + def _do_next(): + self.table.selection = [next_row] + + pn.state.execute(_do_next, schedule=True) self._panel_update_visible_pair(next_row) elif event.data == "previous": previous_row = max(self.table.selection[0] - 1, 0) - self.table.selection = [previous_row] + + def _do_prev(): + self.table.selection = [previous_row] + + pn.state.execute(_do_prev, schedule=True) self._panel_update_visible_pair(previous_row) def _panel_on_spike_selection_changed(self): diff --git a/spikeinterface_gui/metricsview.py b/spikeinterface_gui/metricsview.py index cab29432..3c3b12bf 100644 --- a/spikeinterface_gui/metricsview.py +++ b/spikeinterface_gui/metricsview.py @@ -172,7 +172,7 @@ def _panel_make_layout(self): self.empty_plot_pane = pn.pane.Bokeh(empty_fig, sizing_mode="stretch_both") self.layout = pn.Column( - pn.Row(self.metrics_select, sizing_mode="stretch_width", height=160), + pn.Row(self.metrics_select, sizing_mode="stretch_width", height=100), self.empty_plot_pane, sizing_mode="stretch_both" ) @@ -221,6 +221,7 @@ def _panel_refresh(self): outline_line_color="white", toolbar_location=None ) + plot.toolbar.logo = None plot.grid.visible = False if r == c: diff --git a/spikeinterface_gui/ndscatterview.py b/spikeinterface_gui/ndscatterview.py index ec63b712..0b8ba542 100644 --- a/spikeinterface_gui/ndscatterview.py +++ b/spikeinterface_gui/ndscatterview.py @@ -488,20 +488,31 @@ def _panel_refresh(self, update_components=True, update_colors=True): self.scatter_fig.y_range.end = self.limit def _panel_on_spike_selection_changed(self): + import panel as pn + # handle selection with lasso plotted_spike_indices = self.scatter_source.data.get("spike_indices", []) ind_selected, = np.nonzero(np.isin(plotted_spike_indices, self.controller.get_indices_spike_selected())) - self.scatter_source.selected.indices = ind_selected + + def _do_update(): + self.scatter_source.selected.indices = ind_selected + + pn.state.execute(_do_update, schedule=True) def _panel_gain_zoom(self, event): - from bokeh.models import Range1d + import panel as pn factor = 1.3 if event.delta > 0 else 1 / 1.3 self.limit /= factor - self.scatter_fig.x_range.start = -self.limit - self.scatter_fig.x_range.end = self.limit - self.scatter_fig.y_range.start = -self.limit - self.scatter_fig.y_range.end = self.limit + limit = self.limit + + def _do_update(): + self.scatter_fig.x_range.start = -limit + self.scatter_fig.x_range.end = limit + self.scatter_fig.y_range.start = -limit + self.scatter_fig.y_range.end = limit + + pn.state.execute(_do_update, schedule=True) def _panel_next_face(self, event): self.next_face() @@ -522,11 +533,18 @@ def _panel_start_stop_tour(self, event): self.auto_update_limit = True def _panel_on_select_button(self, event): - if self.select_toggle_button.value: - self.scatter_fig.toolbar.active_drag = self.lasso_tool - else: - self.scatter_fig.toolbar.active_drag = None - self.scatter_source.selected.indices = [] + import panel as pn + + value = self.select_toggle_button.value + + def _do_update(): + if value: + self.scatter_fig.toolbar.active_drag = self.lasso_tool + else: + self.scatter_fig.toolbar.active_drag = None + self.scatter_source.selected.indices = [] + + pn.state.execute(_do_update, schedule=True) def _on_panel_selection_geometry(self, event): """ diff --git a/spikeinterface_gui/probeview.py b/spikeinterface_gui/probeview.py index db7a217f..1c880208 100644 --- a/spikeinterface_gui/probeview.py +++ b/spikeinterface_gui/probeview.py @@ -485,55 +485,84 @@ def _panel_make_layout(self): ) def _panel_refresh(self): - # Only update unit positions if they actually changed + import panel as pn + + # Compute everything outside the scheduled callback current_unit_positions = self.controller.unit_positions - if not np.array_equal(current_unit_positions, self._unit_positions): - self._unit_positions = current_unit_positions - # Update positions in data source - self.glyphs_data_source.patch({ + positions_changed = not np.array_equal(current_unit_positions, self._unit_positions) + show_channel_id = self.settings['show_channel_id'] + auto_zoom = self.settings['auto_zoom_on_unit_selection'] + radius_channel = self.settings['radius_channel'] + + selected_unit_indices = self.controller.get_visible_unit_indices() + visible_mask = self.controller.get_units_visibility_mask() + + # Pre-compute unit glyph updates + glyph_patch_dict = self._panel_compute_unit_glyph_patches() + + # Pre-compute position patches + position_patches = None + if positions_changed: + position_patches = { 'x': [(i, pos[0]) for i, pos in enumerate(current_unit_positions)], 'y': [(i, pos[1]) for i, pos in enumerate(current_unit_positions)] - }) - - # Update unit positions - self._panel_update_unit_glyphs() + } - # channel labels - for label in self.channel_labels: - label.visible = self.settings['show_channel_id'] + # Pre-compute zoom bounds + zoom_bounds = None + if auto_zoom and sum(visible_mask) > 0: + visible_pos = self.controller.unit_positions[visible_mask, :] + x_min, x_max = np.min(visible_pos[:, 0]), np.max(visible_pos[:, 0]) + y_min, y_max = np.min(visible_pos[:, 1]), np.max(visible_pos[:, 1]) + margin = 50 + zoom_bounds = (x_min - margin, x_max + margin, y_min - margin, y_max + margin) - # Update selection circles if only one unit is visible - selected_unit_indices = self.controller.get_visible_unit_indices() + # Pre-compute circle updates + circle_update = None if len(selected_unit_indices) == 1: unit_index = selected_unit_indices[0] unit_positions = self.controller.unit_positions - x, y = unit_positions[unit_index, 0], unit_positions[unit_index, 1] - # Update circles position - self.unit_circle.update_position(x, y) + cx, cy = unit_positions[unit_index, 0], unit_positions[unit_index, 1] + visible_channel_inds = self.update_channel_visibility(cx, cy, radius_channel) + circle_update = (cx, cy, visible_channel_inds) - self.channel_circle.update_position(x, y) - # Update channel visibility - visible_channel_inds = self.update_channel_visibility(x, y, self.settings['radius_channel']) + # def _do_update(): + if positions_changed: + self._unit_positions = current_unit_positions + if position_patches: + self.glyphs_data_source.patch(position_patches) + + # Update unit glyphs + if glyph_patch_dict: + self.glyphs_data_source.patch(glyph_patch_dict) + + # Channel labels + for label in self.channel_labels: + label.visible = show_channel_id + + # Update selection circles + if circle_update is not None: + cx, cy, visible_channel_inds = circle_update + self.unit_circle.update_position(cx, cy) + self.channel_circle.update_position(cx, cy) self.controller.set_channel_visibility(visible_channel_inds) - if self.settings['auto_zoom_on_unit_selection']: - visible_mask = self.controller.get_units_visibility_mask() - if sum(visible_mask) > 0: - visible_pos = self.controller.unit_positions[visible_mask, :] - x_min, x_max = np.min(visible_pos[:, 0]), np.max(visible_pos[:, 0]) - y_min, y_max = np.min(visible_pos[:, 1]), np.max(visible_pos[:, 1]) - margin = 50 - self.x_range.start = x_min - margin - self.x_range.end = x_max + margin - self.y_range.start = y_min - margin - self.y_range.end = y_max + margin + # Auto zoom + if zoom_bounds is not None: + self.x_range.start = zoom_bounds[0] + self.x_range.end = zoom_bounds[1] + self.y_range.start = zoom_bounds[2] + self.y_range.end = zoom_bounds[3] - def _panel_update_unit_glyphs(self): - # Get current data from source + # Defer to avoid nested Bokeh callbacks + # pn.state.execute(_do_update, schedule=True) + + def _panel_compute_unit_glyph_patches(self): + """Compute glyph patches without modifying Bokeh models.""" current_alphas = self.glyphs_data_source.data['alpha'] current_sizes = self.glyphs_data_source.data['size'] current_line_colors = self.glyphs_data_source.data['line_color'] - # Prepare patches (only for changed values) + alpha_patches = [] size_patches = [] line_color_patches = [] @@ -542,12 +571,10 @@ def _panel_update_unit_glyphs(self): color = self.get_unit_color(unit_id) is_visible = self.controller.get_unit_visibility(unit_id) - # Compute new values new_alpha = self.alpha_selected if is_visible else self.alpha_unselected new_size = self.unit_marker_size_selected if is_visible else self.unit_marker_size_unselected new_line_color = "black" if is_visible else color - # Only patch if changed if current_alphas[idx] != new_alpha: alpha_patches.append((idx, new_alpha)) if current_sizes[idx] != new_size: @@ -555,22 +582,33 @@ def _panel_update_unit_glyphs(self): if current_line_colors[idx] != new_line_color: line_color_patches.append((idx, new_line_color)) - # Apply patches if any changes detected - if len(alpha_patches) > 0 or len(size_patches) > 0 or len(line_color_patches) > 0: - patch_dict = {} - if alpha_patches: - patch_dict['alpha'] = alpha_patches - if size_patches: - patch_dict['size'] = size_patches - if line_color_patches: - patch_dict['line_color'] = line_color_patches - - self.glyphs_data_source.patch(patch_dict) - + patch_dict = {} + if alpha_patches: + patch_dict['alpha'] = alpha_patches + if size_patches: + patch_dict['size'] = size_patches + if line_color_patches: + patch_dict['line_color'] = line_color_patches + + return patch_dict + + def _panel_update_unit_glyphs(self): + import panel as pn + + patch_dict = self._panel_compute_unit_glyph_patches() + if patch_dict: + pn.state.execute(lambda: self.glyphs_data_source.patch(patch_dict), schedule=True) + def _panel_on_pan_start(self, event): - self.figure.toolbar.active_drag = None + import panel as pn + x, y = event.x, event.y + def _do_update(): + self.figure.toolbar.active_drag = None + + pn.state.execute(_do_update, schedule=True) + if self.unit_circle.is_close_to_diamond(x, y): self.should_resize_unit_circle = [x, y] self.unit_circle.select() @@ -654,6 +692,8 @@ def _panel_on_pan_end(self, event): def _panel_on_tap(self, event): + import panel as pn + x, y = event.x, event.y unit_positions = self.controller.unit_positions distances = np.sqrt(np.sum((unit_positions - np.array([x, y])) ** 2, axis=1)) @@ -683,16 +723,18 @@ def _panel_on_tap(self, event): self._panel_update_unit_glyphs() - if select_only: # Update selection circles - self.unit_circle.update_position(x, y) - self.channel_circle.update_position(x, y) + def _do_update(): + self.unit_circle.update_position(x, y) + self.channel_circle.update_position(x, y) + + pn.state.execute(_do_update, schedule=True) # Update channel visibility visible_channel_inds = self.update_channel_visibility(x, y, self.settings['radius_channel']) self.controller.set_channel_visibility(visible_channel_inds) - self.notify_channel_visibility_changed + self.notify_channel_visibility_changed() self.notify_unit_visibility_changed() diff --git a/spikeinterface_gui/spikelistview.py b/spikeinterface_gui/spikelistview.py index b82986c3..414626de 100644 --- a/spikeinterface_gui/spikelistview.py +++ b/spikeinterface_gui/spikelistview.py @@ -118,6 +118,7 @@ class SpikeListView(ViewBase): ] def __init__(self, controller=None, parent=None, backend="qt"): + self._updating_from_controller = False # Add this guard ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) def handle_selection(self, inds): @@ -276,7 +277,7 @@ def _panel_make_layout(self): conditional_shortcut=self.is_view_active, ) # Add selection event handler - self.table.param.watch(self._panel_on_selection_changed, "selection") + self.table.param.watch(self._panel_on_user_selection_changed, "selection") self.refresh_button = pn.widgets.Button(name="↻ spikes", button_type="default", sizing_mode="stretch_width") self.refresh_button.on_click(self._panel_on_refresh_click) @@ -328,17 +329,22 @@ def _panel_refresh_table(self): 'rand_selected': spikes['rand_selected'] } - # Update table data + # Update table data without replacing entire dataframe df = pd.DataFrame(data) - self.table.value = df + + # Only update if data changed + if not self.table.value.equals(df): + self.table.value = df selected_inds = self.controller.get_indices_spike_selected() + self._updating_from_controller = True if len(selected_inds) == 0: self.table.selection = [] else: # Find the rows corresponding to the selected indices row_selected, = np.nonzero(np.isin(visible_inds, selected_inds)) self.table.selection = [int(r) for r in row_selected] + self._updating_from_controller = False self._panel_refresh_label() @@ -348,21 +354,39 @@ def _panel_on_refresh_click(self, event): self.notify_active_view_updated() def _panel_on_clear_click(self, event): + import panel as pn + self.controller.set_indices_spike_selected([]) - self.table.selection = [] + + def _do_update(): + self.table.selection = [] + self._panel_refresh_label() + + pn.state.execute(_do_update, schedule=True) self.notify_spike_selection_changed() - self._panel_refresh_label() self.notify_active_view_updated() - def _panel_on_selection_changed(self, event=None): - selection = event.new + def _panel_on_user_selection_changed(self, event=None): + import panel as pn + + # Ignore if we're updating from controller + if self._updating_from_controller: + return + + selection = event.new if event else self.table.selection if len(selection) == 0: - self.handle_selection([]) + absolute_indices = [] else: - absolute_indices = self.controller.get_indices_spike_visible()[np.array(selection)] - self.handle_selection(absolute_indices) + visible_inds = self.controller.get_indices_spike_visible() + absolute_indices = visible_inds[np.array(selection)] + + # Defer the entire selection handling to avoid nested Bokeh callbacks + # def update_selection(): + self.handle_selection(absolute_indices) self._panel_refresh_label() + # pn.state.execute(update_selection, schedule=True) + def _panel_refresh_label(self): n1 = self.controller.spikes.size n2 = self.controller.get_indices_spike_visible().size @@ -371,21 +395,35 @@ def _panel_refresh_label(self): self.info_text.object = txt def _panel_on_unit_visibility_changed(self): + import panel as pn import pandas as pd - # Clear the table when visibility changes - self.table.value = pd.DataFrame(columns=_columns, data=[]) - self._panel_refresh_label() + + def _do_update(): + # Clear the table when visibility changes + self.table.value = pd.DataFrame(columns=_columns, data=[]) + self._updating_from_controller = True + self.table.selection = [] + self._updating_from_controller = False + self._panel_refresh_label() + + pn.state.execute(_do_update, schedule=True) def _panel_on_spike_selection_changed(self): + import panel as pn + if len(self.table.value) == 0: return selected_inds = self.controller.get_indices_spike_selected() visible_inds = self.controller.get_indices_spike_visible() row_selected, = np.nonzero(np.isin(visible_inds, selected_inds)) row_selected = [int(r) for r in row_selected] - # Update the selection in the table - self.table.selection = row_selected - self._panel_refresh_label() + + def _do_update(): + # Update the selection in the table + self.table.selection = row_selected + self._panel_refresh_label() + + pn.state.execute(_do_update, schedule=True) SpikeListView._gui_help_txt = """ diff --git a/spikeinterface_gui/tests/iframe/curation.json b/spikeinterface_gui/tests/iframe/curation.json new file mode 100644 index 00000000..db7d689e --- /dev/null +++ b/spikeinterface_gui/tests/iframe/curation.json @@ -0,0 +1,75 @@ +{ + "supported_versions": [ + "1", + "2" + ], + "format_version": "2", + "unit_ids": [ + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "10", + "11", + "12", + "13", + "14", + "15" + ], + "label_definitions": { + "quality": { + "name": "quality", + "label_options": [ + "good", + "noise", + "MUA" + ], + "exclusive": true + } + }, + "manual_labels": [], + "removed": [ + "3", + "8" + ], + "merges": [ + { + "unit_ids": [ + "1", + "2" + ], + "new_unit_id": null + }, + { + "unit_ids": [ + "5", + "6", + "7" + ], + "new_unit_id": null + } + ], + "splits": [ + { + "unit_id": "4", + "mode": "indices", + "indices": [ + [ + 0, + 1, + 2, + 3, + 4 + ] + ], + "labels": null, + "new_unit_ids": null + } + ] +} \ No newline at end of file diff --git a/spikeinterface_gui/tests/iframe/iframe_server.py b/spikeinterface_gui/tests/iframe/iframe_server.py index 8dada7b9..ad1b6e1f 100644 --- a/spikeinterface_gui/tests/iframe/iframe_server.py +++ b/spikeinterface_gui/tests/iframe/iframe_server.py @@ -2,120 +2,142 @@ import webbrowser from pathlib import Path import argparse +import socket +import time + from flask import Flask, send_file, jsonify app = Flask(__name__) + +PANEL_HOST = "localhost" # <- switch back to localhost + panel_server = None panel_url = None panel_thread = None panel_port_global = None +panel_last_error = None + -@app.route('/') +def _wait_for_port(host: str, port: int, timeout_s: float = 20.0) -> bool: + deadline = time.time() + timeout_s + while time.time() < deadline: + try: + with socket.create_connection((host, port), timeout=0.25): + return True + except OSError: + time.sleep(0.1) + return False + + +@app.route("/") def index(): - """Serve the iframe test HTML page""" - return send_file('iframe_test.html') + here = Path(__file__).parent + return send_file(str(here / "iframe_test.html")) + -@app.route('/start_test_server') +@app.route("/curation.json") +def curation_json(): + here = Path(__file__).parent + return send_file(str(here / "curation.json")) + + +@app.route("/start_test_server") def start_test_server(): - """Start the Panel server in a separate thread""" - global panel_server, panel_url, panel_thread, panel_port_global - - # If a server is already running, return its URL + global panel_server, panel_url, panel_thread, panel_port_global, panel_last_error + if panel_url: return jsonify({"success": True, "url": panel_url}) - - # Make sure the test dataset exists + + panel_last_error = None + test_folder = Path(__file__).parent / "my_dataset" if not test_folder.is_dir(): from spikeinterface_gui.tests.testingtools import make_analyzer_folder + make_analyzer_folder(test_folder) - - # Function to run the Panel server in a thread + def run_panel_server(): - global panel_server, panel_url, panel_port_global + global panel_server, panel_url, panel_port_global, panel_last_error try: - # Start the Panel server with curation enabled - # Use a direct import to avoid circular imports + import panel as pn from spikeinterface import load_sorting_analyzer from spikeinterface_gui import run_mainwindow - - # Load the analyzer + + pn.extension("tabulator", "gridstack") + analyzer = load_sorting_analyzer(test_folder / "sorting_analyzer") - - # Start the Panel server directly - win = run_mainwindow( - analyzer, - backend="panel", - start_app=False, - verbose=True, - curation=True, - make_servable=True + + def app_factory(): + win = run_mainwindow( + analyzer, + mode="web", + start_app=False, + verbose=True, + curation=True, + panel_window_servable=True, + ) + return win.main_layout + + allowed = [ + f"localhost:{int(panel_port_global)}", + f"127.0.0.1:{int(panel_port_global)}", + ] + print(panel_port_global) + server = pn.serve( + app_factory, + port=int(panel_port_global), + address=PANEL_HOST, + allow_websocket_origin=allowed, + show=False, + start=False, ) - - # Start the server manually - import panel as pn - pn.serve(win.main_layout, port=panel_port_global, address="localhost", show=False, start=True) - - # Get the server URL - panel_url = f"http://localhost:{panel_port_global}" - panel_server = win - - print(f"Panel server started at {panel_url}") + + panel_server = server + panel_url = f"http://{PANEL_HOST}:{panel_port_global}/" + print(f"Panel server starting at {panel_url} (allow_websocket_origin={allowed})") + + server.start() + server.io_loop.start() + except Exception as e: - print(f"Error starting Panel server: {e}") + panel_last_error = repr(e) + panel_server = None + panel_url = None import traceback + traceback.print_exc() - - # Start the Panel server in a separate thread - panel_thread = threading.Thread(target=run_panel_server) - panel_thread.daemon = True + + panel_thread = threading.Thread(target=run_panel_server, daemon=True) panel_thread.start() - - # Give the server some time to start - import time - time.sleep(5) # Increased wait time - - # Check if the server is actually running - import requests - try: - response = requests.get(f"http://localhost:{panel_port_global}", timeout=2) - if response.status_code == 200: - return jsonify({"success": True, "url": f"http://localhost:{panel_port_global}"}) - else: - return jsonify({"success": False, "error": f"Server returned status code {response.status_code}"}) - except requests.exceptions.RequestException as e: - return jsonify({"success": False, "error": f"Could not connect to Panel server: {str(e)}"}) - -@app.route('/stop_test_server') -def stop_test_server(): - """Stop the Panel server""" - global panel_server, panel_url, panel_thread - - if panel_server: - # Clean up resources - # clean_all(Path(__file__).parent / 'my_dataset') - panel_url = None - panel_server = None - return jsonify({"success": True}) - else: - return jsonify({"success": False, "error": "No server running"}) - -def main(flask_port=5000, panel_port=5006): - """Start the Flask server and open the browser""" + + if not _wait_for_port("127.0.0.1", int(panel_port_global), timeout_s=30.0): + return ( + jsonify( + { + "success": False, + "error": "Panel server did not become ready (port not open).", + "panel_host": PANEL_HOST, + "panel_port": panel_port_global, + "last_error": panel_last_error, + } + ), + 500, + ) + + return jsonify({"success": True, "url": panel_url}) + + +def main(flask_port=5000, panel_port=5007): global panel_port_global panel_port_global = panel_port - # Open the browser - webbrowser.open(f'http://localhost:{flask_port}') - - # Start the Flask server - app.run(debug=False, port=flask_port) + webbrowser.open(f"http://localhost:{flask_port}") + app.run(debug=False, port=flask_port, host="localhost") parser = argparse.ArgumentParser(description="Run the Flask and Panel servers.") -parser.add_argument('--flask-port', type=int, default=5000, help="Port for the Flask server (default: 5000)") -parser.add_argument('--panel-port', type=int, default=5006, help="Port for the Panel server (default: 5006)") +parser.add_argument("--flask-port", type=int, default=5000) +parser.add_argument("--panel-port", type=int, default=5006) -if __name__ == '__main__': +if __name__ == "__main__": args = parser.parse_args() - main(flask_port=int(args.flask_port), panel_port=int(args.panel_port)) diff --git a/spikeinterface_gui/tests/iframe/iframe_test.html b/spikeinterface_gui/tests/iframe/iframe_test.html index d3e0d91a..050b05e4 100644 --- a/spikeinterface_gui/tests/iframe/iframe_test.html +++ b/spikeinterface_gui/tests/iframe/iframe_test.html @@ -51,6 +51,7 @@ display: flex; gap: 10px; margin-bottom: 10px; + flex-wrap: wrap; } button { padding: 8px 16px; @@ -63,22 +64,68 @@ button:hover { background-color: #45a049; } + button:disabled { + background-color: #cccccc; + cursor: not-allowed; + } + .send-controls { + background-color: #fff; + border: 1px solid #ccc; + border-radius: 5px; + padding: 20px; + } + textarea { + width: 100%; + min-height: 150px; + font-family: 'Courier New', monospace; + font-size: 12px; + padding: 10px; + border: 1px solid #ddd; + border-radius: 4px; + box-sizing: border-box; + } + .button-secondary { + background-color: #2196F3; + } + .button-secondary:hover { + background-color: #0b7dda; + } + .button-danger { + background-color: #f44336; + } + .button-danger:hover { + background-color: #da190b; + }
-

SpikeInterface GUI Iframe Test

- +

SpikeInterface GUI Iframe Test - Bidirectional Communication

+
- + +
- +
+ +

Send Curation Data to iframe:

+
+

Edit the curation data below and click "Send to iframe" to update the GUI:

+ +
+ + + +
+
-

Received Data:

+

Received Data from iframe:

No data received yet...
@@ -88,23 +135,88 @@

Received Data:

// DOM elements const iframe = document.getElementById('guiFrame'); const dataOutput = document.getElementById('dataOutput'); + const curationInput = document.getElementById('curationInput'); const clearBtn = document.getElementById('clearBtn'); const startServerBtn = document.getElementById('startServerBtn'); - + const sendDataBtn = document.getElementById('sendDataBtn'); + const resetDataBtn = document.getElementById('resetDataBtn'); + const clearCurationBtn = document.getElementById('clearCurationBtn'); + const autoSendCheckbox = document.getElementById('autoSendCheckbox'); + // Store received data let receivedData = []; - + let autoSendEnabled = false; + let defaultCurationData = null; + + // Update auto-send state when checkbox changes + autoSendCheckbox.addEventListener('change', function() { + autoSendEnabled = autoSendCheckbox.checked; + console.log('Auto-send ' + (autoSendEnabled ? 'enabled' : 'disabled')); + }); + + // Load default curation data from curation.json + fetch('/curation.json') + .then(response => { + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + return response.json(); + }) + .then(data => { + defaultCurationData = data; + // Initialize the textarea with the loaded data + curationInput.value = JSON.stringify(defaultCurationData, null, 2); + console.log('Loaded curation data from curation.json'); + }) + .catch(error => { + console.error('Error loading curation.json:', error); + // Fallback to simple default data + defaultCurationData = { + "merges": [ + {"unit_ids": ["1", "2"]}, + {"unit_ids": ["5", "6", "7"]} + ], + "removed": ["3", "8"], + "splits": [ + { + "unit_id": "4", + "mode": "indices", + "indices": [[0, 1, 2, 3, 4]] + } + ] + }; + curationInput.value = JSON.stringify(defaultCurationData, null, 2); + console.log('Using fallback curation data'); + }); + // Event listener for messages from the iframe function handleMessage(event) { console.log('Message received from iframe:', event.data); - + + // Ignore messages from browser extensions + if (event.data && event.data.source === 'react-devtools-content-script') { + return; + } + // Check if the message is from our Panel application if (event.data && event.data.type === 'panel-data') { const data = event.data.data; receivedData.push(data); // Format and display the data - displayData(); + displayReceivedData(); + + // Check for loaded status and auto-send if enabled + if (data && data.loaded === true) { + console.log('iframe loaded detected, auto-send enabled:', autoSendCheckbox.checked); + if (autoSendCheckbox.checked) { + console.log('Auto-sending curation data...'); + const curationData = curationInput.value; + setTimeout(() => { + sendCurationData(curationData); + }, 100); // Small delay to ensure iframe is fully ready + } + } } } @@ -112,7 +224,7 @@

Received Data:

window.addEventListener('message', handleMessage); // Function to display the received data - function displayData() { + function displayReceivedData() { if (receivedData.length === 0) { dataOutput.textContent = 'No data received yet...'; return; @@ -129,10 +241,33 @@

Received Data:

dataOutput.scrollTop = dataOutput.scrollHeight; } + // Function to send curation data to the iframe + function sendCurationData(curationData) { + try { + // Validate the JSON + const data = typeof curationData === 'string' + ? JSON.parse(curationData) + : curationData; + + console.log('Sending curation data to iframe:', data); + + // Send the message to the iframe + iframe.contentWindow.postMessage({ + type: 'curation-data', + data: data + }, '*'); + + console.log('Curation data sent to iframe successfully!'); + } catch (error) { + console.error('Error sending curation data:', error); + alert('Error sending curation data: ' + error.message); + } + } + // Clear button event listener clearBtn.addEventListener('click', function() { receivedData = []; - displayData(); + displayReceivedData(); // Re-add the event listener to ensure it's working window.removeEventListener('message', handleMessage); @@ -140,6 +275,29 @@

Received Data:

console.log('Event listener reset'); }); + // Send data button event listener + sendDataBtn.addEventListener('click', function() { + const curationData = curationInput.value; + sendCurationData(curationData); + }); + + // Reset data button event listener + resetDataBtn.addEventListener('click', function() { + if (defaultCurationData) { + curationInput.value = JSON.stringify(defaultCurationData, null, 2); + } + }); + + // Clear curation button event listener + clearCurationBtn.addEventListener('click', function() { + const emptyCuration = { + "merges": [], + "removed": [], + "splits": [] + }; + sendCurationData(emptyCuration); + }); + // Start server button event listener startServerBtn.addEventListener('click', function() { // Disable the button while starting the server @@ -177,9 +335,17 @@

Received Data:

// Add instructions for using the curation view dataOutput.textContent = `Instructions: 1. Click "Start Test Server" to launch the SpikeInterface GUI -2. Once loaded, navigate to the "curation" tab -3. Click the "Submit to parent" button to send data to this page -4. The received data will appear here +2. Optionally enable "Auto send when ready" to automatically send curation data when the iframe signals it's loaded +3. Once loaded, navigate to the "curation" tab + +Testing SEND (parent → iframe): +- Edit the curation data in the textarea above +- Click "Send to iframe" to update the GUI's curation data +- The GUI should refresh and show the new curation data + +Testing RECEIVE (iframe → parent): +- Click the "Submit to parent" button in the GUI's curation view +- The received data will appear in this section below Note: The first time you start the server, it may take a minute to generate test data.`; diff --git a/spikeinterface_gui/tests/iframe/iframe_test_README.md b/spikeinterface_gui/tests/iframe/iframe_test_README.md index a84e6092..76db6829 100644 --- a/spikeinterface_gui/tests/iframe/iframe_test_README.md +++ b/spikeinterface_gui/tests/iframe/iframe_test_README.md @@ -1,13 +1,19 @@ -# SpikeInterface GUI Iframe Test +# SpikeInterface GUI Iframe Test - Bidirectional Communication -This is a simple test application that demonstrates how to embed the SpikeInterface GUI in an iframe and receive data from it. +This is a test application that demonstrates bidirectional communication between a parent window and the SpikeInterface GUI embedded in an iframe. ## Overview The test consists of: -1. `iframe_test.html` - A simple HTML page that embeds the SpikeInterface GUI in an iframe and displays data received from it -2. `iframe_server.py` - A Flask server that serves the HTML page and starts the Panel application +1. `iframe_test.html` - An HTML page that: + - Embeds the SpikeInterface GUI in an iframe + - Sends curation data TO the iframe + - Receives curation data FROM the iframe + +2. `iframe_server.py` - A Flask server that: + - Serves the HTML page + - Starts the Panel application ## How to Run @@ -23,33 +29,67 @@ The test consists of: 3. This will open a browser window with the test application. Click the "Start Test Server" button to launch the SpikeInterface GUI. -4. Once the GUI is loaded, navigate to the "curation" tab (you may need to look for it in the different zones of the interface). +## Testing the Bidirectional Communication -5. Click the "Submit to parent" button in the curation view to send data to the parent window. +### Test 1: Send Data to iframe (Parent → iframe) -6. The received data will be displayed in the "Received Data" section of the parent window. +1. Once the GUI is loaded, navigate to the "curation" tab +2. In the parent window, edit the curation data in the textarea +3. Click "Send to iframe" button +4. The GUI should refresh and display the new curation data in its tables +5. The "Auto send when ready" button will automatically send the curation data when the iframe is ready -## How It Works +### Test 2: Receive Data from iframe (iframe → Parent) -1. The Flask server serves the HTML page and provides an endpoint to start the Panel application. +1. In the GUI's curation view, click the "Submit to parent" button +2. The data will appear in the "Received Data from iframe" section of the parent window +3. Multiple submissions will be accumulated and displayed -2. The HTML page embeds the Panel application in an iframe. +### Test 3: Clear All Curation -3. The curationview.py file has been modified to send data to the parent window using the `postMessage` API when the "Submit to parent" button is clicked. +1. Click the "Clear All Curation" button +2. This sends an empty curation object to the iframe +3. All tables in the curation view should become empty -4. The parent window listens for messages from the iframe and displays the received data. +## How It Works -## Troubleshooting +### Parent → iframe Communication -- If the Panel application fails to start, check the console for error messages. -- If no data is received when clicking the "Submit to parent" button, make sure you're using the latest version of the curationview.py file with the fix for the DataCloneError. -- If you're running the Panel application separately, you can manually set the iframe URL using the `setIframeUrl` function in the browser console. +1. Parent window constructs a curation data object +2. The iframe sends a `loaded=true` message when ready to receive data +3. Sends it via `postMessage` with `type: 'curation-data'` +4. The iframe's JavaScript listener receives the message +5. Python updates `controller.set_curation_data` and refreshes the view -## Technical Details +### iframe → Parent Communication -The fix for the DataCloneError involves: +1. User clicks "Submit to parent" in the curation view +2. Python generates a JSON string of the curation model +3. JavaScript code is injected that sends the data via `postMessage` +4. Parent window's message listener receives the data +5. Data is parsed and displayed -1. Converting the complex object to a JSON string before sending it to the parent window -2. Parsing the JSON string back to an object in the parent window -This avoids the issue where the browser tries to clone a complex object that contains non-cloneable properties. +## Technical Details + +The curation data needs to follow the [`CurationModel`](https://spikeinterface.readthedocs.io/en/stable/api.html#curation-model). + +### Message Format (Parent → iframe) +```javascript +{ + type: 'curation-data', + data: { + // Full curation model JSON + } +} +``` + +### Message Format (iframe → Parent) +```javascript +{ + type: 'panel-data', + data: { + // Full curation model JSON + } +} +``` diff --git a/spikeinterface_gui/tracemapview.py b/spikeinterface_gui/tracemapview.py index d5bc07df..2f9e5873 100644 --- a/spikeinterface_gui/tracemapview.py +++ b/spikeinterface_gui/tracemapview.py @@ -310,15 +310,30 @@ def _panel_on_spike_selection_changed(self): self._panel_seek_with_selected_spike() def _panel_gain_zoom(self, event): + import panel as pn + factor_ratio = 1.3 if event.delta > 0 else 1 / 1.3 - self.color_mapper.high = self.color_mapper.high * factor_ratio - self.color_mapper.low = -self.color_mapper.high + new_high = self.color_mapper.high * factor_ratio + new_low = -new_high + + def _do_update(): + self.color_mapper.high = new_high + self.color_mapper.low = new_low + + pn.state.execute(_do_update, schedule=True) def _panel_auto_scale(self, event): + import panel as pn + if self.last_data_curves is not None: self.color_limit = np.max(np.abs(self.last_data_curves)) - self.color_mapper.high = self.color_limit - self.color_mapper.low = -self.color_limit + color_limit = self.color_limit + + def _do_update(): + self.color_mapper.high = color_limit + self.color_mapper.low = -color_limit + + pn.state.execute(_do_update, schedule=True) def _panel_on_time_info_updated(self): # Update segment and time slider range diff --git a/spikeinterface_gui/traceview.py b/spikeinterface_gui/traceview.py index bc4d2166..8a461ffd 100644 --- a/spikeinterface_gui/traceview.py +++ b/spikeinterface_gui/traceview.py @@ -316,6 +316,8 @@ def _panel_on_time_slider_changed(self, event): self.notify_time_info_updated() def _panel_seek_with_selected_spike(self): + import panel as pn + ind_selected = self.controller.get_indices_spike_selected() n_selected = ind_selected.size @@ -340,8 +342,14 @@ def _panel_seek_with_selected_spike(self): # Center view on spike margin = self.xsize / 3 - self.figure.x_range.start = peak_time - margin - self.figure.x_range.end = peak_time + 2 * margin + range_start = peak_time - margin + range_end = peak_time + 2 * margin + + def _do_update(): + self.figure.x_range.start = range_start + self.figure.x_range.end = range_end + + pn.state.execute(_do_update, schedule=True) self._block_auto_refresh_and_notify = False self.refresh() diff --git a/spikeinterface_gui/unitlistview.py b/spikeinterface_gui/unitlistview.py index 5512aadf..466c355c 100644 --- a/spikeinterface_gui/unitlistview.py +++ b/spikeinterface_gui/unitlistview.py @@ -667,11 +667,17 @@ def _panel_on_visible_checkbox_toggled(self, row): self.refresh() def _panel_on_unit_visibility_changed(self): + import panel as pn + # update selection to match visible units visible_units = self.controller.get_visible_unit_ids() unit_ids = list(self.table.value.index.values) rows_to_select = [unit_ids.index(unit_id) for unit_id in visible_units if unit_id in unit_ids] - self.table.selection = rows_to_select + + def _do_update(): + self.table.selection = rows_to_select + + pn.state.execute(_do_update, schedule=True) self.refresh() def _panel_refresh_colors(self): diff --git a/spikeinterface_gui/utils_panel.py b/spikeinterface_gui/utils_panel.py index 7619651e..7a215c49 100644 --- a/spikeinterface_gui/utils_panel.py +++ b/spikeinterface_gui/utils_panel.py @@ -14,6 +14,7 @@ from panel.param import param from panel.custom import ReactComponent from panel.widgets import Tabulator +from panel.reactive import ReactiveHTML from bokeh.models import ColumnDataSource, Patches, HTMLTemplateFormatter @@ -292,7 +293,7 @@ class SelectableTabulator(pn.viewable.Viewer): refresh_table_function: Callable | None A function to call when the table a new selection is made via keyboard shortcuts. on_only_function: Callable | None - A function to call when the table a ctrl+selection is made via keyboard shortcuts or a double-click. + A function to call when a ctrl+selection is made via keyboard shortcuts or a double-click. conditional_shortcut: Callable | None A function that returns True if the shortcuts should be enabled, False otherwise. column_callbacks: dict[Callable] | None @@ -313,6 +314,7 @@ def __init__( self._formatters = kwargs.get("formatters", {}) self._editors = kwargs.get("editors", {}) self._frozen_columns = kwargs.get("frozen_columns", []) + self._selectable = kwargs.get("selectable", True) if "sortable" in kwargs: self._sortable = kwargs.pop("sortable") else: @@ -394,8 +396,9 @@ def selection(self): @selection.setter def selection(self, val): - if isinstance(self.tabulator.selectable, int): - max_selectable = self.tabulator.selectable + # Added this logic to prevent max selection with shift+click / arrows + if isinstance(self._selectable, int): + max_selectable = self._selectable if not isinstance(max_selectable, bool): if len(val) > max_selectable: val = val[-max_selectable:] @@ -437,6 +440,7 @@ def refresh_tabulator_settings(self): self.tabulator.formatters = self._formatters self.tabulator.editors = self._editors self.tabulator.frozen_columns = self._frozen_columns + self.tabulator.selectable = self._selectable self.tabulator.sorters = [] def refresh(self): @@ -485,7 +489,7 @@ def _on_selection_change(self, event): Handle the selection change event. This is called when the selection is changed. """ if self._refresh_table_function is not None: - self._refresh_table_function() + pn.state.execute(self._refresh_table_function, schedule=True) def _on_click(self, event): """ @@ -685,3 +689,58 @@ class KeyboardShortcuts(ReactComponent): return <>; } """ + + +class PostMessageListener(ReactComponent): + """ + Listen to window.postMessage events and forward them to Python via on_msg(). + This avoids ReactiveHTML/Bokeh 'source' linkage issues. + """ + _model_name = "PostMessageListener" + _model_module = "post_message_listener" + _model_module_version = "0.0.1" + + # If set, only forward messages whose event.data.type matches this value. + accept_type = param.String(default="curation-data") + + _esm = """ + export function render({ model }) { + const [accept_type] = model.useState("accept_type"); + + function onMessage(event) { + const data = event.data; + + // Ignore messages from browser extensions + if (data && data.source === "react-devtools-content-script") return; + + if (accept_type && data && data.type !== accept_type) return; + + // Always include a timestamp so repeated sends still look "new" + model.send_msg({ payload: data, _ts: Date.now() }); + } + + React.useEffect(() => { + window.addEventListener("message", onMessage); + return () => window.removeEventListener("message", onMessage); + }, [accept_type]); + + return <>; + } + """ + + +class IFrameDetector(pn.reactive.ReactiveHTML): + """ + Simple component that detects if it is running inside an iframe. + """ + in_iframe = param.Parameter(default=None) + + _template = "
" + + _scripts = { + "render": """ + const val = window.self !== window.top; + console.log("iframe detector JS:", val); + data.in_iframe = val; // reliable sync + """ + } diff --git a/spikeinterface_gui/view_base.py b/spikeinterface_gui/view_base.py index 3fb2507a..52aaebdc 100644 --- a/spikeinterface_gui/view_base.py +++ b/spikeinterface_gui/view_base.py @@ -125,7 +125,8 @@ def _refresh(self, **kwargs): if self.backend == "qt": self._qt_refresh(**kwargs) elif self.backend == "panel": - self._panel_refresh(**kwargs) + import panel as pn + pn.state.execute(lambda: self._panel_refresh(**kwargs), schedule=True) def warning(self, warning_msg): if self.backend == "qt": diff --git a/spikeinterface_gui/waveformheatmapview.py b/spikeinterface_gui/waveformheatmapview.py index 7ea52055..42d0a5ca 100644 --- a/spikeinterface_gui/waveformheatmapview.py +++ b/spikeinterface_gui/waveformheatmapview.py @@ -257,33 +257,45 @@ def _panel_make_layout(self): ) def _panel_refresh(self): + import panel as pn + hist2d = self.get_plotting_data() - if hist2d is None: - self.image_source.data.update({ - "image": [], - "dw": [], - "dh": [] - }) - return + def _do_update(): + if hist2d is None: + self.image_source.data = { + "image": [], + "dw": [], + "dh": [] + } + return + + self.image_source.data = { + "image": [hist2d.T], + "dw": [hist2d.shape[0]], + "dh": [hist2d.shape[1]] + } - self.image_source.data.update({ - "image": [hist2d.T], - "dw": [hist2d.shape[0]], - "dh": [hist2d.shape[1]] - }) + self.color_mapper.low = 0 + self.color_mapper.high = np.max(hist2d) - self.color_mapper.low = 0 - self.color_mapper.high = np.max(hist2d) + self.figure.x_range.start = 0 + self.figure.x_range.end = hist2d.shape[0] + self.figure.y_range.start = 0 + self.figure.y_range.end = hist2d.shape[1] - self.figure.x_range.start = 0 - self.figure.x_range.end = hist2d.shape[0] - self.figure.y_range.start = 0 - self.figure.y_range.end = hist2d.shape[1] + pn.state.execute(_do_update, schedule=True) def _panel_gain_zoom(self, event): + import panel as pn + factor = 1.3 if event.delta > 0 else 1 / 1.3 - self.color_mapper.high = self.color_mapper.high * factor + new_high = self.color_mapper.high * factor + + def _do_update(): + self.color_mapper.high = new_high + + pn.state.execute(_do_update, schedule=True) diff --git a/spikeinterface_gui/waveformview.py b/spikeinterface_gui/waveformview.py index d53e7497..38f07e9e 100644 --- a/spikeinterface_gui/waveformview.py +++ b/spikeinterface_gui/waveformview.py @@ -1022,7 +1022,8 @@ def _panel_on_mode_selector_changed(self, event): self.refresh() def _panel_gain_zoom(self, event): - self.figure_geom.toolbar.active_scroll = None + import panel as pn + current_time = time.perf_counter() if self.last_wheel_event_time is not None: time_elapsed = current_time - self.last_wheel_event_time @@ -1030,8 +1031,14 @@ def _panel_gain_zoom(self, event): time_elapsed = 1000 if time_elapsed > _wheel_refresh_time: modifiers = event.modifiers - if modifiers["shift"] and modifiers["alt"]: + + def _enable_active_scroll(): + self.figure_geom.toolbar.active_scroll = self.zoom_tool + + def _disable_active_scroll(): self.figure_geom.toolbar.active_scroll = None + + if modifiers["shift"] and modifiers["alt"]: if self.mode == "geometry": factor_ratio = 1.3 if event.delta > 0 else 1 / 1.3 # adjust y range and keep center @@ -1040,29 +1047,36 @@ def _panel_gain_zoom(self, event): yrange = ymax - ymin ymid = 0.5 * (ymin + ymax) new_yrange = yrange * factor_ratio - ymin = ymid - new_yrange / 2. - ymax = ymid + new_yrange / 2. - self.figure_geom.y_range.start = ymin - self.figure_geom.y_range.end = ymax + new_ymin = ymid - new_yrange / 2. + new_ymax = ymid + new_yrange / 2. + + def _do_range_update(): + self.figure_geom.toolbar.active_scroll = None + self.figure_geom.y_range.start = new_ymin + self.figure_geom.y_range.end = new_ymax + + pn.state.execute(_do_range_update, schedule=True) + else: + pn.state.execute(_disable_active_scroll, schedule=True) elif modifiers["shift"]: - self.figure_geom.toolbar.active_scroll = self.zoom_tool + pn.state.execute(_enable_active_scroll, schedule=True) elif modifiers["alt"]: - self.figure_geom.toolbar.active_scroll = None if self.mode == "geometry": factor = 1.3 if event.delta > 0 else 1 / 1.3 self.factor_x *= factor self._panel_refresh_mode_geometry(keep_range=True) self._panel_refresh_spikes() + pn.state.execute(_disable_active_scroll, schedule=True) elif not modifiers["ctrl"]: - self.figure_geom.toolbar.active_scroll = None if self.mode == "geometry": factor = 1.3 if event.delta > 0 else 1 / 1.3 self.gain_y *= factor self._panel_refresh_mode_geometry(keep_range=True) self._panel_refresh_spikes() + pn.state.execute(_disable_active_scroll, schedule=True) else: # Ignore the event if it occurs too quickly - self.figure_geom.toolbar.active_scroll = None + pn.state.execute(_disable_active_scroll, schedule=True) self.last_wheel_event_time = current_time def _panel_refresh_mode_geometry(self, dict_visible_units=None, keep_range=False): @@ -1388,11 +1402,14 @@ def _panel_clear_data_sources(self): self.vlines_data_source_std.data = dict(xs=[], ys=[], colors=[]) def _panel_on_spike_selection_changed(self): - self._panel_refresh_one_spike() + import panel as pn + pn.state.execute(self._panel_refresh_one_spike, schedule=True) def _panel_on_channel_visibility_changed(self): + import panel as pn + keep_range = not self.settings["auto_move_on_unit_selection"] - self._panel_refresh(keep_range=keep_range) + pn.state.execute(lambda: self._panel_refresh(keep_range=keep_range), schedule=True) def _panel_handle_shortcut(self, event): if event.data == "overlap":