Skip to content

Commit 55eab64

Browse files
authored
Add interactive annotation ability via .pl.annotate() (#684)
1 parent 1d3e78d commit 55eab64

9 files changed

Lines changed: 746 additions & 21 deletions

File tree

.gitignore

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,11 @@ _version.py
4444

4545
# pixi
4646
pixi.lock
47+
48+
# Local sandbox data + historical prototype notebooks (Sandbox.ipynb is the
49+
# active one and is tracked; the others are kept locally for reference only).
50+
/sandbox_data/
51+
/Sandbox.anywidget-v0.ipynb
52+
/Sandbox.ipympl-v0.ipynb
53+
/verify_ssh_annotate.ipynb
54+
Sandbox.ipynb

pyproject.toml

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ dependencies = [
3131
"scikit-learn",
3232
"spatialdata>=0.3",
3333
]
34+
optional-dependencies.interactive = [
35+
"anybioimage>=0.3,<0.4",
36+
"anywidget",
37+
"ipykernel",
38+
"ipywidgets",
39+
]
3440
urls.Documentation = "https://spatialdata.scverse.org/projects/plot/en/latest/index.html"
3541
urls.Home-page = "https://github.com/scverse/spatialdata-plot.git"
3642
urls.Source = "https://github.com/scverse/spatialdata-plot.git"
@@ -61,7 +67,6 @@ doc = [
6167
"sphinxcontrib-katex",
6268
"sphinxext-opengraph",
6369
]
64-
6570
[tool.hatch]
6671
build.hooks.vcs.version-file = "_version.py"
6772
build.targets.wheel.packages = [ "src/spatialdata_plot" ]
@@ -86,29 +91,49 @@ envs.hatch-test.scripts.cov-report = [ "coverage report", "coverage xml -o cover
8691
metadata.allow-direct-references = true
8792
version.source = "vcs"
8893

89-
[tool.pixi]
90-
workspace.channels = [ "conda-forge" ]
91-
workspace.platforms = [ "linux-64", "osx-arm64" ]
92-
dependencies.python = ">=3.11"
93-
pypi-dependencies.spatialdata-plot = { path = ".", editable = true }
94-
tasks.format = "ruff format ."
95-
tasks.kernel-install = 'python -m ipykernel install --user --name pixi-dev --display-name "sdata-plot (dev)"'
96-
tasks.lab = "jupyter lab"
97-
tasks.lint = "ruff check ."
98-
tasks.pre-commit-install = "pre-commit install"
99-
tasks.pre-commit-run = "pre-commit run --all-files"
100-
tasks.test = "pytest -v --color=yes --tb=short --durations=10"
94+
[tool.pixi.workspace]
95+
channels = [ "conda-forge" ]
96+
platforms = [ "linux-64", "osx-arm64" ]
97+
98+
[tool.pixi.dependencies]
99+
python = ">=3.11"
100+
101+
[tool.pixi.pypi-dependencies]
102+
spatialdata-plot = { path = ".", editable = true }
103+
104+
# When the `interactive` feature is active, install the package with the
105+
# `interactive` PyPI extra (anywidget, ipykernel, ipywidgets) so the pixi
106+
# env mirrors what `pip install spatialdata-plot[interactive]` would give.
107+
[tool.pixi.feature.interactive.pypi-dependencies]
108+
spatialdata-plot = { path = ".", editable = true, extras = [ "interactive" ] }
109+
110+
[tool.pixi.tasks]
111+
format = "ruff format ."
112+
kernel-install = 'python -m ipykernel install --user --name pixi-dev --display-name "sdata-plot (dev)"'
113+
kernel-install-interactive = 'python -m ipykernel install --user --name sdata-plot-interactive --display-name "sdata-plot (interactive)"'
114+
lab = "jupyter lab"
115+
lint = "ruff check ."
116+
pre-commit-install = "pre-commit install"
117+
pre-commit-run = "pre-commit run --all-files"
118+
test = "pytest -v --color=yes --tb=short --durations=10"
119+
101120
# for gh-actions
102-
feature.py311.dependencies.python = "3.11.*"
103-
feature.py313.dependencies.python = "3.13.*"
121+
[tool.pixi.feature.py311.dependencies]
122+
python = "3.11.*"
123+
124+
[tool.pixi.feature.py313.dependencies]
125+
python = "3.13.*"
126+
127+
[tool.pixi.environments]
104128
# 3.13 lane
105-
environments.default = { features = [ "py313" ], solve-group = "py313" }
129+
default = { features = [ "py313" ], solve-group = "py313" }
106130
# 3.11 lane (for gh-actions)
107-
environments.dev-py311 = { features = [ "dev", "test", "py311" ], solve-group = "py311" }
108-
environments.dev-py313 = { features = [ "dev", "test", "py313" ], solve-group = "py313" }
109-
environments.docs-py311 = { features = [ "doc", "py311" ], solve-group = "py311" }
110-
environments.docs-py313 = { features = [ "doc", "py313" ], solve-group = "py313" }
111-
environments.test-py313 = { features = [ "test", "py313" ], solve-group = "py313" }
131+
dev-py311 = { features = [ "dev", "test", "py311" ], solve-group = "py311" }
132+
dev-py313 = { features = [ "dev", "test", "py313" ], solve-group = "py313" }
133+
dev-interactive-py313 = { features = [ "dev", "test", "interactive", "py313" ], solve-group = "py313" }
134+
docs-py311 = { features = [ "doc", "py311" ], solve-group = "py311" }
135+
docs-py313 = { features = [ "doc", "py313" ], solve-group = "py313" }
136+
test-py313 = { features = [ "test", "py313" ], solve-group = "py313" }
112137

113138
[tool.ruff]
114139
line-length = 120

src/spatialdata_plot/pl/basic.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,141 @@ def _copy(
170170
tables=self._sdata.tables if tables is None else tables,
171171
)
172172
sdata.plotting_tree = self._sdata.plotting_tree if hasattr(self._sdata, "plotting_tree") else OrderedDict()
173+
sdata._source_sdata = getattr(self._sdata, "_source_sdata", self._sdata)
173174

174175
return sdata
175176

177+
def annotate(
178+
self,
179+
*,
180+
coordinate_systems: str | None = None,
181+
point_radius_frac: float = 0.005,
182+
figsize: tuple[float, float] = (7, 7),
183+
dpi: int = 120,
184+
) -> Any:
185+
"""Terminal step on a render chain: drop the plot into an interactive annotator.
186+
187+
Renders the accumulated ``plotting_tree`` (so any ``render_images`` /
188+
``render_shapes`` / ``render_points`` / ``render_labels`` overlays composed
189+
upstream of this call appear in the annotation canvas), then hands the
190+
rasterised figure to a ``BioImageViewer`` widget. The user draws
191+
rectangles, polygons, and points on the canvas, types a name, and clicks
192+
*Save* — the shapes are converted from canvas-pixel space to the chosen
193+
coordinate system and stored in ``sdata.shapes[<name>]`` with an
194+
``Identity`` transformation in that CS. Points are stored as small
195+
circle polygons (radius = ``point_radius_frac`` of the rendered image's
196+
CS extent) so the resulting ``ShapesModel`` is uniform-type.
197+
198+
Single coordinate system only. If the chain spans more than one CS, or
199+
none can be inferred, raises ``ValueError``.
200+
201+
Requires the ``interactive`` extra: ``pip install 'spatialdata-plot[interactive]'``.
202+
203+
Parameters
204+
----------
205+
coordinate_systems :
206+
Coordinate system to render and resolve drawn shapes against.
207+
Drawn shapes are stored with an ``Identity`` transformation in this
208+
CS. If ``None`` and the SpatialData has exactly one CS, that one is
209+
used; otherwise this argument is required.
210+
point_radius_frac :
211+
Radius of the circle polygon used to store each point, expressed as
212+
a fraction of the rendered image's CS extent. Default 0.005 (0.5%).
213+
figsize :
214+
Matplotlib figure size used for the underlying rasterisation. The
215+
same value affects the canvas resolution alongside ``dpi``.
216+
dpi :
217+
DPI of the rasterised figure. Combined with ``figsize`` this sets
218+
the pixel resolution the annotator works in.
219+
220+
Returns
221+
-------
222+
InteractiveSession
223+
The session object, with the widget already displayed. Holding the
224+
reference keeps the underlying ``BioImageViewer`` alive across cell
225+
re-runs; usually you can ignore the return value.
226+
227+
Raises
228+
------
229+
ValueError
230+
If no single coordinate system can be resolved.
231+
ImportError
232+
If the ``interactive`` extra is not installed.
233+
234+
Examples
235+
--------
236+
>>> import spatialdata_plot # noqa: F401 registers .pl
237+
>>> (
238+
... sdata.pl
239+
... .render_images(element="he")
240+
... .pl.render_shapes(element="cells", outline_color="red")
241+
... .pl.annotate()
242+
... )
243+
>>> # ... user draws and clicks Save with name "tumor" ...
244+
>>> sdata.shapes["tumor"]
245+
"""
246+
try:
247+
from spatialdata_plot.pl.interactive._session import _InteractiveSession
248+
except ImportError as exc:
249+
raise ImportError(
250+
"sdata.pl.annotate() requires the `interactive` extra. "
251+
"Install with: pip install 'spatialdata-plot[interactive]'"
252+
) from exc
253+
254+
import io as _io
255+
256+
from PIL import Image as _Image
257+
258+
available_cs = list(self._sdata.coordinate_systems)
259+
if coordinate_systems is None:
260+
if len(available_cs) != 1:
261+
raise ValueError(
262+
"annotate() needs exactly one coordinate system. "
263+
f"SpatialData has {len(available_cs)}: {available_cs!r}. "
264+
"Pass coordinate_systems=<name> explicitly."
265+
)
266+
cs = available_cs[0]
267+
else:
268+
if isinstance(coordinate_systems, list):
269+
if len(coordinate_systems) != 1:
270+
raise ValueError(f"annotate() supports a single coordinate system; got {coordinate_systems!r}.")
271+
cs = coordinate_systems[0]
272+
else:
273+
cs = coordinate_systems
274+
if cs not in available_cs:
275+
raise ValueError(f"Unknown coordinate system {cs!r}. Available: {available_cs!r}")
276+
277+
fig = plt.figure(figsize=figsize, dpi=dpi)
278+
try:
279+
ax = fig.add_axes([0, 0, 1, 1])
280+
self.show(coordinate_systems=cs, ax=ax)
281+
xlim = ax.get_xlim()
282+
ylim = ax.get_ylim()
283+
ax.set_axis_off()
284+
# set_aspect("equal") inside show() can shrink the axes box so the
285+
# figure has blank padding around the data. Crop the saved PNG to
286+
# the axes bbox so PNG pixels map 1:1 to (xlim, ylim) and the
287+
# px→cs transform in _commit.py stays correct.
288+
fig.canvas.draw()
289+
bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
290+
buf = _io.BytesIO()
291+
fig.savefig(buf, format="png", dpi=dpi, bbox_inches=bbox, pad_inches=0)
292+
finally:
293+
plt.close(fig)
294+
rgb = np.asarray(_Image.open(buf).convert("RGB"))
295+
296+
target_sdata = getattr(self._sdata, "_source_sdata", self._sdata)
297+
session = _InteractiveSession(
298+
sdata=target_sdata,
299+
coordinate_system=cs,
300+
rgb=rgb,
301+
xlim=tuple(xlim),
302+
ylim=tuple(ylim),
303+
point_radius_frac=point_radius_frac,
304+
)
305+
session.show()
306+
return session
307+
176308
@_deprecation_alias(elements="element", version="0.3.0")
177309
def render_shapes(
178310
self,
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Interactive region selection on top of a rendered spatialdata-plot figure.
2+
3+
Use via :meth:`spatialdata_plot.pl.basic.PlotAccessor.annotate`:
4+
5+
>>> import spatialdata_plot # noqa: F401 registers .pl
6+
>>> sdata.pl.render_images(element="he").pl.annotate()
7+
"""
8+
9+
from __future__ import annotations
10+
11+
__all__: list[str] = []
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""Convert anybioimage canvas shapes into CS-coord shapely geometries."""
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import Callable
6+
from typing import Any
7+
8+
from shapely.geometry import Point, Polygon, box
9+
10+
PxToCs = Callable[[float, float], tuple[float, float]]
11+
12+
13+
def _make_px_to_cs(xmin: float, xmax: float, y_lo: float, y_hi: float, image_w: int, image_h: int) -> PxToCs:
14+
"""Build an affine mapping (px_x, px_y) → (cs_x, cs_y).
15+
16+
The y_lo/y_hi are the sorted ylim values; image_h pixels map linearly
17+
between them. matplotlib image axes with ``origin='upper'`` return
18+
reversed ylim — sorting normalises that.
19+
"""
20+
dx = xmax - xmin
21+
dy = y_hi - y_lo
22+
23+
def px_to_cs(x_px: float, y_px: float) -> tuple[float, float]:
24+
return (xmin + (x_px / image_w) * dx, y_lo + (y_px / image_h) * dy)
25+
26+
return px_to_cs
27+
28+
29+
def _roi_to_polygon(roi: dict[str, Any], px_to_cs: PxToCs) -> Polygon | None:
30+
"""ROI dict ``{x, y, width, height}`` → axis-aligned rectangle Polygon."""
31+
try:
32+
x0, y0 = px_to_cs(float(roi["x"]), float(roi["y"]))
33+
x1, y1 = px_to_cs(float(roi["x"]) + float(roi["width"]), float(roi["y"]) + float(roi["height"]))
34+
except (KeyError, TypeError, ValueError):
35+
return None
36+
poly = box(min(x0, x1), min(y0, y1), max(x0, x1), max(y0, y1))
37+
return poly if not poly.is_empty else None
38+
39+
40+
def _polygon_to_polygon(poly: dict[str, Any], px_to_cs: PxToCs) -> Polygon | None:
41+
"""Polygon dict ``{id, points: [{x, y}, ...]}`` → shapely Polygon (≥3 verts)."""
42+
pts = poly.get("points") or []
43+
try:
44+
cs_verts = [px_to_cs(float(p["x"]), float(p["y"])) for p in pts]
45+
except (KeyError, TypeError, ValueError):
46+
return None
47+
if len(cs_verts) < 3:
48+
return None
49+
geom = Polygon(cs_verts)
50+
return geom if not geom.is_empty else None
51+
52+
53+
def _point_to_circle(pt: dict[str, Any], px_to_cs: PxToCs, radius: float) -> Polygon | None:
54+
"""Point dict ``{x, y}`` → circle Polygon of the given CS-units radius.
55+
56+
Stored as a polygon so the resulting ShapesModel is uniform-type and
57+
doesn't need a ``radius`` column.
58+
"""
59+
try:
60+
cx, cy = px_to_cs(float(pt["x"]), float(pt["y"]))
61+
except (KeyError, TypeError, ValueError):
62+
return None
63+
geom = Point(cx, cy).buffer(radius)
64+
return geom if not geom.is_empty else None
65+
66+
67+
def collect_geoms_from_viewer(
68+
viewer: Any,
69+
*,
70+
xmin: float,
71+
xmax: float,
72+
y_lo: float,
73+
y_hi: float,
74+
image_w: int,
75+
image_h: int,
76+
point_radius: float,
77+
) -> list[Polygon]:
78+
"""Read the viewer's three shape stores and convert each entry to a CS-coord Polygon.
79+
80+
Order of returned geometries: ROIs first, then polygons, then points. Invalid
81+
entries (missing keys, degenerate geometry) are silently skipped.
82+
"""
83+
px_to_cs = _make_px_to_cs(xmin, xmax, y_lo, y_hi, image_w, image_h)
84+
geoms: list[Polygon] = []
85+
for roi in viewer._rois_data or []:
86+
g = _roi_to_polygon(roi, px_to_cs)
87+
if g is not None:
88+
geoms.append(g)
89+
for poly in viewer._polygons_data or []:
90+
g = _polygon_to_polygon(poly, px_to_cs)
91+
if g is not None:
92+
geoms.append(g)
93+
for pt in viewer._points_data or []:
94+
g = _point_to_circle(pt, px_to_cs, point_radius)
95+
if g is not None:
96+
geoms.append(g)
97+
return geoms

0 commit comments

Comments
 (0)