diff --git a/doc/changes/dev/13765.newfeature.rst b/doc/changes/dev/13765.newfeature.rst new file mode 100644 index 00000000000..306b6b351a2 --- /dev/null +++ b/doc/changes/dev/13765.newfeature.rst @@ -0,0 +1 @@ +Allow per-channel color overrides in :func:`mne.viz.plot_raw` via channel name keys in the ``color`` dict, by :newcontrib:`Hansuja Budhiraja`. \ No newline at end of file diff --git a/mne/viz/_mpl_figure.py b/mne/viz/_mpl_figure.py index 47a26047768..99ba7405788 100644 --- a/mne/viz/_mpl_figure.py +++ b/mne/viz/_mpl_figure.py @@ -2006,8 +2006,13 @@ def _draw_traces(self): ) offsets = self.mne.trace_offsets[offset_ixs] bad_bool = np.isin(ch_names, self.mne.info["bads"]) - # colors - good_ch_colors = [self.mne.ch_color_dict[_type] for _type in ch_types] + # colors: allow overrides by channel name, then by channel type + good_ch_colors = [] + for _name, _type in zip(ch_names, ch_types): + if _name in self.mne.ch_color_dict: + good_ch_colors.append(self.mne.ch_color_dict[_name]) + else: + good_ch_colors.append(self.mne.ch_color_dict[_type]) ch_colors = to_rgba_array( [ self.mne.ch_color_bad if _bad else _color diff --git a/mne/viz/raw.py b/mne/viz/raw.py index eecad33ad55..317b8bdc87f 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -96,6 +96,10 @@ def plot_raw( emg='k', ref_meg='steelblue', misc='k', stim='k', resp='k', chpi='k') + If a dict, keys can be channel *types* (e.g., ``'eeg'``) and/or + channel *names* (e.g., ``'SFG, Left'``); name-based entries + take precedence over type-based ones. + bad_color : color object Color to make bad channels. %(event_color)s diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index 8554f12a82b..53f7d82e214 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -7,6 +7,7 @@ from copy import deepcopy from pathlib import Path +import matplotlib.colors as mcolors import matplotlib.pyplot as plt import numpy as np import pytest @@ -317,6 +318,31 @@ def test_scale_bar(browser_backend): bar_lims = bar.get_ydata() assert_allclose(y_lims, bar_lims, atol=1e-4) + # Per-channel color overrides via channel names (matplotlib only). + if ismpl: + sfreq = 100.0 + ch_names = ["SFG, Left", "SFG, Right", "MFG, Left"] + info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types="eeg") + data = np.zeros((len(ch_names), int(sfreq))) # 1 second of zeros + raw2 = RawArray(data, info) + + color = {"eeg": "k", "SFG, Left": "red"} + browser_backend._close_all() + fig2 = plot_raw(raw2, color=color, show=False) + + # ch_colors stores the "good" (non-bad) colors, in visible channel order + assert fig2.mne.ch_colors[0] == "red" + assert fig2.mne.ch_colors[1] == "k" + assert fig2.mne.ch_colors[2] == "k" + + # check colours on the plot are also correct + for trace, ch_color in zip(fig2.mne.traces, fig2.mne.ch_colors): + assert np.allclose( + mcolors.to_rgba(trace.get_color()), mcolors.to_rgba(ch_color) + ), f"Expected {ch_color}, got {trace.get_color()}" + + browser_backend._close_all() + def test_plot_raw_selection(raw, browser_backend): """Test selection mode of plot_raw()."""