diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 2b7401c1..3567efa7 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -950,6 +950,10 @@ def show( if not all(isinstance(t, str) for t in title): raise TypeError("All titles must be strings.") + # Track whether the caller supplied their own axes so we can skip + # plt.show() later (ax is reassigned inside the rendering loop). + user_supplied_ax = ax is not None + # get original axis extent for later comparison ax_x_min, ax_x_max = (np.inf, -np.inf) ax_y_min, ax_y_max = (np.inf, -np.inf) @@ -1273,8 +1277,11 @@ def _draw_colorbar( # Default (show=None): display in non-interactive mode (scripts), suppress in interactive # sessions. We check both sys.ps1 (standard REPL) and matplotlib.is_interactive() # (covers IPython, Jupyter, plt.ion(), and IDE consoles like PyCharm). + # When the user supplies their own axes, they manage the figure lifecycle, so we + # default to not calling plt.show(). This allows multiple .pl.show(ax=...) calls + # to accumulate content on the same axes (see #362, #71). if show is None: - show = not hasattr(sys, "ps1") and not matplotlib.is_interactive() + show = False if user_supplied_ax else (not hasattr(sys, "ps1") and not matplotlib.is_interactive()) if show: plt.show() return (fig_params.ax if fig_params.axs is None else fig_params.axs) if return_ax else None # shuts up ruff diff --git a/tests/pl/test_show.py b/tests/pl/test_show.py index 2b7e444d..a4033d2b 100644 --- a/tests/pl/test_show.py +++ b/tests/pl/test_show.py @@ -1,4 +1,7 @@ +from unittest.mock import patch + import matplotlib +import matplotlib.pyplot as plt import scanpy as sc from spatialdata import SpatialData @@ -21,3 +24,19 @@ class TestShow(PlotTester, metaclass=PlotTesterMeta): def test_plot_pad_extent_adds_padding(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_images(element="blobs_image").pl.show(pad_extent=100) + + def test_no_plt_show_when_ax_provided(self, sdata_blobs: SpatialData): + """plt.show() must not be called when the user supplies ax= (regression for #362).""" + _, ax = plt.subplots() + with patch("spatialdata_plot.pl.basic.plt.show") as mock_show: + sdata_blobs.pl.render_images(element="blobs_image").pl.show(ax=ax) + mock_show.assert_not_called() + plt.close("all") + + def test_plt_show_when_ax_provided_and_show_true(self, sdata_blobs: SpatialData): + """Explicit show=True still calls plt.show() even with ax=.""" + _, ax = plt.subplots() + with patch("spatialdata_plot.pl.basic.plt.show") as mock_show: + sdata_blobs.pl.render_images(element="blobs_image").pl.show(ax=ax, show=True) + mock_show.assert_called_once() + plt.close("all")