diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py
index 2085296e..b99823d6 100644
--- a/spikeinterface_gui/basescatterview.py
+++ b/spikeinterface_gui/basescatterview.py
@@ -314,7 +314,7 @@ def _qt_refresh(self, set_scatter_range=False):
# set x range to time range of the current segment for scatter, and max count for histogram
# set y range to min and max of visible spike amplitudes
- if set_scatter_range or not self._first_refresh_done:
+ if len(ymins) > 0 and (set_scatter_range or not self._first_refresh_done):
ymin = np.min(ymins)
ymax = np.max(ymaxs)
t_start, t_stop = self.controller.get_t_start_t_stop()
@@ -490,6 +490,7 @@ def _panel_make_layout(self):
self.plotted_inds = []
def _panel_refresh(self, set_scatter_range=False):
+ import panel as pn
from bokeh.models import FixedTicker
self.plotted_inds = []
@@ -555,7 +556,8 @@ def _panel_refresh(self, set_scatter_range=False):
# handle selected spikes
self._panel_update_selected_spikes()
- # set y range to min and max of visible spike amplitudes
+ # Defer Range updates to avoid nested document lock issues
+ # def update_ranges():
if set_scatter_range or not self._first_refresh_done:
self.y_range.start = np.min(ymins)
self.y_range.end = np.max(ymaxs)
@@ -563,20 +565,36 @@ def _panel_refresh(self, set_scatter_range=False):
self.hist_fig.x_range.end = max_count
self.hist_fig.xaxis.ticker = FixedTicker(ticks=[0, max_count // 2, max_count])
+ # Schedule the update to run after the current event loop iteration
+ # pn.state.execute(update_ranges, schedule=True)
+
def _panel_on_select_button(self, event):
- if self.select_toggle_button.value:
- self.scatter_fig.toolbar.active_drag = self.lasso_tool
- else:
- self.scatter_fig.toolbar.active_drag = None
- self.scatter_source.selected.indices = []
+ import panel as pn
+
+ value = self.select_toggle_button.value
+
+ def _do_update():
+ if value:
+ self.scatter_fig.toolbar.active_drag = self.lasso_tool
+ else:
+ self.scatter_fig.toolbar.active_drag = None
+ self.scatter_source.selected.indices = []
+
+ pn.state.execute(_do_update, schedule=True)
def _panel_change_segment(self, event):
+ import panel as pn
+
self._current_selected = 0
segment_index = int(self.segment_selector.value.split()[-1])
self.controller.set_time(segment_index=segment_index)
t_start, t_end = self.controller.get_t_start_t_stop()
- self.scatter_fig.x_range.start = t_start
- self.scatter_fig.x_range.end = t_end
+
+ def _do_update():
+ self.scatter_fig.x_range.start = t_start
+ self.scatter_fig.x_range.end = t_end
+
+ pn.state.execute(_do_update, schedule=True)
self.refresh(set_scatter_range=True)
self.notify_time_info_updated()
@@ -618,9 +636,17 @@ def _panel_split(self, event):
self.split()
def _panel_update_selected_spikes(self):
+ import panel as pn
+
# handle selected spikes
selected_spike_indices = self.controller.get_indices_spike_selected()
selected_spike_indices = np.intersect1d(selected_spike_indices, self.plotted_inds)
+ if len(selected_spike_indices) == 1:
+ selected_segment = self.controller.spikes[selected_spike_indices[0]]['segment_index']
+ segment_index = self.controller.get_time()[1]
+ if selected_segment != segment_index:
+ self.segment_selector.value = f"Segment {selected_segment}"
+ self._panel_change_segment(None)
if len(selected_spike_indices) > 0:
# map absolute indices to visible spikes
segment_index = self.controller.get_time()[1]
@@ -634,23 +660,16 @@ def _panel_update_selected_spikes(self):
# set selected spikes in scatter plot
if self.settings["auto_decimate"] and len(selected_indices) > 0:
selected_indices, = np.nonzero(np.isin(self.plotted_inds, selected_spike_indices))
- self.scatter_source.selected.indices = list(selected_indices)
else:
- self.scatter_source.selected.indices = []
+ selected_indices = []
+
+ def _do_update():
+ self.scatter_source.selected.indices = list(selected_indices)
+
+ pn.state.execute(_do_update, schedule=True)
def _panel_on_spike_selection_changed(self):
- # set selection in scatter plot
- selected_indices = self.controller.get_indices_spike_selected()
- if len(selected_indices) == 0:
- self.scatter_source.selected.indices = []
- return
- elif len(selected_indices) == 1:
- selected_segment = self.controller.spikes[selected_indices[0]]['segment_index']
- segment_index = self.controller.get_time()[1]
- if selected_segment != segment_index:
- self.segment_selector.value = f"Segment {selected_segment}"
- self._panel_change_segment(None)
- # update selected spikes
+ # update selected spikes (scheduled via pn.state.execute inside)
self._panel_update_selected_spikes()
def _panel_handle_shortcut(self, event):
diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py
index 3bb7eb94..41b63981 100644
--- a/spikeinterface_gui/controller.py
+++ b/spikeinterface_gui/controller.py
@@ -9,10 +9,7 @@
from spikeinterface.widgets.utils import get_unit_colors
from spikeinterface import compute_sparsity
from spikeinterface.core import get_template_extremum_channel
-import spikeinterface.postprocessing
-import spikeinterface.qualitymetrics
from spikeinterface.core.sorting_tools import spike_vector_to_indices
-from spikeinterface.core.core_tools import check_json
from spikeinterface.curation import validate_curation_dict
from spikeinterface.curation.curation_model import CurationModel
from spikeinterface.widgets.utils import make_units_table_from_analyzer
@@ -345,6 +342,7 @@ def __init__(
if curation_data is not None:
# validate the curation data
+ curation_data = deepcopy(curation_data)
format_version = curation_data.get("format_version", None)
# assume version 2 if not present
if format_version is None:
@@ -354,24 +352,6 @@ def __init__(
except Exception as e:
raise ValueError(f"Invalid curation data.\nError: {e}")
- if curation_data.get("merges") is None:
- curation_data["merges"] = []
- else:
- # here we reset the merges for better formatting (str)
- existing_merges = curation_data["merges"]
- new_merges = []
- for m in existing_merges:
- if "unit_ids" not in m:
- continue
- if len(m["unit_ids"]) < 2:
- continue
- new_merges = add_merge(new_merges, m["unit_ids"])
- curation_data["merges"] = new_merges
- if curation_data.get("splits") is None:
- curation_data["splits"] = []
- if curation_data.get("removed") is None:
- curation_data["removed"] = []
-
elif self.analyzer.format == "binary_folder":
json_file = self.analyzer.folder / "spikeinterface_gui" / "curation_data.json"
if json_file.exists():
@@ -386,24 +366,27 @@ def __init__(
if curation_data is None:
curation_data = deepcopy(empty_curation_data)
+ curation_data["unit_ids"] = self.unit_ids.tolist()
- self.curation_data = curation_data
-
- self.has_default_quality_labels = False
- if "label_definitions" not in self.curation_data:
+ if "label_definitions" not in curation_data:
if label_definitions is not None:
- self.curation_data["label_definitions"] = label_definitions
+ curation_data["label_definitions"] = label_definitions
else:
- self.curation_data["label_definitions"] = default_label_definitions.copy()
+ curation_data["label_definitions"] = default_label_definitions.copy()
- if "quality" in self.curation_data["label_definitions"]:
- curation_dict_quality_labels = self.curation_data["label_definitions"]["quality"]["label_options"]
+ # This will enable the default shortcuts if has default quality labels
+ self.has_default_quality_labels = False
+ if "quality" in curation_data["label_definitions"]:
+ curation_dict_quality_labels = curation_data["label_definitions"]["quality"]["label_options"]
default_quality_labels = default_label_definitions["quality"]["label_options"]
if set(curation_dict_quality_labels) == set(default_quality_labels):
if self.verbose:
print('Curation quality labels are the default ones')
self.has_default_quality_labels = True
+ curation_data = CurationModel(**curation_data).model_dump()
+ self.curation_data = curation_data
+
def check_is_view_possible(self, view_name):
from .viewlist import get_all_possible_views
possible_class_views = get_all_possible_views()
@@ -825,7 +808,24 @@ def compute_auto_merge(self, **params):
)
return merge_unit_groups, extra
-
+
+ def set_curation_data(self, curation_data):
+ print("Setting curation data")
+ new_curation_data = empty_curation_data.copy()
+ new_curation_data.update(curation_data)
+
+ if "unit_ids" not in curation_data:
+ print("Setting unit_ids from controller")
+ new_curation_data["unit_ids"] = self.unit_ids.tolist()
+
+ if "label_definitions" not in curation_data:
+ print("Setting default label definitions")
+ new_curation_data["label_definitions"] = default_label_definitions.copy()
+
+ # validate the curation data
+ model = CurationModel(**new_curation_data)
+ self.curation_data = model.model_dump()
+
def curation_can_be_saved(self):
return self.analyzer.format != "memory"
diff --git a/spikeinterface_gui/curation_tools.py b/spikeinterface_gui/curation_tools.py
index 42b7b06b..1aff9fcd 100644
--- a/spikeinterface_gui/curation_tools.py
+++ b/spikeinterface_gui/curation_tools.py
@@ -38,3 +38,32 @@ def add_merge(previous_merges, new_merge_unit_ids):
# Ensure the uniqueness
new_merges = [{"unit_ids": list(set(gp))} for gp in new_merge_units]
return new_merges
+
+
+# def cast_unit_dtypes_in_curation(curation_data, unit_ids_dtype):
+# """Cast unit ids in curation data to the correct dtype."""
+# if "unit_ids" in curation_data:
+# curation_data["unit_ids"] = [unit_ids_dtype(uid) for uid in curation_data["unit_ids"]]
+
+# if "merges" in curation_data:
+# for merge in curation_data["merges"]:
+# merge["unit_ids"] = [unit_ids_dtype(uid) for uid in merge["unit_ids"]]
+# new_unit_id = merge.get("new_unit_id", None)
+# if new_unit_id is not None:
+# merge["new_unit_id"] = unit_ids_dtype(new_unit_id)
+
+# if "splits" in curation_data:
+# for split in curation_data["splits"]:
+# split["unit_id"] = unit_ids_dtype(split["unit_id"])
+# new_unit_ids = split.get("new_unit_ids", None)
+# if new_unit_ids is not None:
+# split["new_unit_ids"] = [unit_ids_dtype(uid) for uid in new_unit_ids]
+
+# if "removed" in curation_data:
+# curation_data["removed"] = [unit_ids_dtype(uid) for uid in curation_data["removed"]]
+
+# if "manual_labels" in curation_data:
+# for label_entry in curation_data["manual_labels"]:
+# label_entry["unit_id"] = unit_ids_dtype(label_entry["unit_id"])
+
+# return curation_data
\ No newline at end of file
diff --git a/spikeinterface_gui/curationview.py b/spikeinterface_gui/curationview.py
index 5394e72c..53f01786 100644
--- a/spikeinterface_gui/curationview.py
+++ b/spikeinterface_gui/curationview.py
@@ -6,20 +6,18 @@
from spikeinterface.core.core_tools import check_json
-
-
class CurationView(ViewBase):
id = "curation"
- _supported_backend = ['qt', 'panel']
+ _supported_backend = ["qt", "panel"]
_need_compute = False
def __init__(self, controller=None, parent=None, backend="qt"):
self.active_table = "merge"
- ViewBase.__init__(self, controller=controller, parent=parent, backend=backend)
+ ViewBase.__init__(self, controller=controller, parent=parent, backend=backend)
# TODO: Cast unit ids to the correct type here
def restore_units(self):
- if self.backend == 'qt':
+ if self.backend == "qt":
unit_ids = self._qt_get_delete_table_selection()
else:
unit_ids = self._panel_get_delete_table_selection()
@@ -30,7 +28,7 @@ def restore_units(self):
self.refresh()
def unmerge(self):
- if self.backend == 'qt':
+ if self.backend == "qt":
merge_indices = self._qt_get_merge_table_row()
else:
merge_indices = self._panel_get_merge_table_row()
@@ -40,7 +38,7 @@ def unmerge(self):
self.refresh()
def unsplit(self):
- if self.backend == 'qt':
+ if self.backend == "qt":
split_indices = self._qt_get_split_table_row()
else:
split_indices = self._panel_get_split_table_row()
@@ -55,8 +53,8 @@ def select_and_notify_split(self, split_unit_id):
self.controller.set_visible_unit_ids([split_unit_id])
self.notify_unit_visibility_changed()
spike_inds = self.controller.get_spike_indices(split_unit_id, segment_index=None)
- active_split = [s for s in self.controller.curation_data['splits'] if s['unit_id'] == split_unit_id][0]
- split_indices = active_split['indices'][0]
+ active_split = [s for s in self.controller.curation_data["splits"] if s["unit_id"] == split_unit_id][0]
+ split_indices = active_split["indices"][0]
self.controller.set_indices_spike_selected(spike_inds[split_indices])
self.notify_spike_selection_changed()
@@ -73,25 +71,26 @@ def _qt_make_layout(self):
but = QT.QPushButton("Save in analyzer")
tb.addWidget(but)
but.clicked.connect(self.save_in_analyzer)
+
but = QT.QPushButton("Export JSON")
- but.clicked.connect(self._qt_export_json)
+ but.clicked.connect(self._qt_export_json)
tb.addWidget(but)
h = QT.QHBoxLayout()
self.layout.addLayout(h)
-
v = QT.QVBoxLayout()
h.addLayout(v)
- self.table_delete = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection,
- selectionBehavior=QT.QAbstractItemView.SelectRows)
+ self.table_delete = QT.QTableWidget(
+ selectionMode=QT.QAbstractItemView.SingleSelection, selectionBehavior=QT.QAbstractItemView.SelectRows
+ )
v.addWidget(self.table_delete)
self.table_delete.setContextMenuPolicy(QT.Qt.CustomContextMenu)
self.table_delete.customContextMenuRequested.connect(self._qt_open_context_menu_delete)
self.table_delete.itemSelectionChanged.connect(self._qt_on_item_selection_changed_delete)
self.delete_menu = QT.QMenu()
- act = self.delete_menu.addAction('Restore')
+ act = self.delete_menu.addAction("Restore")
act.triggered.connect(self.restore_units)
shortcut_restore = QT.QShortcut(self.qt_widget)
shortcut_restore.setKey(QT.QKeySequence("ctrl+r"))
@@ -99,8 +98,9 @@ def _qt_make_layout(self):
v = QT.QVBoxLayout()
h.addLayout(v)
- self.table_merge = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection,
- selectionBehavior=QT.QAbstractItemView.SelectRows)
+ self.table_merge = QT.QTableWidget(
+ selectionMode=QT.QAbstractItemView.SingleSelection, selectionBehavior=QT.QAbstractItemView.SelectRows
+ )
# self.table_merge.setContextMenuPolicy(QT.Qt.CustomContextMenu)
v.addWidget(self.table_merge)
@@ -109,7 +109,7 @@ def _qt_make_layout(self):
self.table_merge.itemSelectionChanged.connect(self._qt_on_item_selection_changed_merge)
self.merge_menu = QT.QMenu()
- act = self.merge_menu.addAction('Remove merge')
+ act = self.merge_menu.addAction("Remove merge")
act.triggered.connect(self.unmerge)
shortcut_unmerge = QT.QShortcut(self.qt_widget)
shortcut_unmerge.setKey(QT.QKeySequence("ctrl+u"))
@@ -117,14 +117,15 @@ def _qt_make_layout(self):
v = QT.QVBoxLayout()
h.addLayout(v)
- self.table_split = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection,
- selectionBehavior=QT.QAbstractItemView.SelectRows)
+ self.table_split = QT.QTableWidget(
+ selectionMode=QT.QAbstractItemView.SingleSelection, selectionBehavior=QT.QAbstractItemView.SelectRows
+ )
v.addWidget(self.table_split)
self.table_split.setContextMenuPolicy(QT.Qt.CustomContextMenu)
self.table_split.customContextMenuRequested.connect(self._qt_open_context_menu_split)
self.table_split.itemSelectionChanged.connect(self._qt_on_item_selection_changed_split)
self.split_menu = QT.QMenu()
- act = self.split_menu.addAction('Remove split')
+ act = self.split_menu.addAction("Remove split")
act.triggered.connect(self.unsplit)
shortcut_unsplit = QT.QShortcut(self.qt_widget)
shortcut_unsplit.setKey(QT.QKeySequence("ctrl+x"))
@@ -132,6 +133,7 @@ def _qt_make_layout(self):
def _qt_refresh(self):
from .myqt import QT
+
# Merged
merged_units = [m["unit_ids"] for m in self.controller.curation_data["merges"]]
self.table_merge.clear()
@@ -141,12 +143,12 @@ def _qt_refresh(self):
self.table_merge.setSortingEnabled(False)
for ix, group in enumerate(merged_units):
item = QT.QTableWidgetItem(str(group))
- item.setFlags(QT.Qt.ItemIsEnabled|QT.Qt.ItemIsSelectable)
+ item.setFlags(QT.Qt.ItemIsEnabled | QT.Qt.ItemIsSelectable)
self.table_merge.setItem(ix, 0, item)
for i in range(self.table_merge.columnCount()):
self.table_merge.resizeColumnToContents(i)
- # Removed
+ # Removed
removed_units = self.controller.curation_data["removed"]
self.table_delete.clear()
self.table_delete.setRowCount(len(removed_units))
@@ -155,12 +157,12 @@ def _qt_refresh(self):
self.table_delete.setSortingEnabled(False)
for i, unit_id in enumerate(removed_units):
color = self.get_unit_color(unit_id)
- pix = QT.QPixmap(16,16)
+ pix = QT.QPixmap(16, 16)
pix.fill(color)
icon = QT.QIcon(pix)
- item = QT.QTableWidgetItem( f'{unit_id}')
- item.setFlags(QT.Qt.ItemIsEnabled|QT.Qt.ItemIsSelectable)
- self.table_delete.setItem(i,0, item)
+ item = QT.QTableWidgetItem(f"{unit_id}")
+ item.setFlags(QT.Qt.ItemIsEnabled | QT.Qt.ItemIsSelectable)
+ self.table_delete.setItem(i, 0, item)
item.setIcon(icon)
item.unit_id = unit_id
self.table_delete.resizeColumnToContents(0)
@@ -178,13 +180,11 @@ def _qt_refresh(self):
num_spikes = self.controller.num_spikes[unit_id]
num_splits = f"({num_indices}-{num_spikes - num_indices})"
item = QT.QTableWidgetItem(f"{unit_id} {num_splits}")
- item.setFlags(QT.Qt.ItemIsEnabled|QT.Qt.ItemIsSelectable)
+ item.setFlags(QT.Qt.ItemIsEnabled | QT.Qt.ItemIsSelectable)
self.table_split.setItem(i, 0, item)
item.unit_id = unit_id
self.table_split.resizeColumnToContents(0)
-
-
def _qt_get_delete_table_selection(self):
selected_items = self.table_delete.selectedItems()
if len(selected_items) == 0:
@@ -198,7 +198,7 @@ def _qt_get_merge_table_row(self):
return None
else:
return [s.row() for s in selected_items]
-
+
def _qt_get_split_table_row(self):
selected_items = self.table_split.selectedItems()
if len(selected_items) == 0:
@@ -214,7 +214,7 @@ def _qt_open_context_menu_merge(self):
def _qt_open_context_menu_split(self):
self.split_menu.popup(self.qt_widget.cursor().pos())
-
+
def _qt_on_item_selection_changed_merge(self):
if len(self.table_merge.selectedIndexes()) == 0:
return
@@ -274,15 +274,16 @@ def _qt_on_unit_visibility_changed(self):
def on_manual_curation_updated(self):
self.refresh()
-
+
def save_in_analyzer(self):
self.controller.save_curation_in_analyzer()
def _qt_export_json(self):
from .myqt import QT
+
fd = QT.QFileDialog(fileMode=QT.QFileDialog.AnyFile, acceptMode=QT.QFileDialog.AcceptSave)
- fd.setNameFilters(['JSON (*.json);'])
- fd.setDefaultSuffix('json')
+ fd.setNameFilters(["JSON (*.json);"])
+ fd.setDefaultSuffix("json")
fd.setViewMode(QT.QFileDialog.Detail)
if fd.exec_():
json_file = Path(fd.selectedFiles()[0])
@@ -296,14 +297,17 @@ def _panel_make_layout(self):
import pandas as pd
import panel as pn
- from .utils_panel import KeyboardShortcut, KeyboardShortcuts, SelectableTabulator
+ from .utils_panel import KeyboardShortcut, KeyboardShortcuts, SelectableTabulator, PostMessageListener, IFrameDetector
pn.extension("tabulator")
+ # Initialize listenet_pane as None
+ self.listener_pane = None
+
# Create dataframe
delete_df = pd.DataFrame({"removed": []})
merge_df = pd.DataFrame({"merges": []})
- split_df = pd.DataFrame({"splits": []})
+ split_df = pd.DataFrame({"splits": []})
# Create tables
self.table_delete = SelectableTabulator(
@@ -311,11 +315,11 @@ def _panel_make_layout(self):
show_index=False,
disabled=True,
sortable=False,
+ selectable=True,
formatters={"removed": "plaintext"},
sizing_mode="stretch_width",
# SelectableTabulator functions
parent_view=self,
- # refresh_table_function=self.refresh,
conditional_shortcut=self._conditional_refresh_delete,
column_callbacks={"removed": self._panel_on_deleted_col},
)
@@ -329,7 +333,6 @@ def _panel_make_layout(self):
sizing_mode="stretch_width",
# SelectableTabulator functions
parent_view=self,
- # refresh_table_function=self.refresh,
conditional_shortcut=self._conditional_refresh_merge,
column_callbacks={"merges": self._panel_on_merged_col},
)
@@ -347,59 +350,43 @@ def _panel_make_layout(self):
column_callbacks={"splits": self._panel_on_split_col},
)
- self.table_delete.param.watch(self._panel_update_unit_visibility, "selection")
- self.table_merge.param.watch(self._panel_update_unit_visibility, "selection")
- self.table_split.param.watch(self._panel_update_unit_visibility, "selection")
+ # Watch selection changes instead of calling from column callbacks
+ self.table_delete.param.watch(self._panel_on_table_selection_changed, "selection")
+ self.table_merge.param.watch(self._panel_on_table_selection_changed, "selection")
+ self.table_split.param.watch(self._panel_on_table_selection_changed, "selection")
# Create buttons
- save_button = pn.widgets.Button(
- name="Save in analyzer",
- button_type="primary",
- height=30
- )
- save_button.on_click(self._panel_save_in_analyzer)
+ buttons_row = []
+ self.save_button = None
+ if self.controller.curation_can_be_saved():
+ self.save_button = pn.widgets.Button(name="Save in analyzer", button_type="primary", height=30)
+ self.save_button.on_click(self._panel_save_in_analyzer)
+ buttons_row.append(self.save_button)
- download_button = pn.widgets.FileDownload(
- button_type="primary",
- filename="curation.json",
- callback=self._panel_generate_json,
- height=30
+ self.download_button = pn.widgets.FileDownload(
+ button_type="primary", filename="curation.json", callback=self._panel_generate_json, height=30
)
+ buttons_row.append(self.download_button)
- restore_button = pn.widgets.Button(
- name="Restore",
- button_type="primary",
- height=30
- )
+ restore_button = pn.widgets.Button(name="Restore", button_type="primary", height=30)
restore_button.on_click(self._panel_restore_units)
- remove_merge_button = pn.widgets.Button(
- name="Unmerge",
- button_type="primary",
- height=30
- )
+ remove_merge_button = pn.widgets.Button(name="Unmerge", button_type="primary", height=30)
remove_merge_button.on_click(self._panel_unmerge)
- submit_button = pn.widgets.Button(
- name="Submit to parent",
- button_type="primary",
- height=30
- )
+ remove_split = pn.widgets.Button(name="Unsplit", button_type="primary", height=30)
+ remove_split.on_click(self._panel_unsplit)
# Create layout
- buttons_save = pn.Row(
- save_button,
- download_button,
- submit_button,
- sizing_mode="stretch_width",
- )
- save_sections = pn.Column(
- buttons_save,
+ self.buttons_save = pn.Row(
+ *buttons_row,
sizing_mode="stretch_width",
)
+
buttons_curate = pn.Row(
restore_button,
remove_merge_button,
+ remove_split,
sizing_mode="stretch_width",
)
@@ -413,29 +400,19 @@ def _panel_make_layout(self):
shortcuts_component.on_msg(self._panel_handle_shortcut)
# Create main layout with proper sizing
- sections = pn.Row(self.table_delete, self.table_merge, self.table_split,
- sizing_mode="stretch_width")
+ sections = pn.Row(self.table_delete, self.table_merge, self.table_split, sizing_mode="stretch_width")
self.layout = pn.Column(
- save_sections,
- buttons_curate,
- sections,
- shortcuts_component,
- scroll=True,
- sizing_mode="stretch_both"
+ self.buttons_save, buttons_curate, sections, shortcuts_component, scroll=True, sizing_mode="stretch_both"
)
- # Add a custom JavaScript callback to the button that doesn't interact with Bokeh models
- submit_button.on_click(self._panel_submit_to_parent)
-
- # Add a hidden div to store the data
- self.data_div = pn.pane.HTML("", width=0, height=0, margin=0, sizing_mode="fixed")
- self.layout.append(self.data_div)
-
+ self.iframe_detector = IFrameDetector()
+ self.iframe_detector.param.watch(self._panel_on_iframe_change, "in_iframe")
+ self.layout.append(self.iframe_detector)
def _panel_refresh(self):
import pandas as pd
- ## deleted
+ ## deleted
removed_units = self.controller.curation_data["removed"]
removed_units = [str(unit_id) for unit_id in removed_units]
df = pd.DataFrame({"removed": removed_units})
@@ -473,7 +450,7 @@ def _panel_refresh(self):
def ensure_save_warning_message(self):
- if self.layout[0].name == 'curation_save_warning':
+ if self.layout[0].name == "curation_save_warning":
return
import panel as pn
@@ -482,13 +459,13 @@ def ensure_save_warning_message(self):
f"""⚠️⚠️⚠️ Your curation is not saved""",
hard_line_break=True,
styles={"color": "red", "font-size": "16px"},
- name="curation_save_warning"
+ name="curation_save_warning",
)
self.layout.insert(0, alert_markdown)
def ensure_no_message(self):
- if self.layout[0].name == 'curation_save_warning':
+ if self.layout[0].name == "curation_save_warning":
self.layout.pop(0)
def _panel_update_unit_visibility(self, event):
@@ -518,13 +495,16 @@ def _panel_restore_units(self, event):
def _panel_unmerge(self, event):
self.unmerge()
+ def _panel_unsplit(self, event):
+ self.unsplit()
+
def _panel_save_in_analyzer(self, event):
self.save_in_analyzer()
self.refresh()
def _panel_generate_json(self):
# Get the path from the text input
- export_path = "curation.json"
+ export_path = Path("curation.json")
# Save the JSON file
curation_model = self.controller.construct_final_curation()
with export_path.open("w") as f:
@@ -536,40 +516,119 @@ def _panel_generate_json(self):
return export_path
- def _panel_submit_to_parent(self, event):
+ def _panel_submit_to_parent(self, event):
"""Send the curation data to the parent window"""
+ import time
+
# Get the curation data and convert it to a JSON string
curation_model = self.controller.construct_final_curation()
+ curation_data = curation_model.model_dump_json()
+ # Trigger the JavaScript function via the TextInput
+ # Update the value to trigger the jscallback
+ self.submit_trigger.value = curation_data + f"_{int(time.time() * 1000)}"
- # Create a JavaScript snippet that will send the data to the parent window
- js_code = f"""
-
- """
-
- # Update the hidden div with the JavaScript code
- self.data_div.object = js_code
# Submitting to parent is a way to "save" the curation (the parent can handle it)
self.controller.current_curation_saved = True
+ self.ensure_no_message()
+ print(f"Curation data sent to parent app!")
+
+ def _panel_set_curation_data(self, event):
+ """
+ Handler for PostMessageListener.on_msg.
+
+ event.data is whatever the JS side passed to model.send_msg(...).
+ Expected shape:
+ {
+ "payload": {"type": "curation-data", "data": },
+ }
+ """
+ msg = event.data
+ payload = (msg or {}).get("payload", {})
+ curation_data = payload.get("data", None)
+
+ if curation_data is None:
+ print("Received message without curation data:", msg)
+ return
+
+ # Optional: validate basic structure
+ if not isinstance(curation_data, dict):
+ print("Invalid curation_data type:", type(curation_data), curation_data)
+ return
+
+ self.controller.set_curation_data(curation_data)
self.refresh()
+ def _panel_on_iframe_change(self, event):
+ import panel as pn
+ from .utils_panel import PostMessageListener
+
+ in_iframe = event.new
+ print(f"CurationView detected iframe mode: {in_iframe}")
+ if in_iframe:
+ # Remove save in analyzer button and add submit to parent button
+ self.submit_button = pn.widgets.Button(name="Submit to parent", button_type="primary", height=30)
+ self.submit_button.on_click(self._panel_submit_to_parent)
+
+ self.buttons_save = pn.Row(
+ self.submit_button,
+ self.download_button,
+ sizing_mode="stretch_width",
+ )
+ self.layout[0] = self.buttons_save
+
+ # Create objects to submit and listen
+ self.submit_trigger = pn.widgets.TextInput(value="", visible=False)
+ # Add JavaScript callback that triggers when the TextInput value changes
+ self.submit_trigger.jscallback(
+ value="""
+ // Extract just the JSON data (remove timestamp suffix)
+ const fullValue = cb_obj.value;
+ const lastUnderscore = fullValue.lastIndexOf('_');
+ const dataStr = lastUnderscore > 0 ? fullValue.substring(0, lastUnderscore) : fullValue;
+
+ if (dataStr && dataStr.length > 0) {
+ try {
+ const data = JSON.parse(dataStr);
+ console.log('Sending data to parent:', data);
+ parent.postMessage({
+ type: 'panel-data',
+ data: data
+ },
+ '*');
+ console.log('Data sent successfully to parent window');
+ } catch (error) {
+ console.error('Error sending data to parent:', error);
+ }
+ }
+ """
+ )
+ self.layout.append(self.submit_trigger)
+ # Set up listener for external curation changes
+ self.listener = PostMessageListener()
+ self.listener.on_msg(self._panel_set_curation_data)
+ self.layout.append(self.listener)
+
+ # Notify parent that panel is ready to receive messages
+ self.ready_trigger = pn.pane.HTML(
+ """
+
+ """
+ )
+ self.layout.append(self.ready_trigger)
+
def _panel_get_delete_table_selection(self):
selected_items = self.table_delete.selection
if len(selected_items) == 0:
@@ -606,22 +665,46 @@ def _panel_on_unit_visibility_changed(self):
def _panel_on_deleted_col(self, row):
self.active_table = "delete"
- self.table_merge.selection = []
- self.table_split.selection = []
def _panel_on_merged_col(self, row):
self.active_table = "merge"
- self.table_delete.selection = []
- self.table_split.selection = []
def _panel_on_split_col(self, row):
self.active_table = "split"
- self.table_delete.selection = []
- self.table_merge.selection = []
- # set split selection
- split_unit_str = self.table_split.value["splits"].values[row]
- split_unit_id = self.controller.unit_ids.dtype.type(split_unit_str.split(" ")[0])
- self.select_and_notify_split(split_unit_id)
+
+ def _panel_on_table_selection_changed(self, event):
+ """
+ Unified handler for all table selection changes.
+ Determines which table was changed and updates visibility accordingly.
+ """
+ import panel as pn
+
+ print(f"Selection changed in table: {self.active_table}")
+ # Determine which table triggered the change
+ if self.active_table == "delete" and len(self.table_delete.selection) > 0:
+ self.table_merge.selection = []
+ self.table_split.selection = []
+ # Defer to avoid nested Bokeh callbacks (especially from ctrl+click)
+ pn.state.execute(lambda: self._panel_update_unit_visibility(None), schedule=True)
+ elif self.active_table == "merge" and len(self.table_merge.selection) > 0:
+ self.table_delete.selection = []
+ self.table_split.selection = []
+ # Defer to avoid nested Bokeh callbacks (especially from ctrl+click)
+ pn.state.execute(lambda: self._panel_update_unit_visibility(None), schedule=True)
+ elif self.active_table == "split" and len(self.table_split.selection) > 0:
+ self.table_delete.selection = []
+ self.table_merge.selection = []
+ # Handle split specially (sets selected spikes) - also deferred
+ def handle_split():
+ split_unit_str = self.table_split.value["splits"].values[self.table_split.selection[0]]
+ split_unit_id = self.controller.unit_ids.dtype.type(split_unit_str.split(" ")[0])
+ self.select_and_notify_split(split_unit_id)
+ # Defer to avoid nested Bokeh callbacks (especially from ctrl+click)
+ pn.state.execute(handle_split, schedule=True)
+ elif len(self.table_delete.selection) == 0 and len(self.table_merge.selection) == 0 and len(self.table_split.selection) == 0:
+ # All tables are cleared
+ self.active_table = None
+
def _conditional_refresh_merge(self):
# Check if the view is active before refreshing
@@ -656,8 +739,22 @@ def _conditional_refresh_split(self):
- **export/download JSON**: Export the current curation state to a JSON file.
- **restore**: Restore the selected unit from the deleted units table.
- **unmerge**: Unmerge the selected merges from the merged units table.
+- **unsplit**: Unsplit the selected split groups from the split units table.
- **submit to parent**: Submit the current curation state to the parent window (for use in web applications).
- **press 'ctrl+r'**: Restore the selected units from the deleted units table.
- **press 'ctrl+u'**: Unmerge the selected merges from the merged units table.
- **press 'ctrl+x'**: Unsplit the selected split groups from the split units table.
+
+### Note
+When setting the `iframe_mode` setting to `True` using the `user_settings=dict(curation=dict(iframe_mode=True))`,
+the GUI is expected to be used inside an iframe. In this mode, the curation view will include a "Submit to parent"
+button that, when clicked, will send the current curation data to the parent window.
+In this mode, bi-directional communication is established between the GUI and the parent window using the `postMessage`
+API. The GUI listens for incoming messages of this expected shape:
+
+```
+{
+ "payload": {"type": "curation-data", "data": },
+}
+```
"""
diff --git a/spikeinterface_gui/mergeview.py b/spikeinterface_gui/mergeview.py
index 18fc7b13..1a6169f3 100644
--- a/spikeinterface_gui/mergeview.py
+++ b/spikeinterface_gui/mergeview.py
@@ -438,9 +438,15 @@ def _panel_on_preset_change(self, event):
self.layout[layout_index] = self.preset_params_selectors[self.preset]
def _panel_on_click(self, event):
+ import panel as pn
+
# set unit visibility
row = event.row
- self.table.selection = [row]
+
+ def _do_update():
+ self.table.selection = [row]
+
+ pn.state.execute(_do_update, schedule=True)
self._panel_update_visible_pair(row)
def _panel_include_deleted_change(self, event):
@@ -458,19 +464,40 @@ def _panel_update_visible_pair(self, row):
self.notify_unit_visibility_changed()
def _panel_handle_shortcut(self, event):
+ import panel as pn
+
if event.data == "accept":
selected = self.table.selection
- for row in selected:
- group_ids = self.table.value.iloc[row].group_ids
- self.accept_group_merge(group_ids)
+ if len(selected) == 0:
+ return
+ # selected is always 1
+ row = selected[0]
+ group_ids = self.table.value.iloc[row].group_ids
+ self.accept_group_merge(group_ids)
self.notify_manual_curation_updated()
+
+ next_row = min(row + 1, len(self.table.value) - 1)
+
+ def _select_next():
+ self.table.selection = [next_row]
+
+ pn.state.execute(_select_next, schedule=True)
+ self._panel_update_visible_pair(next_row)
elif event.data == "next":
next_row = min(self.table.selection[0] + 1, len(self.table.value) - 1)
- self.table.selection = [next_row]
+
+ def _do_next():
+ self.table.selection = [next_row]
+
+ pn.state.execute(_do_next, schedule=True)
self._panel_update_visible_pair(next_row)
elif event.data == "previous":
previous_row = max(self.table.selection[0] - 1, 0)
- self.table.selection = [previous_row]
+
+ def _do_prev():
+ self.table.selection = [previous_row]
+
+ pn.state.execute(_do_prev, schedule=True)
self._panel_update_visible_pair(previous_row)
def _panel_on_spike_selection_changed(self):
diff --git a/spikeinterface_gui/metricsview.py b/spikeinterface_gui/metricsview.py
index cab29432..3c3b12bf 100644
--- a/spikeinterface_gui/metricsview.py
+++ b/spikeinterface_gui/metricsview.py
@@ -172,7 +172,7 @@ def _panel_make_layout(self):
self.empty_plot_pane = pn.pane.Bokeh(empty_fig, sizing_mode="stretch_both")
self.layout = pn.Column(
- pn.Row(self.metrics_select, sizing_mode="stretch_width", height=160),
+ pn.Row(self.metrics_select, sizing_mode="stretch_width", height=100),
self.empty_plot_pane,
sizing_mode="stretch_both"
)
@@ -221,6 +221,7 @@ def _panel_refresh(self):
outline_line_color="white",
toolbar_location=None
)
+ plot.toolbar.logo = None
plot.grid.visible = False
if r == c:
diff --git a/spikeinterface_gui/ndscatterview.py b/spikeinterface_gui/ndscatterview.py
index ec63b712..0b8ba542 100644
--- a/spikeinterface_gui/ndscatterview.py
+++ b/spikeinterface_gui/ndscatterview.py
@@ -488,20 +488,31 @@ def _panel_refresh(self, update_components=True, update_colors=True):
self.scatter_fig.y_range.end = self.limit
def _panel_on_spike_selection_changed(self):
+ import panel as pn
+
# handle selection with lasso
plotted_spike_indices = self.scatter_source.data.get("spike_indices", [])
ind_selected, = np.nonzero(np.isin(plotted_spike_indices, self.controller.get_indices_spike_selected()))
- self.scatter_source.selected.indices = ind_selected
+
+ def _do_update():
+ self.scatter_source.selected.indices = ind_selected
+
+ pn.state.execute(_do_update, schedule=True)
def _panel_gain_zoom(self, event):
- from bokeh.models import Range1d
+ import panel as pn
factor = 1.3 if event.delta > 0 else 1 / 1.3
self.limit /= factor
- self.scatter_fig.x_range.start = -self.limit
- self.scatter_fig.x_range.end = self.limit
- self.scatter_fig.y_range.start = -self.limit
- self.scatter_fig.y_range.end = self.limit
+ limit = self.limit
+
+ def _do_update():
+ self.scatter_fig.x_range.start = -limit
+ self.scatter_fig.x_range.end = limit
+ self.scatter_fig.y_range.start = -limit
+ self.scatter_fig.y_range.end = limit
+
+ pn.state.execute(_do_update, schedule=True)
def _panel_next_face(self, event):
self.next_face()
@@ -522,11 +533,18 @@ def _panel_start_stop_tour(self, event):
self.auto_update_limit = True
def _panel_on_select_button(self, event):
- if self.select_toggle_button.value:
- self.scatter_fig.toolbar.active_drag = self.lasso_tool
- else:
- self.scatter_fig.toolbar.active_drag = None
- self.scatter_source.selected.indices = []
+ import panel as pn
+
+ value = self.select_toggle_button.value
+
+ def _do_update():
+ if value:
+ self.scatter_fig.toolbar.active_drag = self.lasso_tool
+ else:
+ self.scatter_fig.toolbar.active_drag = None
+ self.scatter_source.selected.indices = []
+
+ pn.state.execute(_do_update, schedule=True)
def _on_panel_selection_geometry(self, event):
"""
diff --git a/spikeinterface_gui/probeview.py b/spikeinterface_gui/probeview.py
index db7a217f..1c880208 100644
--- a/spikeinterface_gui/probeview.py
+++ b/spikeinterface_gui/probeview.py
@@ -485,55 +485,84 @@ def _panel_make_layout(self):
)
def _panel_refresh(self):
- # Only update unit positions if they actually changed
+ import panel as pn
+
+ # Compute everything outside the scheduled callback
current_unit_positions = self.controller.unit_positions
- if not np.array_equal(current_unit_positions, self._unit_positions):
- self._unit_positions = current_unit_positions
- # Update positions in data source
- self.glyphs_data_source.patch({
+ positions_changed = not np.array_equal(current_unit_positions, self._unit_positions)
+ show_channel_id = self.settings['show_channel_id']
+ auto_zoom = self.settings['auto_zoom_on_unit_selection']
+ radius_channel = self.settings['radius_channel']
+
+ selected_unit_indices = self.controller.get_visible_unit_indices()
+ visible_mask = self.controller.get_units_visibility_mask()
+
+ # Pre-compute unit glyph updates
+ glyph_patch_dict = self._panel_compute_unit_glyph_patches()
+
+ # Pre-compute position patches
+ position_patches = None
+ if positions_changed:
+ position_patches = {
'x': [(i, pos[0]) for i, pos in enumerate(current_unit_positions)],
'y': [(i, pos[1]) for i, pos in enumerate(current_unit_positions)]
- })
-
- # Update unit positions
- self._panel_update_unit_glyphs()
+ }
- # channel labels
- for label in self.channel_labels:
- label.visible = self.settings['show_channel_id']
+ # Pre-compute zoom bounds
+ zoom_bounds = None
+ if auto_zoom and sum(visible_mask) > 0:
+ visible_pos = self.controller.unit_positions[visible_mask, :]
+ x_min, x_max = np.min(visible_pos[:, 0]), np.max(visible_pos[:, 0])
+ y_min, y_max = np.min(visible_pos[:, 1]), np.max(visible_pos[:, 1])
+ margin = 50
+ zoom_bounds = (x_min - margin, x_max + margin, y_min - margin, y_max + margin)
- # Update selection circles if only one unit is visible
- selected_unit_indices = self.controller.get_visible_unit_indices()
+ # Pre-compute circle updates
+ circle_update = None
if len(selected_unit_indices) == 1:
unit_index = selected_unit_indices[0]
unit_positions = self.controller.unit_positions
- x, y = unit_positions[unit_index, 0], unit_positions[unit_index, 1]
- # Update circles position
- self.unit_circle.update_position(x, y)
+ cx, cy = unit_positions[unit_index, 0], unit_positions[unit_index, 1]
+ visible_channel_inds = self.update_channel_visibility(cx, cy, radius_channel)
+ circle_update = (cx, cy, visible_channel_inds)
- self.channel_circle.update_position(x, y)
- # Update channel visibility
- visible_channel_inds = self.update_channel_visibility(x, y, self.settings['radius_channel'])
+ # def _do_update():
+ if positions_changed:
+ self._unit_positions = current_unit_positions
+ if position_patches:
+ self.glyphs_data_source.patch(position_patches)
+
+ # Update unit glyphs
+ if glyph_patch_dict:
+ self.glyphs_data_source.patch(glyph_patch_dict)
+
+ # Channel labels
+ for label in self.channel_labels:
+ label.visible = show_channel_id
+
+ # Update selection circles
+ if circle_update is not None:
+ cx, cy, visible_channel_inds = circle_update
+ self.unit_circle.update_position(cx, cy)
+ self.channel_circle.update_position(cx, cy)
self.controller.set_channel_visibility(visible_channel_inds)
- if self.settings['auto_zoom_on_unit_selection']:
- visible_mask = self.controller.get_units_visibility_mask()
- if sum(visible_mask) > 0:
- visible_pos = self.controller.unit_positions[visible_mask, :]
- x_min, x_max = np.min(visible_pos[:, 0]), np.max(visible_pos[:, 0])
- y_min, y_max = np.min(visible_pos[:, 1]), np.max(visible_pos[:, 1])
- margin = 50
- self.x_range.start = x_min - margin
- self.x_range.end = x_max + margin
- self.y_range.start = y_min - margin
- self.y_range.end = y_max + margin
+ # Auto zoom
+ if zoom_bounds is not None:
+ self.x_range.start = zoom_bounds[0]
+ self.x_range.end = zoom_bounds[1]
+ self.y_range.start = zoom_bounds[2]
+ self.y_range.end = zoom_bounds[3]
- def _panel_update_unit_glyphs(self):
- # Get current data from source
+ # Defer to avoid nested Bokeh callbacks
+ # pn.state.execute(_do_update, schedule=True)
+
+ def _panel_compute_unit_glyph_patches(self):
+ """Compute glyph patches without modifying Bokeh models."""
current_alphas = self.glyphs_data_source.data['alpha']
current_sizes = self.glyphs_data_source.data['size']
current_line_colors = self.glyphs_data_source.data['line_color']
- # Prepare patches (only for changed values)
+
alpha_patches = []
size_patches = []
line_color_patches = []
@@ -542,12 +571,10 @@ def _panel_update_unit_glyphs(self):
color = self.get_unit_color(unit_id)
is_visible = self.controller.get_unit_visibility(unit_id)
- # Compute new values
new_alpha = self.alpha_selected if is_visible else self.alpha_unselected
new_size = self.unit_marker_size_selected if is_visible else self.unit_marker_size_unselected
new_line_color = "black" if is_visible else color
- # Only patch if changed
if current_alphas[idx] != new_alpha:
alpha_patches.append((idx, new_alpha))
if current_sizes[idx] != new_size:
@@ -555,22 +582,33 @@ def _panel_update_unit_glyphs(self):
if current_line_colors[idx] != new_line_color:
line_color_patches.append((idx, new_line_color))
- # Apply patches if any changes detected
- if len(alpha_patches) > 0 or len(size_patches) > 0 or len(line_color_patches) > 0:
- patch_dict = {}
- if alpha_patches:
- patch_dict['alpha'] = alpha_patches
- if size_patches:
- patch_dict['size'] = size_patches
- if line_color_patches:
- patch_dict['line_color'] = line_color_patches
-
- self.glyphs_data_source.patch(patch_dict)
-
+ patch_dict = {}
+ if alpha_patches:
+ patch_dict['alpha'] = alpha_patches
+ if size_patches:
+ patch_dict['size'] = size_patches
+ if line_color_patches:
+ patch_dict['line_color'] = line_color_patches
+
+ return patch_dict
+
+ def _panel_update_unit_glyphs(self):
+ import panel as pn
+
+ patch_dict = self._panel_compute_unit_glyph_patches()
+ if patch_dict:
+ pn.state.execute(lambda: self.glyphs_data_source.patch(patch_dict), schedule=True)
+
def _panel_on_pan_start(self, event):
- self.figure.toolbar.active_drag = None
+ import panel as pn
+
x, y = event.x, event.y
+ def _do_update():
+ self.figure.toolbar.active_drag = None
+
+ pn.state.execute(_do_update, schedule=True)
+
if self.unit_circle.is_close_to_diamond(x, y):
self.should_resize_unit_circle = [x, y]
self.unit_circle.select()
@@ -654,6 +692,8 @@ def _panel_on_pan_end(self, event):
def _panel_on_tap(self, event):
+ import panel as pn
+
x, y = event.x, event.y
unit_positions = self.controller.unit_positions
distances = np.sqrt(np.sum((unit_positions - np.array([x, y])) ** 2, axis=1))
@@ -683,16 +723,18 @@ def _panel_on_tap(self, event):
self._panel_update_unit_glyphs()
-
if select_only:
# Update selection circles
- self.unit_circle.update_position(x, y)
- self.channel_circle.update_position(x, y)
+ def _do_update():
+ self.unit_circle.update_position(x, y)
+ self.channel_circle.update_position(x, y)
+
+ pn.state.execute(_do_update, schedule=True)
# Update channel visibility
visible_channel_inds = self.update_channel_visibility(x, y, self.settings['radius_channel'])
self.controller.set_channel_visibility(visible_channel_inds)
- self.notify_channel_visibility_changed
+ self.notify_channel_visibility_changed()
self.notify_unit_visibility_changed()
diff --git a/spikeinterface_gui/spikelistview.py b/spikeinterface_gui/spikelistview.py
index b82986c3..414626de 100644
--- a/spikeinterface_gui/spikelistview.py
+++ b/spikeinterface_gui/spikelistview.py
@@ -118,6 +118,7 @@ class SpikeListView(ViewBase):
]
def __init__(self, controller=None, parent=None, backend="qt"):
+ self._updating_from_controller = False # Add this guard
ViewBase.__init__(self, controller=controller, parent=parent, backend=backend)
def handle_selection(self, inds):
@@ -276,7 +277,7 @@ def _panel_make_layout(self):
conditional_shortcut=self.is_view_active,
)
# Add selection event handler
- self.table.param.watch(self._panel_on_selection_changed, "selection")
+ self.table.param.watch(self._panel_on_user_selection_changed, "selection")
self.refresh_button = pn.widgets.Button(name="↻ spikes", button_type="default", sizing_mode="stretch_width")
self.refresh_button.on_click(self._panel_on_refresh_click)
@@ -328,17 +329,22 @@ def _panel_refresh_table(self):
'rand_selected': spikes['rand_selected']
}
- # Update table data
+ # Update table data without replacing entire dataframe
df = pd.DataFrame(data)
- self.table.value = df
+
+ # Only update if data changed
+ if not self.table.value.equals(df):
+ self.table.value = df
selected_inds = self.controller.get_indices_spike_selected()
+ self._updating_from_controller = True
if len(selected_inds) == 0:
self.table.selection = []
else:
# Find the rows corresponding to the selected indices
row_selected, = np.nonzero(np.isin(visible_inds, selected_inds))
self.table.selection = [int(r) for r in row_selected]
+ self._updating_from_controller = False
self._panel_refresh_label()
@@ -348,21 +354,39 @@ def _panel_on_refresh_click(self, event):
self.notify_active_view_updated()
def _panel_on_clear_click(self, event):
+ import panel as pn
+
self.controller.set_indices_spike_selected([])
- self.table.selection = []
+
+ def _do_update():
+ self.table.selection = []
+ self._panel_refresh_label()
+
+ pn.state.execute(_do_update, schedule=True)
self.notify_spike_selection_changed()
- self._panel_refresh_label()
self.notify_active_view_updated()
- def _panel_on_selection_changed(self, event=None):
- selection = event.new
+ def _panel_on_user_selection_changed(self, event=None):
+ import panel as pn
+
+ # Ignore if we're updating from controller
+ if self._updating_from_controller:
+ return
+
+ selection = event.new if event else self.table.selection
if len(selection) == 0:
- self.handle_selection([])
+ absolute_indices = []
else:
- absolute_indices = self.controller.get_indices_spike_visible()[np.array(selection)]
- self.handle_selection(absolute_indices)
+ visible_inds = self.controller.get_indices_spike_visible()
+ absolute_indices = visible_inds[np.array(selection)]
+
+ # Defer the entire selection handling to avoid nested Bokeh callbacks
+ # def update_selection():
+ self.handle_selection(absolute_indices)
self._panel_refresh_label()
+ # pn.state.execute(update_selection, schedule=True)
+
def _panel_refresh_label(self):
n1 = self.controller.spikes.size
n2 = self.controller.get_indices_spike_visible().size
@@ -371,21 +395,35 @@ def _panel_refresh_label(self):
self.info_text.object = txt
def _panel_on_unit_visibility_changed(self):
+ import panel as pn
import pandas as pd
- # Clear the table when visibility changes
- self.table.value = pd.DataFrame(columns=_columns, data=[])
- self._panel_refresh_label()
+
+ def _do_update():
+ # Clear the table when visibility changes
+ self.table.value = pd.DataFrame(columns=_columns, data=[])
+ self._updating_from_controller = True
+ self.table.selection = []
+ self._updating_from_controller = False
+ self._panel_refresh_label()
+
+ pn.state.execute(_do_update, schedule=True)
def _panel_on_spike_selection_changed(self):
+ import panel as pn
+
if len(self.table.value) == 0:
return
selected_inds = self.controller.get_indices_spike_selected()
visible_inds = self.controller.get_indices_spike_visible()
row_selected, = np.nonzero(np.isin(visible_inds, selected_inds))
row_selected = [int(r) for r in row_selected]
- # Update the selection in the table
- self.table.selection = row_selected
- self._panel_refresh_label()
+
+ def _do_update():
+ # Update the selection in the table
+ self.table.selection = row_selected
+ self._panel_refresh_label()
+
+ pn.state.execute(_do_update, schedule=True)
SpikeListView._gui_help_txt = """
diff --git a/spikeinterface_gui/tests/iframe/curation.json b/spikeinterface_gui/tests/iframe/curation.json
new file mode 100644
index 00000000..db7d689e
--- /dev/null
+++ b/spikeinterface_gui/tests/iframe/curation.json
@@ -0,0 +1,75 @@
+{
+ "supported_versions": [
+ "1",
+ "2"
+ ],
+ "format_version": "2",
+ "unit_ids": [
+ "0",
+ "1",
+ "2",
+ "3",
+ "4",
+ "5",
+ "6",
+ "7",
+ "8",
+ "9",
+ "10",
+ "11",
+ "12",
+ "13",
+ "14",
+ "15"
+ ],
+ "label_definitions": {
+ "quality": {
+ "name": "quality",
+ "label_options": [
+ "good",
+ "noise",
+ "MUA"
+ ],
+ "exclusive": true
+ }
+ },
+ "manual_labels": [],
+ "removed": [
+ "3",
+ "8"
+ ],
+ "merges": [
+ {
+ "unit_ids": [
+ "1",
+ "2"
+ ],
+ "new_unit_id": null
+ },
+ {
+ "unit_ids": [
+ "5",
+ "6",
+ "7"
+ ],
+ "new_unit_id": null
+ }
+ ],
+ "splits": [
+ {
+ "unit_id": "4",
+ "mode": "indices",
+ "indices": [
+ [
+ 0,
+ 1,
+ 2,
+ 3,
+ 4
+ ]
+ ],
+ "labels": null,
+ "new_unit_ids": null
+ }
+ ]
+}
\ No newline at end of file
diff --git a/spikeinterface_gui/tests/iframe/iframe_server.py b/spikeinterface_gui/tests/iframe/iframe_server.py
index 8dada7b9..ad1b6e1f 100644
--- a/spikeinterface_gui/tests/iframe/iframe_server.py
+++ b/spikeinterface_gui/tests/iframe/iframe_server.py
@@ -2,120 +2,142 @@
import webbrowser
from pathlib import Path
import argparse
+import socket
+import time
+
from flask import Flask, send_file, jsonify
app = Flask(__name__)
+
+PANEL_HOST = "localhost" # <- switch back to localhost
+
panel_server = None
panel_url = None
panel_thread = None
panel_port_global = None
+panel_last_error = None
+
-@app.route('/')
+def _wait_for_port(host: str, port: int, timeout_s: float = 20.0) -> bool:
+ deadline = time.time() + timeout_s
+ while time.time() < deadline:
+ try:
+ with socket.create_connection((host, port), timeout=0.25):
+ return True
+ except OSError:
+ time.sleep(0.1)
+ return False
+
+
+@app.route("/")
def index():
- """Serve the iframe test HTML page"""
- return send_file('iframe_test.html')
+ here = Path(__file__).parent
+ return send_file(str(here / "iframe_test.html"))
+
-@app.route('/start_test_server')
+@app.route("/curation.json")
+def curation_json():
+ here = Path(__file__).parent
+ return send_file(str(here / "curation.json"))
+
+
+@app.route("/start_test_server")
def start_test_server():
- """Start the Panel server in a separate thread"""
- global panel_server, panel_url, panel_thread, panel_port_global
-
- # If a server is already running, return its URL
+ global panel_server, panel_url, panel_thread, panel_port_global, panel_last_error
+
if panel_url:
return jsonify({"success": True, "url": panel_url})
-
- # Make sure the test dataset exists
+
+ panel_last_error = None
+
test_folder = Path(__file__).parent / "my_dataset"
if not test_folder.is_dir():
from spikeinterface_gui.tests.testingtools import make_analyzer_folder
+
make_analyzer_folder(test_folder)
-
- # Function to run the Panel server in a thread
+
def run_panel_server():
- global panel_server, panel_url, panel_port_global
+ global panel_server, panel_url, panel_port_global, panel_last_error
try:
- # Start the Panel server with curation enabled
- # Use a direct import to avoid circular imports
+ import panel as pn
from spikeinterface import load_sorting_analyzer
from spikeinterface_gui import run_mainwindow
-
- # Load the analyzer
+
+ pn.extension("tabulator", "gridstack")
+
analyzer = load_sorting_analyzer(test_folder / "sorting_analyzer")
-
- # Start the Panel server directly
- win = run_mainwindow(
- analyzer,
- backend="panel",
- start_app=False,
- verbose=True,
- curation=True,
- make_servable=True
+
+ def app_factory():
+ win = run_mainwindow(
+ analyzer,
+ mode="web",
+ start_app=False,
+ verbose=True,
+ curation=True,
+ panel_window_servable=True,
+ )
+ return win.main_layout
+
+ allowed = [
+ f"localhost:{int(panel_port_global)}",
+ f"127.0.0.1:{int(panel_port_global)}",
+ ]
+ print(panel_port_global)
+ server = pn.serve(
+ app_factory,
+ port=int(panel_port_global),
+ address=PANEL_HOST,
+ allow_websocket_origin=allowed,
+ show=False,
+ start=False,
)
-
- # Start the server manually
- import panel as pn
- pn.serve(win.main_layout, port=panel_port_global, address="localhost", show=False, start=True)
-
- # Get the server URL
- panel_url = f"http://localhost:{panel_port_global}"
- panel_server = win
-
- print(f"Panel server started at {panel_url}")
+
+ panel_server = server
+ panel_url = f"http://{PANEL_HOST}:{panel_port_global}/"
+ print(f"Panel server starting at {panel_url} (allow_websocket_origin={allowed})")
+
+ server.start()
+ server.io_loop.start()
+
except Exception as e:
- print(f"Error starting Panel server: {e}")
+ panel_last_error = repr(e)
+ panel_server = None
+ panel_url = None
import traceback
+
traceback.print_exc()
-
- # Start the Panel server in a separate thread
- panel_thread = threading.Thread(target=run_panel_server)
- panel_thread.daemon = True
+
+ panel_thread = threading.Thread(target=run_panel_server, daemon=True)
panel_thread.start()
-
- # Give the server some time to start
- import time
- time.sleep(5) # Increased wait time
-
- # Check if the server is actually running
- import requests
- try:
- response = requests.get(f"http://localhost:{panel_port_global}", timeout=2)
- if response.status_code == 200:
- return jsonify({"success": True, "url": f"http://localhost:{panel_port_global}"})
- else:
- return jsonify({"success": False, "error": f"Server returned status code {response.status_code}"})
- except requests.exceptions.RequestException as e:
- return jsonify({"success": False, "error": f"Could not connect to Panel server: {str(e)}"})
-
-@app.route('/stop_test_server')
-def stop_test_server():
- """Stop the Panel server"""
- global panel_server, panel_url, panel_thread
-
- if panel_server:
- # Clean up resources
- # clean_all(Path(__file__).parent / 'my_dataset')
- panel_url = None
- panel_server = None
- return jsonify({"success": True})
- else:
- return jsonify({"success": False, "error": "No server running"})
-
-def main(flask_port=5000, panel_port=5006):
- """Start the Flask server and open the browser"""
+
+ if not _wait_for_port("127.0.0.1", int(panel_port_global), timeout_s=30.0):
+ return (
+ jsonify(
+ {
+ "success": False,
+ "error": "Panel server did not become ready (port not open).",
+ "panel_host": PANEL_HOST,
+ "panel_port": panel_port_global,
+ "last_error": panel_last_error,
+ }
+ ),
+ 500,
+ )
+
+ return jsonify({"success": True, "url": panel_url})
+
+
+def main(flask_port=5000, panel_port=5007):
global panel_port_global
panel_port_global = panel_port
- # Open the browser
- webbrowser.open(f'http://localhost:{flask_port}')
-
- # Start the Flask server
- app.run(debug=False, port=flask_port)
+ webbrowser.open(f"http://localhost:{flask_port}")
+ app.run(debug=False, port=flask_port, host="localhost")
parser = argparse.ArgumentParser(description="Run the Flask and Panel servers.")
-parser.add_argument('--flask-port', type=int, default=5000, help="Port for the Flask server (default: 5000)")
-parser.add_argument('--panel-port', type=int, default=5006, help="Port for the Panel server (default: 5006)")
+parser.add_argument("--flask-port", type=int, default=5000)
+parser.add_argument("--panel-port", type=int, default=5006)
-if __name__ == '__main__':
+if __name__ == "__main__":
args = parser.parse_args()
-
main(flask_port=int(args.flask_port), panel_port=int(args.panel_port))
diff --git a/spikeinterface_gui/tests/iframe/iframe_test.html b/spikeinterface_gui/tests/iframe/iframe_test.html
index d3e0d91a..050b05e4 100644
--- a/spikeinterface_gui/tests/iframe/iframe_test.html
+++ b/spikeinterface_gui/tests/iframe/iframe_test.html
@@ -51,6 +51,7 @@
display: flex;
gap: 10px;
margin-bottom: 10px;
+ flex-wrap: wrap;
}
button {
padding: 8px 16px;
@@ -63,22 +64,68 @@
button:hover {
background-color: #45a049;
}
+ button:disabled {
+ background-color: #cccccc;
+ cursor: not-allowed;
+ }
+ .send-controls {
+ background-color: #fff;
+ border: 1px solid #ccc;
+ border-radius: 5px;
+ padding: 20px;
+ }
+ textarea {
+ width: 100%;
+ min-height: 150px;
+ font-family: 'Courier New', monospace;
+ font-size: 12px;
+ padding: 10px;
+ border: 1px solid #ddd;
+ border-radius: 4px;
+ box-sizing: border-box;
+ }
+ .button-secondary {
+ background-color: #2196F3;
+ }
+ .button-secondary:hover {
+ background-color: #0b7dda;
+ }
+ .button-danger {
+ background-color: #f44336;
+ }
+ .button-danger:hover {
+ background-color: #da190b;
+ }
-
SpikeInterface GUI Iframe Test
-
+
SpikeInterface GUI Iframe Test - Bidirectional Communication
+
-
+
+
-
+
+
+
Send Curation Data to iframe:
+
+
Edit the curation data below and click "Send to iframe" to update the GUI:
+
+
+
+
+
+
+
-
Received Data:
+
Received Data from iframe:
No data received yet...
@@ -88,23 +135,88 @@
Received Data:
// DOM elements
const iframe = document.getElementById('guiFrame');
const dataOutput = document.getElementById('dataOutput');
+ const curationInput = document.getElementById('curationInput');
const clearBtn = document.getElementById('clearBtn');
const startServerBtn = document.getElementById('startServerBtn');
-
+ const sendDataBtn = document.getElementById('sendDataBtn');
+ const resetDataBtn = document.getElementById('resetDataBtn');
+ const clearCurationBtn = document.getElementById('clearCurationBtn');
+ const autoSendCheckbox = document.getElementById('autoSendCheckbox');
+
// Store received data
let receivedData = [];
-
+ let autoSendEnabled = false;
+ let defaultCurationData = null;
+
+ // Update auto-send state when checkbox changes
+ autoSendCheckbox.addEventListener('change', function() {
+ autoSendEnabled = autoSendCheckbox.checked;
+ console.log('Auto-send ' + (autoSendEnabled ? 'enabled' : 'disabled'));
+ });
+
+ // Load default curation data from curation.json
+ fetch('/curation.json')
+ .then(response => {
+ if (!response.ok) {
+ throw new Error(`HTTP error! status: ${response.status}`);
+ }
+ return response.json();
+ })
+ .then(data => {
+ defaultCurationData = data;
+ // Initialize the textarea with the loaded data
+ curationInput.value = JSON.stringify(defaultCurationData, null, 2);
+ console.log('Loaded curation data from curation.json');
+ })
+ .catch(error => {
+ console.error('Error loading curation.json:', error);
+ // Fallback to simple default data
+ defaultCurationData = {
+ "merges": [
+ {"unit_ids": ["1", "2"]},
+ {"unit_ids": ["5", "6", "7"]}
+ ],
+ "removed": ["3", "8"],
+ "splits": [
+ {
+ "unit_id": "4",
+ "mode": "indices",
+ "indices": [[0, 1, 2, 3, 4]]
+ }
+ ]
+ };
+ curationInput.value = JSON.stringify(defaultCurationData, null, 2);
+ console.log('Using fallback curation data');
+ });
+
// Event listener for messages from the iframe
function handleMessage(event) {
console.log('Message received from iframe:', event.data);
-
+
+ // Ignore messages from browser extensions
+ if (event.data && event.data.source === 'react-devtools-content-script') {
+ return;
+ }
+
// Check if the message is from our Panel application
if (event.data && event.data.type === 'panel-data') {
const data = event.data.data;
receivedData.push(data);
// Format and display the data
- displayData();
+ displayReceivedData();
+
+ // Check for loaded status and auto-send if enabled
+ if (data && data.loaded === true) {
+ console.log('iframe loaded detected, auto-send enabled:', autoSendCheckbox.checked);
+ if (autoSendCheckbox.checked) {
+ console.log('Auto-sending curation data...');
+ const curationData = curationInput.value;
+ setTimeout(() => {
+ sendCurationData(curationData);
+ }, 100); // Small delay to ensure iframe is fully ready
+ }
+ }
}
}
@@ -112,7 +224,7 @@
Received Data:
window.addEventListener('message', handleMessage);
// Function to display the received data
- function displayData() {
+ function displayReceivedData() {
if (receivedData.length === 0) {
dataOutput.textContent = 'No data received yet...';
return;
@@ -129,10 +241,33 @@
Received Data:
dataOutput.scrollTop = dataOutput.scrollHeight;
}
+ // Function to send curation data to the iframe
+ function sendCurationData(curationData) {
+ try {
+ // Validate the JSON
+ const data = typeof curationData === 'string'
+ ? JSON.parse(curationData)
+ : curationData;
+
+ console.log('Sending curation data to iframe:', data);
+
+ // Send the message to the iframe
+ iframe.contentWindow.postMessage({
+ type: 'curation-data',
+ data: data
+ }, '*');
+
+ console.log('Curation data sent to iframe successfully!');
+ } catch (error) {
+ console.error('Error sending curation data:', error);
+ alert('Error sending curation data: ' + error.message);
+ }
+ }
+
// Clear button event listener
clearBtn.addEventListener('click', function() {
receivedData = [];
- displayData();
+ displayReceivedData();
// Re-add the event listener to ensure it's working
window.removeEventListener('message', handleMessage);
@@ -140,6 +275,29 @@
// Add instructions for using the curation view
dataOutput.textContent = `Instructions:
1. Click "Start Test Server" to launch the SpikeInterface GUI
-2. Once loaded, navigate to the "curation" tab
-3. Click the "Submit to parent" button to send data to this page
-4. The received data will appear here
+2. Optionally enable "Auto send when ready" to automatically send curation data when the iframe signals it's loaded
+3. Once loaded, navigate to the "curation" tab
+
+Testing SEND (parent → iframe):
+- Edit the curation data in the textarea above
+- Click "Send to iframe" to update the GUI's curation data
+- The GUI should refresh and show the new curation data
+
+Testing RECEIVE (iframe → parent):
+- Click the "Submit to parent" button in the GUI's curation view
+- The received data will appear in this section below
Note: The first time you start the server, it may take a minute to generate test data.`;
diff --git a/spikeinterface_gui/tests/iframe/iframe_test_README.md b/spikeinterface_gui/tests/iframe/iframe_test_README.md
index a84e6092..76db6829 100644
--- a/spikeinterface_gui/tests/iframe/iframe_test_README.md
+++ b/spikeinterface_gui/tests/iframe/iframe_test_README.md
@@ -1,13 +1,19 @@
-# SpikeInterface GUI Iframe Test
+# SpikeInterface GUI Iframe Test - Bidirectional Communication
-This is a simple test application that demonstrates how to embed the SpikeInterface GUI in an iframe and receive data from it.
+This is a test application that demonstrates bidirectional communication between a parent window and the SpikeInterface GUI embedded in an iframe.
## Overview
The test consists of:
-1. `iframe_test.html` - A simple HTML page that embeds the SpikeInterface GUI in an iframe and displays data received from it
-2. `iframe_server.py` - A Flask server that serves the HTML page and starts the Panel application
+1. `iframe_test.html` - An HTML page that:
+ - Embeds the SpikeInterface GUI in an iframe
+ - Sends curation data TO the iframe
+ - Receives curation data FROM the iframe
+
+2. `iframe_server.py` - A Flask server that:
+ - Serves the HTML page
+ - Starts the Panel application
## How to Run
@@ -23,33 +29,67 @@ The test consists of:
3. This will open a browser window with the test application. Click the "Start Test Server" button to launch the SpikeInterface GUI.
-4. Once the GUI is loaded, navigate to the "curation" tab (you may need to look for it in the different zones of the interface).
+## Testing the Bidirectional Communication
-5. Click the "Submit to parent" button in the curation view to send data to the parent window.
+### Test 1: Send Data to iframe (Parent → iframe)
-6. The received data will be displayed in the "Received Data" section of the parent window.
+1. Once the GUI is loaded, navigate to the "curation" tab
+2. In the parent window, edit the curation data in the textarea
+3. Click "Send to iframe" button
+4. The GUI should refresh and display the new curation data in its tables
+5. The "Auto send when ready" button will automatically send the curation data when the iframe is ready
-## How It Works
+### Test 2: Receive Data from iframe (iframe → Parent)
-1. The Flask server serves the HTML page and provides an endpoint to start the Panel application.
+1. In the GUI's curation view, click the "Submit to parent" button
+2. The data will appear in the "Received Data from iframe" section of the parent window
+3. Multiple submissions will be accumulated and displayed
-2. The HTML page embeds the Panel application in an iframe.
+### Test 3: Clear All Curation
-3. The curationview.py file has been modified to send data to the parent window using the `postMessage` API when the "Submit to parent" button is clicked.
+1. Click the "Clear All Curation" button
+2. This sends an empty curation object to the iframe
+3. All tables in the curation view should become empty
-4. The parent window listens for messages from the iframe and displays the received data.
+## How It Works
-## Troubleshooting
+### Parent → iframe Communication
-- If the Panel application fails to start, check the console for error messages.
-- If no data is received when clicking the "Submit to parent" button, make sure you're using the latest version of the curationview.py file with the fix for the DataCloneError.
-- If you're running the Panel application separately, you can manually set the iframe URL using the `setIframeUrl` function in the browser console.
+1. Parent window constructs a curation data object
+2. The iframe sends a `loaded=true` message when ready to receive data
+3. Sends it via `postMessage` with `type: 'curation-data'`
+4. The iframe's JavaScript listener receives the message
+5. Python updates `controller.set_curation_data` and refreshes the view
-## Technical Details
+### iframe → Parent Communication
-The fix for the DataCloneError involves:
+1. User clicks "Submit to parent" in the curation view
+2. Python generates a JSON string of the curation model
+3. JavaScript code is injected that sends the data via `postMessage`
+4. Parent window's message listener receives the data
+5. Data is parsed and displayed
-1. Converting the complex object to a JSON string before sending it to the parent window
-2. Parsing the JSON string back to an object in the parent window
-This avoids the issue where the browser tries to clone a complex object that contains non-cloneable properties.
+## Technical Details
+
+The curation data needs to follow the [`CurationModel`](https://spikeinterface.readthedocs.io/en/stable/api.html#curation-model).
+
+### Message Format (Parent → iframe)
+```javascript
+{
+ type: 'curation-data',
+ data: {
+ // Full curation model JSON
+ }
+}
+```
+
+### Message Format (iframe → Parent)
+```javascript
+{
+ type: 'panel-data',
+ data: {
+ // Full curation model JSON
+ }
+}
+```
diff --git a/spikeinterface_gui/tracemapview.py b/spikeinterface_gui/tracemapview.py
index d5bc07df..2f9e5873 100644
--- a/spikeinterface_gui/tracemapview.py
+++ b/spikeinterface_gui/tracemapview.py
@@ -310,15 +310,30 @@ def _panel_on_spike_selection_changed(self):
self._panel_seek_with_selected_spike()
def _panel_gain_zoom(self, event):
+ import panel as pn
+
factor_ratio = 1.3 if event.delta > 0 else 1 / 1.3
- self.color_mapper.high = self.color_mapper.high * factor_ratio
- self.color_mapper.low = -self.color_mapper.high
+ new_high = self.color_mapper.high * factor_ratio
+ new_low = -new_high
+
+ def _do_update():
+ self.color_mapper.high = new_high
+ self.color_mapper.low = new_low
+
+ pn.state.execute(_do_update, schedule=True)
def _panel_auto_scale(self, event):
+ import panel as pn
+
if self.last_data_curves is not None:
self.color_limit = np.max(np.abs(self.last_data_curves))
- self.color_mapper.high = self.color_limit
- self.color_mapper.low = -self.color_limit
+ color_limit = self.color_limit
+
+ def _do_update():
+ self.color_mapper.high = color_limit
+ self.color_mapper.low = -color_limit
+
+ pn.state.execute(_do_update, schedule=True)
def _panel_on_time_info_updated(self):
# Update segment and time slider range
diff --git a/spikeinterface_gui/traceview.py b/spikeinterface_gui/traceview.py
index bc4d2166..8a461ffd 100644
--- a/spikeinterface_gui/traceview.py
+++ b/spikeinterface_gui/traceview.py
@@ -316,6 +316,8 @@ def _panel_on_time_slider_changed(self, event):
self.notify_time_info_updated()
def _panel_seek_with_selected_spike(self):
+ import panel as pn
+
ind_selected = self.controller.get_indices_spike_selected()
n_selected = ind_selected.size
@@ -340,8 +342,14 @@ def _panel_seek_with_selected_spike(self):
# Center view on spike
margin = self.xsize / 3
- self.figure.x_range.start = peak_time - margin
- self.figure.x_range.end = peak_time + 2 * margin
+ range_start = peak_time - margin
+ range_end = peak_time + 2 * margin
+
+ def _do_update():
+ self.figure.x_range.start = range_start
+ self.figure.x_range.end = range_end
+
+ pn.state.execute(_do_update, schedule=True)
self._block_auto_refresh_and_notify = False
self.refresh()
diff --git a/spikeinterface_gui/unitlistview.py b/spikeinterface_gui/unitlistview.py
index 5512aadf..466c355c 100644
--- a/spikeinterface_gui/unitlistview.py
+++ b/spikeinterface_gui/unitlistview.py
@@ -667,11 +667,17 @@ def _panel_on_visible_checkbox_toggled(self, row):
self.refresh()
def _panel_on_unit_visibility_changed(self):
+ import panel as pn
+
# update selection to match visible units
visible_units = self.controller.get_visible_unit_ids()
unit_ids = list(self.table.value.index.values)
rows_to_select = [unit_ids.index(unit_id) for unit_id in visible_units if unit_id in unit_ids]
- self.table.selection = rows_to_select
+
+ def _do_update():
+ self.table.selection = rows_to_select
+
+ pn.state.execute(_do_update, schedule=True)
self.refresh()
def _panel_refresh_colors(self):
diff --git a/spikeinterface_gui/utils_panel.py b/spikeinterface_gui/utils_panel.py
index 7619651e..7a215c49 100644
--- a/spikeinterface_gui/utils_panel.py
+++ b/spikeinterface_gui/utils_panel.py
@@ -14,6 +14,7 @@
from panel.param import param
from panel.custom import ReactComponent
from panel.widgets import Tabulator
+from panel.reactive import ReactiveHTML
from bokeh.models import ColumnDataSource, Patches, HTMLTemplateFormatter
@@ -292,7 +293,7 @@ class SelectableTabulator(pn.viewable.Viewer):
refresh_table_function: Callable | None
A function to call when the table a new selection is made via keyboard shortcuts.
on_only_function: Callable | None
- A function to call when the table a ctrl+selection is made via keyboard shortcuts or a double-click.
+ A function to call when a ctrl+selection is made via keyboard shortcuts or a double-click.
conditional_shortcut: Callable | None
A function that returns True if the shortcuts should be enabled, False otherwise.
column_callbacks: dict[Callable] | None
@@ -313,6 +314,7 @@ def __init__(
self._formatters = kwargs.get("formatters", {})
self._editors = kwargs.get("editors", {})
self._frozen_columns = kwargs.get("frozen_columns", [])
+ self._selectable = kwargs.get("selectable", True)
if "sortable" in kwargs:
self._sortable = kwargs.pop("sortable")
else:
@@ -394,8 +396,9 @@ def selection(self):
@selection.setter
def selection(self, val):
- if isinstance(self.tabulator.selectable, int):
- max_selectable = self.tabulator.selectable
+ # Added this logic to prevent max selection with shift+click / arrows
+ if isinstance(self._selectable, int):
+ max_selectable = self._selectable
if not isinstance(max_selectable, bool):
if len(val) > max_selectable:
val = val[-max_selectable:]
@@ -437,6 +440,7 @@ def refresh_tabulator_settings(self):
self.tabulator.formatters = self._formatters
self.tabulator.editors = self._editors
self.tabulator.frozen_columns = self._frozen_columns
+ self.tabulator.selectable = self._selectable
self.tabulator.sorters = []
def refresh(self):
@@ -485,7 +489,7 @@ def _on_selection_change(self, event):
Handle the selection change event. This is called when the selection is changed.
"""
if self._refresh_table_function is not None:
- self._refresh_table_function()
+ pn.state.execute(self._refresh_table_function, schedule=True)
def _on_click(self, event):
"""
@@ -685,3 +689,58 @@ class KeyboardShortcuts(ReactComponent):
return <>>;
}
"""
+
+
+class PostMessageListener(ReactComponent):
+ """
+ Listen to window.postMessage events and forward them to Python via on_msg().
+ This avoids ReactiveHTML/Bokeh 'source' linkage issues.
+ """
+ _model_name = "PostMessageListener"
+ _model_module = "post_message_listener"
+ _model_module_version = "0.0.1"
+
+ # If set, only forward messages whose event.data.type matches this value.
+ accept_type = param.String(default="curation-data")
+
+ _esm = """
+ export function render({ model }) {
+ const [accept_type] = model.useState("accept_type");
+
+ function onMessage(event) {
+ const data = event.data;
+
+ // Ignore messages from browser extensions
+ if (data && data.source === "react-devtools-content-script") return;
+
+ if (accept_type && data && data.type !== accept_type) return;
+
+ // Always include a timestamp so repeated sends still look "new"
+ model.send_msg({ payload: data, _ts: Date.now() });
+ }
+
+ React.useEffect(() => {
+ window.addEventListener("message", onMessage);
+ return () => window.removeEventListener("message", onMessage);
+ }, [accept_type]);
+
+ return <>>;
+ }
+ """
+
+
+class IFrameDetector(pn.reactive.ReactiveHTML):
+ """
+ Simple component that detects if it is running inside an iframe.
+ """
+ in_iframe = param.Parameter(default=None)
+
+ _template = ""
+
+ _scripts = {
+ "render": """
+ const val = window.self !== window.top;
+ console.log("iframe detector JS:", val);
+ data.in_iframe = val; // reliable sync
+ """
+ }
diff --git a/spikeinterface_gui/view_base.py b/spikeinterface_gui/view_base.py
index 3fb2507a..52aaebdc 100644
--- a/spikeinterface_gui/view_base.py
+++ b/spikeinterface_gui/view_base.py
@@ -125,7 +125,8 @@ def _refresh(self, **kwargs):
if self.backend == "qt":
self._qt_refresh(**kwargs)
elif self.backend == "panel":
- self._panel_refresh(**kwargs)
+ import panel as pn
+ pn.state.execute(lambda: self._panel_refresh(**kwargs), schedule=True)
def warning(self, warning_msg):
if self.backend == "qt":
diff --git a/spikeinterface_gui/waveformheatmapview.py b/spikeinterface_gui/waveformheatmapview.py
index 7ea52055..42d0a5ca 100644
--- a/spikeinterface_gui/waveformheatmapview.py
+++ b/spikeinterface_gui/waveformheatmapview.py
@@ -257,33 +257,45 @@ def _panel_make_layout(self):
)
def _panel_refresh(self):
+ import panel as pn
+
hist2d = self.get_plotting_data()
- if hist2d is None:
- self.image_source.data.update({
- "image": [],
- "dw": [],
- "dh": []
- })
- return
+ def _do_update():
+ if hist2d is None:
+ self.image_source.data = {
+ "image": [],
+ "dw": [],
+ "dh": []
+ }
+ return
+
+ self.image_source.data = {
+ "image": [hist2d.T],
+ "dw": [hist2d.shape[0]],
+ "dh": [hist2d.shape[1]]
+ }
- self.image_source.data.update({
- "image": [hist2d.T],
- "dw": [hist2d.shape[0]],
- "dh": [hist2d.shape[1]]
- })
+ self.color_mapper.low = 0
+ self.color_mapper.high = np.max(hist2d)
- self.color_mapper.low = 0
- self.color_mapper.high = np.max(hist2d)
+ self.figure.x_range.start = 0
+ self.figure.x_range.end = hist2d.shape[0]
+ self.figure.y_range.start = 0
+ self.figure.y_range.end = hist2d.shape[1]
- self.figure.x_range.start = 0
- self.figure.x_range.end = hist2d.shape[0]
- self.figure.y_range.start = 0
- self.figure.y_range.end = hist2d.shape[1]
+ pn.state.execute(_do_update, schedule=True)
def _panel_gain_zoom(self, event):
+ import panel as pn
+
factor = 1.3 if event.delta > 0 else 1 / 1.3
- self.color_mapper.high = self.color_mapper.high * factor
+ new_high = self.color_mapper.high * factor
+
+ def _do_update():
+ self.color_mapper.high = new_high
+
+ pn.state.execute(_do_update, schedule=True)
diff --git a/spikeinterface_gui/waveformview.py b/spikeinterface_gui/waveformview.py
index d53e7497..38f07e9e 100644
--- a/spikeinterface_gui/waveformview.py
+++ b/spikeinterface_gui/waveformview.py
@@ -1022,7 +1022,8 @@ def _panel_on_mode_selector_changed(self, event):
self.refresh()
def _panel_gain_zoom(self, event):
- self.figure_geom.toolbar.active_scroll = None
+ import panel as pn
+
current_time = time.perf_counter()
if self.last_wheel_event_time is not None:
time_elapsed = current_time - self.last_wheel_event_time
@@ -1030,8 +1031,14 @@ def _panel_gain_zoom(self, event):
time_elapsed = 1000
if time_elapsed > _wheel_refresh_time:
modifiers = event.modifiers
- if modifiers["shift"] and modifiers["alt"]:
+
+ def _enable_active_scroll():
+ self.figure_geom.toolbar.active_scroll = self.zoom_tool
+
+ def _disable_active_scroll():
self.figure_geom.toolbar.active_scroll = None
+
+ if modifiers["shift"] and modifiers["alt"]:
if self.mode == "geometry":
factor_ratio = 1.3 if event.delta > 0 else 1 / 1.3
# adjust y range and keep center
@@ -1040,29 +1047,36 @@ def _panel_gain_zoom(self, event):
yrange = ymax - ymin
ymid = 0.5 * (ymin + ymax)
new_yrange = yrange * factor_ratio
- ymin = ymid - new_yrange / 2.
- ymax = ymid + new_yrange / 2.
- self.figure_geom.y_range.start = ymin
- self.figure_geom.y_range.end = ymax
+ new_ymin = ymid - new_yrange / 2.
+ new_ymax = ymid + new_yrange / 2.
+
+ def _do_range_update():
+ self.figure_geom.toolbar.active_scroll = None
+ self.figure_geom.y_range.start = new_ymin
+ self.figure_geom.y_range.end = new_ymax
+
+ pn.state.execute(_do_range_update, schedule=True)
+ else:
+ pn.state.execute(_disable_active_scroll, schedule=True)
elif modifiers["shift"]:
- self.figure_geom.toolbar.active_scroll = self.zoom_tool
+ pn.state.execute(_enable_active_scroll, schedule=True)
elif modifiers["alt"]:
- self.figure_geom.toolbar.active_scroll = None
if self.mode == "geometry":
factor = 1.3 if event.delta > 0 else 1 / 1.3
self.factor_x *= factor
self._panel_refresh_mode_geometry(keep_range=True)
self._panel_refresh_spikes()
+ pn.state.execute(_disable_active_scroll, schedule=True)
elif not modifiers["ctrl"]:
- self.figure_geom.toolbar.active_scroll = None
if self.mode == "geometry":
factor = 1.3 if event.delta > 0 else 1 / 1.3
self.gain_y *= factor
self._panel_refresh_mode_geometry(keep_range=True)
self._panel_refresh_spikes()
+ pn.state.execute(_disable_active_scroll, schedule=True)
else:
# Ignore the event if it occurs too quickly
- self.figure_geom.toolbar.active_scroll = None
+ pn.state.execute(_disable_active_scroll, schedule=True)
self.last_wheel_event_time = current_time
def _panel_refresh_mode_geometry(self, dict_visible_units=None, keep_range=False):
@@ -1388,11 +1402,14 @@ def _panel_clear_data_sources(self):
self.vlines_data_source_std.data = dict(xs=[], ys=[], colors=[])
def _panel_on_spike_selection_changed(self):
- self._panel_refresh_one_spike()
+ import panel as pn
+ pn.state.execute(self._panel_refresh_one_spike, schedule=True)
def _panel_on_channel_visibility_changed(self):
+ import panel as pn
+
keep_range = not self.settings["auto_move_on_unit_selection"]
- self._panel_refresh(keep_range=keep_range)
+ pn.state.execute(lambda: self._panel_refresh(keep_range=keep_range), schedule=True)
def _panel_handle_shortcut(self, event):
if event.data == "overlap":