Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 79 additions & 114 deletions src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
from dask.dataframe import DataFrame as DaskDataFrame
from geopandas import GeoDataFrame
from matplotlib.path import Path
from shapely.geometry import MultiPolygon, Point, Polygon
from xarray import DataArray, DataTree

Expand Down Expand Up @@ -386,51 +387,61 @@ def _bounding_box_mask_points(
axes: tuple[str, ...],
min_coordinate: list[Number] | ArrayLike,
max_coordinate: list[Number] | ArrayLike,
points_df: pd.DataFrame | None = None,
polygon_corners: np.ndarray | None = None, # shape: (n_boxes, n_corners, 2)
) -> list[ArrayLike]:
"""Compute a mask that is true for the points inside axis-aligned bounding boxes.

Parameters
----------
points
The points element to perform the query on.
axes
The axes that min_coordinate and max_coordinate refer to.
min_coordinate
PLACEHOLDER
The upper left hand corners of the bounding boxes (i.e., minimum coordinates along all dimensions).
Shape: (n_boxes, n_axes) or (n_axes,) for a single box.
{min_coordinate_docs}
max_coordinate
The lower right hand corners of the bounding boxes (i.e., the maximum coordinates along all dimensions).
Shape: (n_boxes, n_axes) or (n_axes,) for a single box.
{max_coordinate_docs}

Returns
-------
The masks for the points inside the bounding boxes.
"""
element_axes = get_axes_names(points)

min_coordinate = _parse_list_into_array(min_coordinate)
max_coordinate = _parse_list_into_array(max_coordinate)

# Ensure min_coordinate and max_coordinate are 2D arrays
min_coordinate = min_coordinate[np.newaxis, :] if min_coordinate.ndim == 1 else min_coordinate
max_coordinate = max_coordinate[np.newaxis, :] if max_coordinate.ndim == 1 else max_coordinate

if points_df is None:
points_df = points.compute()

relevant_axes = [ax for ax in axes if ax in element_axes]
point_coords = points_df[relevant_axes].values # (n_points, 2)

n_boxes = min_coordinate.shape[0]
in_bounding_box_masks = []

for box in range(n_boxes):
box_masks = []
for axis_index, axis_name in enumerate(axes):
if axis_name not in element_axes:
continue
min_value = min_coordinate[box, axis_index]
max_value = max_coordinate[box, axis_index]
box_masks.append(points[axis_name].gt(min_value).compute() & points[axis_name].lt(max_value).compute())
bounding_box_mask = np.stack(box_masks, axis=-1)
in_bounding_box_masks.append(np.all(bounding_box_mask, axis=1))
if polygon_corners is not None and len(relevant_axes) == 2:
# Exact path from the (potentially rotated/sheared) corners —
# close the polygon by appending the first corner again
corners = polygon_corners[box] # (n_corners, 2)
closed = np.vstack([corners, corners[0]])
mask = Path(closed).contains_points(point_coords)
elif len(relevant_axes) == 2:
# Axis-aligned rectangle — still faster than per-axis boolean ops
axis_indices = [list(axes).index(ax) for ax in relevant_axes]
mins = min_coordinate[box, axis_indices]
maxs = max_coordinate[box, axis_indices]
x_min, y_min = mins
x_max, y_max = maxs
box_path = Path(
[
(x_min, y_min),
(x_max, y_min),
(x_max, y_max),
(x_min, y_max),
(x_min, y_min),
]
)
mask = box_path.contains_points(point_coords)
else:
# Fallback for 1D or >2D
axis_indices = [list(axes).index(ax) for ax in relevant_axes]
box_masks = []
for i, axis_name in enumerate(relevant_axes):
col = points_df[axis_name].values
box_masks.append(
(col > min_coordinate[box, axis_indices[i]]) & (col < max_coordinate[box, axis_indices[i]])
)
mask = np.all(np.stack(box_masks, axis=-1), axis=1)

in_bounding_box_masks.append(mask)

return in_bounding_box_masks


Expand Down Expand Up @@ -630,118 +641,72 @@ def _(
max_coordinate: list[Number] | ArrayLike,
target_coordinate_system: str,
) -> DaskDataFrame | list[DaskDataFrame] | None:
from spatialdata import transform
from spatialdata.transformations import get_transformation

min_coordinate = _parse_list_into_array(min_coordinate)
max_coordinate = _parse_list_into_array(max_coordinate)

# Ensure min_coordinate and max_coordinate are 2D arrays
min_coordinate = min_coordinate[np.newaxis, :] if min_coordinate.ndim == 1 else min_coordinate
max_coordinate = max_coordinate[np.newaxis, :] if max_coordinate.ndim == 1 else max_coordinate

# for triggering validation
_ = BoundingBoxRequest(
target_coordinate_system=target_coordinate_system,
axes=axes,
min_coordinate=min_coordinate,
max_coordinate=max_coordinate,
)

