diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index f0e7a1e8..af312595 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -375,7 +375,7 @@ def _render_shapes( # When groups are specified, filter out non-matching elements by default. # Only show non-matching elements if the user explicitly sets na_color. _na = render_params.cmap_params.na_color - if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.alpha == "00"): + if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.is_fully_transparent()): keep, color_source_vector, color_vector = _filter_groups_transparent_na( groups, color_source_vector, color_vector ) @@ -530,6 +530,8 @@ def _render_shapes( agg, color_span = _apply_ds_norm(agg, norm) na_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex()) + if render_params.cmap_params.na_color.is_fully_transparent(): + nan_agg = None color_key = _build_color_key( transformed_element, col_for_color, @@ -832,7 +834,7 @@ def _render_points( # When groups are specified, filter out non-matching elements by default. # Only show non-matching elements if the user explicitly sets na_color. _na = render_params.cmap_params.na_color - if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.alpha == "00"): + if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.is_fully_transparent()): keep, color_source_vector, color_vector = _filter_groups_transparent_na( groups, color_source_vector, color_vector ) @@ -925,6 +927,8 @@ def _render_points( agg, color_span = _apply_ds_norm(agg, norm) na_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex()) + if render_params.cmap_params.na_color.is_fully_transparent(): + nan_agg = None color_key = _build_color_key( transformed_element, col_for_color, @@ -1385,7 +1389,7 @@ def _render_labels( groups is not None and categorical and color_source_vector is not None - and (_na.default_color_set or _na.alpha == "00") + and (_na.default_color_set or _na.is_fully_transparent()) ): keep_vec = color_source_vector.isin(groups) matching_ids = instance_id[keep_vec] diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index 24ba1d4f..8f7977be 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -138,6 +138,10 @@ def alpha_is_user_defined(self) -> bool: """Get whether an alpha was set during object creation.""" return self.user_defined_alpha + def is_fully_transparent(self) -> bool: + """Check whether this color is fully transparent (alpha == 0).""" + return self.alpha == "00" + @dataclass class CmapParams: diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index 0cd7a40d..eba38fbd 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -913,3 +913,22 @@ def test_shade_categorical_cmap_used_when_no_color_key(): shaded_blue = _ds_shade_categorical(agg, None, np.array(["#0000ff"] * 100), alpha=1.0) # Different color_vector[0] values should produce different shaded output assert not np.array_equal(np.asarray(shaded_red), np.asarray(shaded_blue)) + + +def test_datashader_na_color_none_no_nan_overlay_points(sdata_blobs: SpatialData): + """NaN overlay is skipped when na_color is fully transparent (#565).""" + pts = sdata_blobs.points["blobs_points"].compute() + n = len(pts) + values = np.full(n, np.nan) + values[: n // 2] = np.random.default_rng(0).uniform(0, 100, n // 2) + pts["val"] = values + sdata_blobs.points["blobs_points"] = PointsModel.parse(pts) + + fig, ax = plt.subplots() + sdata_blobs.pl.render_points("blobs_points", color="val", na_color=None, method="datashader").pl.show(ax=ax) + + assert len(ax.get_images()) == 1, ( + f"Expected 1 image (no NaN overlay), got {len(ax.get_images())}; " + "datashader is still rendering an opaque NaN overlay despite na_color=None" + ) + plt.close(fig) diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index 409d4487..87fa2949 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -1210,3 +1210,24 @@ def test_datashader_colorbar_range_matches_data(sdata_blobs: SpatialData): ) assert cbar_vmin >= data_min * 0.99 - 0.01, f"Colorbar min ({cbar_vmin:.2f}) is below data min ({data_min:.2f})" plt.close(fig) + + +@pytest.mark.parametrize( + ("na_color", "expected_images"), + [(None, 1), ("red", 2)], + ids=["transparent_skips_overlay", "opaque_renders_overlay"], +) +def test_datashader_na_color_nan_overlay(sdata_blobs: SpatialData, na_color: str | None, expected_images: int): + """NaN overlay is rendered only when na_color is opaque (#565).""" + n = len(sdata_blobs.shapes["blobs_circles"]) + values = np.full(n, np.nan) + values[: n // 2] = np.random.default_rng(0).uniform(0, 100, n // 2) + sdata_blobs.shapes["blobs_circles"]["val"] = values + + fig, ax = plt.subplots() + sdata_blobs.pl.render_shapes("blobs_circles", color="val", na_color=na_color, method="datashader").pl.show(ax=ax) + + assert len(ax.get_images()) == expected_images, ( + f"Expected {expected_images} image(s), got {len(ax.get_images())} for na_color={na_color!r}" + ) + plt.close(fig)