Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 42 additions & 23 deletions spikeinterface_gui/basescatterview.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def _qt_refresh(self, set_scatter_range=False):

# set x range to time range of the current segment for scatter, and max count for histogram
# set y range to min and max of visible spike amplitudes
if set_scatter_range or not self._first_refresh_done:
if len(ymins) > 0 and (set_scatter_range or not self._first_refresh_done):
ymin = np.min(ymins)
ymax = np.max(ymaxs)
t_start, t_stop = self.controller.get_t_start_t_stop()
Expand Down Expand Up @@ -490,6 +490,7 @@ def _panel_make_layout(self):
self.plotted_inds = []

def _panel_refresh(self, set_scatter_range=False):
import panel as pn
from bokeh.models import FixedTicker

self.plotted_inds = []
Expand Down Expand Up @@ -555,28 +556,45 @@ def _panel_refresh(self, set_scatter_range=False):
# handle selected spikes
self._panel_update_selected_spikes()

# set y range to min and max of visible spike amplitudes
# Defer Range updates to avoid nested document lock issues
# def update_ranges():
if set_scatter_range or not self._first_refresh_done:
self.y_range.start = np.min(ymins)
self.y_range.end = np.max(ymaxs)
self._first_refresh_done = True
self.hist_fig.x_range.end = max_count
self.hist_fig.xaxis.ticker = FixedTicker(ticks=[0, max_count // 2, max_count])

# Schedule the update to run after the current event loop iteration
# pn.state.execute(update_ranges, schedule=True)

def _panel_on_select_button(self, event):
if self.select_toggle_button.value:
self.scatter_fig.toolbar.active_drag = self.lasso_tool
else:
self.scatter_fig.toolbar.active_drag = None
self.scatter_source.selected.indices = []
import panel as pn

value = self.select_toggle_button.value

def _do_update():
if value:
self.scatter_fig.toolbar.active_drag = self.lasso_tool
else:
self.scatter_fig.toolbar.active_drag = None
self.scatter_source.selected.indices = []

pn.state.execute(_do_update, schedule=True)

def _panel_change_segment(self, event):
import panel as pn

self._current_selected = 0
segment_index = int(self.segment_selector.value.split()[-1])
self.controller.set_time(segment_index=segment_index)
t_start, t_end = self.controller.get_t_start_t_stop()
self.scatter_fig.x_range.start = t_start
self.scatter_fig.x_range.end = t_end

def _do_update():
self.scatter_fig.x_range.start = t_start
self.scatter_fig.x_range.end = t_end

pn.state.execute(_do_update, schedule=True)
self.refresh(set_scatter_range=True)
self.notify_time_info_updated()

Expand Down Expand Up @@ -618,9 +636,17 @@ def _panel_split(self, event):
self.split()

def _panel_update_selected_spikes(self):
import panel as pn

# handle selected spikes
selected_spike_indices = self.controller.get_indices_spike_selected()
selected_spike_indices = np.intersect1d(selected_spike_indices, self.plotted_inds)
if len(selected_spike_indices) == 1:
selected_segment = self.controller.spikes[selected_spike_indices[0]]['segment_index']
segment_index = self.controller.get_time()[1]
if selected_segment != segment_index:
self.segment_selector.value = f"Segment {selected_segment}"
self._panel_change_segment(None)
if len(selected_spike_indices) > 0:
# map absolute indices to visible spikes
segment_index = self.controller.get_time()[1]
Expand All @@ -634,23 +660,16 @@ def _panel_update_selected_spikes(self):
# set selected spikes in scatter plot
if self.settings["auto_decimate"] and len(selected_indices) > 0:
selected_indices, = np.nonzero(np.isin(self.plotted_inds, selected_spike_indices))
self.scatter_source.selected.indices = list(selected_indices)
else:
self.scatter_source.selected.indices = []
selected_indices = []

def _do_update():
self.scatter_source.selected.indices = list(selected_indices)

pn.state.execute(_do_update, schedule=True)

def _panel_on_spike_selection_changed(self):
# set selection in scatter plot
selected_indices = self.controller.get_indices_spike_selected()
if len(selected_indices) == 0:
self.scatter_source.selected.indices = []
return
elif len(selected_indices) == 1:
selected_segment = self.controller.spikes[selected_indices[0]]['segment_index']
segment_index = self.controller.get_time()[1]
if selected_segment != segment_index:
self.segment_selector.value = f"Segment {selected_segment}"
self._panel_change_segment(None)
# update selected spikes
# update selected spikes (scheduled via pn.state.execute inside)
self._panel_update_selected_spikes()

def _panel_handle_shortcut(self, event):
Expand Down
60 changes: 30 additions & 30 deletions spikeinterface_gui/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
from spikeinterface.widgets.utils import get_unit_colors
from spikeinterface import compute_sparsity
from spikeinterface.core import get_template_extremum_channel
import spikeinterface.postprocessing
import spikeinterface.qualitymetrics
from spikeinterface.core.sorting_tools import spike_vector_to_indices
from spikeinterface.core.core_tools import check_json
from spikeinterface.curation import validate_curation_dict
from spikeinterface.curation.curation_model import CurationModel
from spikeinterface.widgets.utils import make_units_table_from_analyzer
Expand Down Expand Up @@ -345,6 +342,7 @@ def __init__(

if curation_data is not None:
# validate the curation data
curation_data = deepcopy(curation_data)
format_version = curation_data.get("format_version", None)
# assume version 2 if not present
if format_version is None:
Expand All @@ -354,24 +352,6 @@ def __init__(
except Exception as e:
raise ValueError(f"Invalid curation data.\nError: {e}")

if curation_data.get("merges") is None:
curation_data["merges"] = []
else:
# here we reset the merges for better formatting (str)
existing_merges = curation_data["merges"]
new_merges = []
for m in existing_merges:
if "unit_ids" not in m:
continue
if len(m["unit_ids"]) < 2:
continue
new_merges = add_merge(new_merges, m["unit_ids"])
curation_data["merges"] = new_merges
if curation_data.get("splits") is None:
curation_data["splits"] = []
if curation_data.get("removed") is None:
curation_data["removed"] = []

elif self.analyzer.format == "binary_folder":
json_file = self.analyzer.folder / "spikeinterface_gui" / "curation_data.json"
if json_file.exists():
Expand All @@ -386,24 +366,27 @@ def __init__(

if curation_data is None:
curation_data = deepcopy(empty_curation_data)
curation_data["unit_ids"] = self.unit_ids.tolist()

self.curation_data = curation_data

self.has_default_quality_labels = False
if "label_definitions" not in self.curation_data:
if "label_definitions" not in curation_data:
if label_definitions is not None:
self.curation_data["label_definitions"] = label_definitions
curation_data["label_definitions"] = label_definitions
else:
self.curation_data["label_definitions"] = default_label_definitions.copy()
curation_data["label_definitions"] = default_label_definitions.copy()

if "quality" in self.curation_data["label_definitions"]:
curation_dict_quality_labels = self.curation_data["label_definitions"]["quality"]["label_options"]
# This will enable the default shortcuts if has default quality labels
self.has_default_quality_labels = False
if "quality" in curation_data["label_definitions"]:
curation_dict_quality_labels = curation_data["label_definitions"]["quality"]["label_options"]
default_quality_labels = default_label_definitions["quality"]["label_options"]
if set(curation_dict_quality_labels) == set(default_quality_labels):
if self.verbose:
print('Curation quality labels are the default ones')
self.has_default_quality_labels = True

curation_data = CurationModel(**curation_data).model_dump()
self.curation_data = curation_data

def check_is_view_possible(self, view_name):
from .viewlist import get_all_possible_views
possible_class_views = get_all_possible_views()
Expand Down Expand Up @@ -825,7 +808,24 @@ def compute_auto_merge(self, **params):
)

return merge_unit_groups, extra


def set_curation_data(self, curation_data):
print("Setting curation data")
new_curation_data = empty_curation_data.copy()
new_curation_data.update(curation_data)

if "unit_ids" not in curation_data:
print("Setting unit_ids from controller")
new_curation_data["unit_ids"] = self.unit_ids.tolist()

if "label_definitions" not in curation_data:
print("Setting default label definitions")
new_curation_data["label_definitions"] = default_label_definitions.copy()

# validate the curation data
model = CurationModel(**new_curation_data)
self.curation_data = model.model_dump()

def curation_can_be_saved(self):
return self.analyzer.format != "memory"

Expand Down
29 changes: 29 additions & 0 deletions spikeinterface_gui/curation_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,32 @@ def add_merge(previous_merges, new_merge_unit_ids):
# Ensure the uniqueness
new_merges = [{"unit_ids": list(set(gp))} for gp in new_merge_units]
return new_merges


# def cast_unit_dtypes_in_curation(curation_data, unit_ids_dtype):
# """Cast unit ids in curation data to the correct dtype."""
# if "unit_ids" in curation_data:
# curation_data["unit_ids"] = [unit_ids_dtype(uid) for uid in curation_data["unit_ids"]]

# if "merges" in curation_data:
# for merge in curation_data["merges"]:
# merge["unit_ids"] = [unit_ids_dtype(uid) for uid in merge["unit_ids"]]
# new_unit_id = merge.get("new_unit_id", None)
# if new_unit_id is not None:
# merge["new_unit_id"] = unit_ids_dtype(new_unit_id)

# if "splits" in curation_data:
# for split in curation_data["splits"]:
# split["unit_id"] = unit_ids_dtype(split["unit_id"])
# new_unit_ids = split.get("new_unit_ids", None)
# if new_unit_ids is not None:
# split["new_unit_ids"] = [unit_ids_dtype(uid) for uid in new_unit_ids]

# if "removed" in curation_data:
# curation_data["removed"] = [unit_ids_dtype(uid) for uid in curation_data["removed"]]

# if "manual_labels" in curation_data:
# for label_entry in curation_data["manual_labels"]:
# label_entry["unit_id"] = unit_ids_dtype(label_entry["unit_id"])

# return curation_data
Loading