# get the four corners of the bounding box (2D case), or the 8 corners of the "3D bounding box" (3D case)
(intrinsic_bounding_box_corners, intrinsic_axes) = _get_bounding_box_corners_in_intrinsic_coordinates(
(intrinsic_corners, intrinsic_axes) = _get_bounding_box_corners_in_intrinsic_coordinates(
element=points,
axes=axes,
min_coordinate=min_coordinate,
max_coordinate=max_coordinate,
target_coordinate_system=target_coordinate_system,
)
min_coordinate_intrinsic = intrinsic_bounding_box_corners.min(dim="corner")
max_coordinate_intrinsic = intrinsic_bounding_box_corners.max(dim="corner")

min_coordinate_intrinsic = min_coordinate_intrinsic.data
max_coordinate_intrinsic = max_coordinate_intrinsic.data

# get the points in the intrinsic coordinate bounding box
in_intrinsic_bounding_box = _bounding_box_mask_points(
# intrinsic_corners has shape (n_boxes, n_corners, n_axes) — extract the
# two spatial axes and pass the exact corner geometry to the mask function
axis_names = list(intrinsic_axes)
xy_indices = [axis_names.index("x"), axis_names.index("y")]
corners_np = intrinsic_corners.data
if corners_np.ndim == 2:
corners_np = corners_np[np.newaxis, ...] # add box dim → (1, n_corners, n_axes)
polygon_corners = corners_np[:, :, xy_indices] # (n_boxes, n_corners, 2)

points_pd = points.compute() # single .compute() for the whole function
masks = _bounding_box_mask_points(
points=points,
axes=intrinsic_axes,
min_coordinate=min_coordinate_intrinsic,
max_coordinate=max_coordinate_intrinsic,
min_coordinate=intrinsic_corners.data.min(axis=1), # still needed for the fallback path
max_coordinate=intrinsic_corners.data.max(axis=1),
points_df=points_pd,
polygon_corners=polygon_corners,
)

if not (len_df := len(in_intrinsic_bounding_box)) == (len_bb := len(min_coordinate)):
if len(masks) != len(min_coordinate):
raise ValueError(
f"Length of list of dataframes `{len_df}` is not equal to the number of bounding boxes axes `{len_bb}`."
f"Length of list of dataframes `{len(masks)}` is not equal to "
f"the number of bounding boxes `{len(min_coordinate)}`."
)
points_in_intrinsic_bounding_box: list[DaskDataFrame | None] = []
points_pd = points.compute()

attrs = points.attrs.copy()
for mask_np in in_intrinsic_bounding_box:
if mask_np.sum() == 0:
points_in_intrinsic_bounding_box.append(None)
else:
# TODO there is a problem when mixing dask dataframe graph with dask array graph. Need to compute for now.
# we can't compute either mask or points as when we calculate either one of them
# test_query_points_multiple_partitions will fail as the mask will be used to index each partition.
# However, if we compute and then create the dask array again we get the mixed dask graph problem.
filtered_pd = points_pd[mask_np]
points_filtered = dd.from_pandas(filtered_pd, npartitions=points.npartitions)
points_filtered.attrs.update(attrs)
points_in_intrinsic_bounding_box.append(points_filtered)
if len(points_in_intrinsic_bounding_box) == 0:
return None
old_transformations = get_transformation(points, get_all=True)
assert isinstance(old_transformations, dict)
feature_key = points.attrs.get(ATTRS_KEY, {}).get(PointsModel.FEATURE_KEY)

# assert that the number of queried points is correct
assert len(points_in_intrinsic_bounding_box) == len(min_coordinate)

# # we have to reset the index since we have subset
# # https://stackoverflow.com/questions/61395351/how-to-reset-index-on-concatenated-dataframe-in-dask
# points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.assign(idx=1)
# points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.set_index(
# points_in_intrinsic_bounding_box.idx.cumsum() - 1
# )
# points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.map_partitions(
# lambda df: df.rename(index={"idx": None})
# )
# points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.drop(columns=["idx"])

# transform the element to the query coordinate system
output: list[DaskDataFrame | None] = []
for p, min_c, max_c in zip(points_in_intrinsic_bounding_box, min_coordinate, max_coordinate, strict=True):
if p is None:
for mask_np in masks:
if mask_np.sum() == 0:
output.append(None)
else:
points_query_coordinate_system = transform(
p, to_coordinate_system=target_coordinate_system, maintain_positioning=False
filtered_pd = points_pd[mask_np]
output.append(
PointsModel.parse(
dd.from_pandas(filtered_pd, npartitions=1),
transformations=old_transformations.copy(),
feature_key=feature_key,
)
)

# get a mask for the points in the bounding box
bounding_box_mask = _bounding_box_mask_points(
points=points_query_coordinate_system,
axes=axes,
min_coordinate=min_c, # type: ignore[arg-type]
max_coordinate=max_c, # type: ignore[arg-type]
)
if len(bounding_box_mask) != 1:
raise ValueError(f"Expected a single mask, got {len(bounding_box_mask)} masks. Please report this bug.")
bounding_box_indices = np.where(bounding_box_mask[0])[0]

if len(bounding_box_indices) == 0:
output.append(None)
else:
points_df = p.compute().iloc[bounding_box_indices]
old_transformations = get_transformation(p, get_all=True)
assert isinstance(old_transformations, dict)
feature_key = p.attrs.get(ATTRS_KEY, {}).get(PointsModel.FEATURE_KEY)

output.append(
PointsModel.parse(
dd.from_pandas(points_df, npartitions=1),
transformations=old_transformations.copy(),
feature_key=feature_key,
)
)
if len(output) == 0:
return None
if len(output) == 1:
Expand Down Expand Up @@ -791,8 +756,8 @@ def _(
)
for box_corners in intrinsic_bounding_box_corners:
bounding_box_non_axes_aligned = Polygon(box_corners.data)
indices = polygons.geometry.intersects(bounding_box_non_axes_aligned)
queried = polygons[indices]
candidate_idx = polygons.sindex.query(bounding_box_non_axes_aligned, predicate="intersects")
queried = polygons.iloc[candidate_idx]
if len(queried) == 0:
queried_polygon = None
else:
Expand Down
Loading