Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def _render_shapes(
# When groups are specified, filter out non-matching elements by default.
# Only show non-matching elements if the user explicitly sets na_color.
_na = render_params.cmap_params.na_color
if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.alpha == "00"):
if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.is_fully_transparent()):
keep, color_source_vector, color_vector = _filter_groups_transparent_na(
groups, color_source_vector, color_vector
)
Expand Down Expand Up @@ -530,6 +530,8 @@ def _render_shapes(

agg, color_span = _apply_ds_norm(agg, norm)
na_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex())
if render_params.cmap_params.na_color.is_fully_transparent():
nan_agg = None
color_key = _build_color_key(
transformed_element,
col_for_color,
Expand Down Expand Up @@ -832,7 +834,7 @@ def _render_points(
# When groups are specified, filter out non-matching elements by default.
# Only show non-matching elements if the user explicitly sets na_color.
_na = render_params.cmap_params.na_color
if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.alpha == "00"):
if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.is_fully_transparent()):
keep, color_source_vector, color_vector = _filter_groups_transparent_na(
groups, color_source_vector, color_vector
)
Expand Down Expand Up @@ -925,6 +927,8 @@ def _render_points(

agg, color_span = _apply_ds_norm(agg, norm)
na_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex())
if render_params.cmap_params.na_color.is_fully_transparent():
nan_agg = None
color_key = _build_color_key(
transformed_element,
col_for_color,
Expand Down Expand Up @@ -1385,7 +1389,7 @@ def _render_labels(
groups is not None
and categorical
and color_source_vector is not None
and (_na.default_color_set or _na.alpha == "00")
and (_na.default_color_set or _na.is_fully_transparent())
):
keep_vec = color_source_vector.isin(groups)
matching_ids = instance_id[keep_vec]
Expand Down
4 changes: 4 additions & 0 deletions src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def alpha_is_user_defined(self) -> bool:
"""Get whether an alpha was set during object creation."""
return self.user_defined_alpha

def is_fully_transparent(self) -> bool:
"""Check whether this color is fully transparent (alpha == 0)."""
return self.alpha == "00"


@dataclass
class CmapParams:
Expand Down
19 changes: 19 additions & 0 deletions tests/pl/test_render_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,3 +913,22 @@ def test_shade_categorical_cmap_used_when_no_color_key():
shaded_blue = _ds_shade_categorical(agg, None, np.array(["#0000ff"] * 100), alpha=1.0)
# Different color_vector[0] values should produce different shaded output
assert not np.array_equal(np.asarray(shaded_red), np.asarray(shaded_blue))


def test_datashader_na_color_none_no_nan_overlay_points(sdata_blobs: SpatialData):
"""NaN overlay is skipped when na_color is fully transparent (#565)."""
pts = sdata_blobs.points["blobs_points"].compute()
n = len(pts)
values = np.full(n, np.nan)
values[: n // 2] = np.random.default_rng(0).uniform(0, 100, n // 2)
pts["val"] = values
sdata_blobs.points["blobs_points"] = PointsModel.parse(pts)

fig, ax = plt.subplots()
sdata_blobs.pl.render_points("blobs_points", color="val", na_color=None, method="datashader").pl.show(ax=ax)

assert len(ax.get_images()) == 1, (
f"Expected 1 image (no NaN overlay), got {len(ax.get_images())}; "
"datashader is still rendering an opaque NaN overlay despite na_color=None"
)
plt.close(fig)
21 changes: 21 additions & 0 deletions tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,3 +1210,24 @@ def test_datashader_colorbar_range_matches_data(sdata_blobs: SpatialData):
)
assert cbar_vmin >= data_min * 0.99 - 0.01, f"Colorbar min ({cbar_vmin:.2f}) is below data min ({data_min:.2f})"
plt.close(fig)


@pytest.mark.parametrize(
("na_color", "expected_images"),
[(None, 1), ("red", 2)],
ids=["transparent_skips_overlay", "opaque_renders_overlay"],
)
def test_datashader_na_color_nan_overlay(sdata_blobs: SpatialData, na_color: str | None, expected_images: int):
"""NaN overlay is rendered only when na_color is opaque (#565)."""
n = len(sdata_blobs.shapes["blobs_circles"])
values = np.full(n, np.nan)
values[: n // 2] = np.random.default_rng(0).uniform(0, 100, n // 2)
sdata_blobs.shapes["blobs_circles"]["val"] = values

fig, ax = plt.subplots()
sdata_blobs.pl.render_shapes("blobs_circles", color="val", na_color=na_color, method="datashader").pl.show(ax=ax)

assert len(ax.get_images()) == expected_images, (
f"Expected {expected_images} image(s), got {len(ax.get_images())} for na_color={na_color!r}"
)
plt.close(fig)
Loading