Skip to content

Commit 224f069

Browse files
authored
perf(shapes): build matplotlib patches once, share across fill/outline (#691)
1 parent 34c23b4 commit 224f069

2 files changed

Lines changed: 106 additions & 94 deletions

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
_align_outline_vector_to_length,
6161
_apply_mask_to_outline_vectors,
6262
_ax_show_and_transform,
63+
_build_shape_patches,
6364
_check_obs_var_shadow,
6465
_color_vector_to_rgba,
6566
_convert_shapes,
@@ -905,6 +906,10 @@ def _render_shapes(
905906
cax = _build_ds_colorbar(reduction_bounds, norm, render_params.cmap_params.cmap)
906907

907908
elif method == "matplotlib":
909+
# Build the matplotlib patches once and share them across the fill and outline
910+
# collections; the geometry is identical, only colours/alpha/linewidth differ.
911+
prebuilt_patches = _build_shape_patches(shapes, render_params.scale)
912+
908913
# render outlines separately to ensure they are always underneath the shape
909914
if col_for_outline_color is not None and render_params.outline_alpha[0] > 0:
910915
outline_rgba = _color_vector_to_rgba(
@@ -924,6 +929,7 @@ def _render_shapes(
924929
fill_alpha=0.0,
925930
outline_alpha=render_params.outline_alpha[0],
926931
outline_color=outline_rgba,
932+
prebuilt_patches=prebuilt_patches,
927933
linewidth=render_params.outline_params.outer_outline_linewidth,
928934
zorder=render_params.zorder,
929935
)
@@ -942,6 +948,7 @@ def _render_shapes(
942948
fill_alpha=0.0,
943949
outline_alpha=render_params.outline_alpha[0],
944950
outline_color=render_params.outline_params.outer_outline_color.get_hex(),
951+
prebuilt_patches=prebuilt_patches,
945952
linewidth=render_params.outline_params.outer_outline_linewidth,
946953
zorder=render_params.zorder,
947954
# **kwargs,
@@ -962,6 +969,7 @@ def _render_shapes(
962969
fill_alpha=0.0,
963970
outline_alpha=render_params.outline_alpha[1],
964971
outline_color=render_params.outline_params.inner_outline_color.get_hex(),
972+
prebuilt_patches=prebuilt_patches,
965973
linewidth=render_params.outline_params.inner_outline_linewidth,
966974
zorder=render_params.zorder,
967975
# **kwargs,
@@ -975,6 +983,7 @@ def _render_shapes(
975983
shapes=shapes,
976984
s=render_params.scale,
977985
c=color_vector.copy(), # copy bc c is modified in _get_collection_shape
986+
prebuilt_patches=prebuilt_patches,
978987
render_params=render_params,
979988
rasterized=sc_settings._vector_friendly,
980989
cmap=render_params.cmap_params.cmap,

src/spatialdata_plot/pl/utils.py

Lines changed: 97 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,89 @@ def _color_vector_to_rgba(
598598
return rgba
599599

600600

601+
def _normalize_geom(geom: Any) -> Any:
602+
"""Canonicalize ring orientation so matplotlib's fill rules render holes correctly.
603+
604+
``shapely.normalize`` (shapely>=2) is preferred; falls back to ``geom.normalize()``.
605+
None/empty geometries and geometries that fail to normalize are returned unchanged.
606+
"""
607+
if geom is None or getattr(geom, "is_empty", False):
608+
return geom
609+
normalize_func = getattr(shapely, "normalize", None)
610+
if callable(normalize_func):
611+
try:
612+
return normalize_func(geom)
613+
except (GEOSException, TypeError, ValueError):
614+
return geom
615+
if hasattr(geom, "normalize"):
616+
try:
617+
return geom.normalize()
618+
except (GEOSException, TypeError, ValueError):
619+
return geom
620+
return geom
621+
622+
623+
def _build_shape_patches(
624+
shapes: GeoDataFrame,
625+
scale: float,
626+
) -> tuple[list[mpatches.Patch], list[int], int]:
627+
"""Build matplotlib patches from shape geometries, once.
628+
629+
Patch geometry is independent of colour/alpha, so it can be built a single time and
630+
shared across the fill and outline ``PatchCollection``s in :func:`_render_shapes`
631+
instead of being rebuilt per layer (the dominant cost for shape elements).
632+
633+
Returns
634+
-------
635+
patches
636+
The matplotlib patches (a MultiPolygon expands to several patches).
637+
patch_row_idx
638+
For each patch, the index into the empty-filtered, re-indexed shapes — used to
639+
look up the per-shape colour.
640+
n_shapes
641+
Number of shapes after empty filtering (used for the single-colour broadcast rule).
642+
"""
643+
df: GeoDataFrame | pd.DataFrame = shapes if isinstance(shapes, GeoDataFrame) else pd.DataFrame(shapes)
644+
if "geometry" not in df.columns:
645+
return [], [], 0
646+
647+
# Normalize ring orientation, then drop empty geometries (both vectorized; fall
648+
# back to per-geometry normalization only if the bulk call rejects an input).
649+
geom_array = df["geometry"].to_numpy()
650+
try:
651+
geom_array = shapely.normalize(geom_array)
652+
except (GEOSException, TypeError, ValueError):
653+
geom_array = np.array([_normalize_geom(g) for g in geom_array], dtype=object)
654+
keep = ~shapely.is_empty(geom_array)
655+
geoms = geom_array[keep]
656+
radii = df["radius"].to_numpy()[keep] if "radius" in df.columns else None
657+
658+
# Resolve the scale scalar once instead of per shape.
659+
scale_value = _extract_scalar_value(scale, default=1.0)
660+
661+
patches: list[mpatches.Patch] = []
662+
patch_row_idx: list[int] = []
663+
for i, geom in enumerate(geoms):
664+
geom_type = geom.geom_type
665+
if geom_type == "Polygon":
666+
coords = np.asarray(geom.exterior.coords)
667+
centroid = np.mean(coords, axis=0)
668+
scaled = centroid + (coords - centroid) * scale_value
669+
patches.append(mpatches.Polygon(scaled, closed=True))
670+
patch_row_idx.append(i)
671+
elif geom_type == "MultiPolygon":
672+
for m in _make_patch_from_multipolygon(geom):
673+
_scale_pathpatch_around_centroid(m, scale_value)
674+
patches.append(m)
675+
patch_row_idx.append(i)
676+
elif geom_type == "Point":
677+
radius_value = _extract_scalar_value(radii[i], default=0.0) if radii is not None else 0.0
678+
patches.append(mpatches.Circle((geom.x, geom.y), radius=radius_value * scale_value))
679+
patch_row_idx.append(i)
680+
681+
return patches, patch_row_idx, len(geoms)
682+
683+
601684
def _get_collection_shape(
602685
shapes: list[GeoDataFrame],
603686
c: Any,
@@ -608,6 +691,7 @@ def _get_collection_shape(
608691
outline_alpha: None | float = None,
609692
outline_color: None | str | list[float] | np.ndarray = "white",
610693
linewidth: float = 0.0,
694+
prebuilt_patches: tuple[list[mpatches.Patch], list[int], int] | None = None,
611695
**kwargs: Any,
612696
) -> PatchCollection:
613697
"""
@@ -718,107 +802,26 @@ def _as_rgba_array(x: Any) -> np.ndarray:
718802
else:
719803
outline_c = [None] * fill_c.shape[0]
720804

721-
if isinstance(shapes, GeoDataFrame):
722-
shapes_df: GeoDataFrame | pd.DataFrame = shapes.copy()
723-
else:
724-
shapes_df = pd.DataFrame(shapes, copy=True)
725-
726-
# Robustly normalise geometries to a canonical representation.
727-
# This ensures consistent exterior/interior ring orientation so that
728-
# matplotlib's fill rules handle holes correctly regardless of user input.
729-
if "geometry" in shapes_df.columns:
730-
731-
def _normalize_geom(geom: Any) -> Any:
732-
if geom is None or getattr(geom, "is_empty", False):
733-
return geom
734-
# shapely.normalize is available in shapely>=2; fall back to geom.normalize()
735-
normalize_func = getattr(shapely, "normalize", None)
736-
if callable(normalize_func):
737-
try:
738-
return normalize_func(geom)
739-
except (GEOSException, TypeError, ValueError):
740-
return geom
741-
if hasattr(geom, "normalize"):
742-
try:
743-
return geom.normalize()
744-
except (GEOSException, TypeError, ValueError):
745-
return geom
746-
return geom
747-
748-
shapes_df["geometry"] = shapes_df["geometry"].apply(_normalize_geom)
749-
750-
shapes_df = shapes_df[shapes_df["geometry"].apply(lambda geom: not geom.is_empty)]
751-
shapes_df = shapes_df.reset_index(drop=True)
752-
753-
def _assign_fill_and_outline_to_row(
754-
fill_colors: list[Any],
755-
outline_colors: list[Any],
756-
row: dict[str, Any],
757-
idx: int,
758-
is_multiple_shapes: bool,
759-
) -> None:
760-
if is_multiple_shapes and len(fill_colors) == 1:
761-
row["fill_c"] = fill_colors[0]
762-
row["outline_c"] = outline_colors[0]
763-
else:
764-
row["fill_c"] = fill_colors[idx]
765-
row["outline_c"] = outline_colors[idx]
766-
767-
def _process_polygon(row: pd.Series, scale: float) -> dict[str, Any]:
768-
coords = np.array(row["geometry"].exterior.coords)
769-
centroid = np.mean(coords, axis=0)
770-
scale_value = _extract_scalar_value(scale, default=1.0)
771-
scaled = (centroid + (coords - centroid) * scale_value).tolist()
772-
return {**row.to_dict(), "geometry": mpatches.Polygon(scaled, closed=True)}
773-
774-
def _process_multipolygon(row: pd.Series, scale: float) -> list[dict[str, Any]]:
775-
mp = _make_patch_from_multipolygon(row["geometry"])
776-
row_dict = row.to_dict()
777-
for m in mp:
778-
_scale_pathpatch_around_centroid(m, scale)
779-
return [{**row_dict, "geometry": m} for m in mp]
780-
781-
def _process_point(row: pd.Series, scale: float) -> dict[str, Any]:
782-
radius_value = _extract_scalar_value(row["radius"], default=0.0)
783-
scale_value = _extract_scalar_value(scale, default=1.0)
784-
radius = radius_value * scale_value
785-
786-
return {
787-
**row.to_dict(),
788-
"geometry": mpatches.Circle((row["geometry"].x, row["geometry"].y), radius=radius),
789-
}
790-
791-
def _create_patches(
792-
shapes_df_: GeoDataFrame, fill_colors: list[Any], outline_colors: list[Any], scale: float
793-
) -> pd.DataFrame:
794-
rows: list[dict[str, Any]] = []
795-
is_multiple = len(shapes_df_) > 1
796-
for idx, row in shapes_df_.iterrows():
797-
geom_type = row["geometry"].geom_type
798-
processed: list[dict[str, Any]] = []
799-
if geom_type == "Polygon":
800-
processed.append(_process_polygon(row, scale))
801-
elif geom_type == "MultiPolygon":
802-
processed.extend(_process_multipolygon(row, scale))
803-
elif geom_type == "Point":
804-
processed.append(_process_point(row, scale))
805-
for pr in processed:
806-
_assign_fill_and_outline_to_row(fill_colors, outline_colors, pr, idx, is_multiple)
807-
rows.append(pr)
808-
return pd.DataFrame(rows)
809-
810-
patches = _create_patches(
811-
shapes_df, fill_c.tolist(), outline_c.tolist() if hasattr(outline_c, "tolist") else outline_c, s
805+
# Build (or reuse) the matplotlib patches. Geometry is colour-independent, so the
806+
# caller can build it once via `_build_shape_patches` and share it across the fill
807+
# and outline collections instead of rebuilding it on every call.
808+
patches, patch_row_idx, n_shapes = (
809+
prebuilt_patches if prebuilt_patches is not None else _build_shape_patches(shapes, s)
812810
)
813811

814-
if patches.empty:
812+
if not patches:
815813
return PatchCollection([])
816814

815+
# Expand the per-shape fill colours to per-patch (a MultiPolygon owns several
816+
# patches). Preserve the single-colour broadcast used for multi-shape elements.
817+
broadcast_single = n_shapes > 1 and len(fill_c) == 1
818+
patch_fill = np.repeat(fill_c, len(patches), axis=0) if broadcast_single else fill_c[patch_row_idx]
819+
817820
return PatchCollection(
818-
patches["geometry"].values.tolist(),
821+
patches,
819822
snap=False,
820823
lw=linewidth,
821-
facecolor=patches["fill_c"],
824+
facecolor=patch_fill,
822825
edgecolor=None if all(o is None for o in outline_c) else outline_c,
823826
**kwargs,
824827
)

0 commit comments

Comments
 (0)