diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index c29e3bc0e..35cf38b36 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -11,6 +11,7 @@ import numpy as np from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame +from matplotlib.path import Path from shapely.geometry import MultiPolygon, Point, Polygon from xarray import DataArray, DataTree @@ -386,51 +387,61 @@ def _bounding_box_mask_points( axes: tuple[str, ...], min_coordinate: list[Number] | ArrayLike, max_coordinate: list[Number] | ArrayLike, + points_df: pd.DataFrame | None = None, + polygon_corners: np.ndarray | None = None, # shape: (n_boxes, n_corners, 2) ) -> list[ArrayLike]: - """Compute a mask that is true for the points inside axis-aligned bounding boxes. - - Parameters - ---------- - points - The points element to perform the query on. - axes - The axes that min_coordinate and max_coordinate refer to. - min_coordinate - PLACEHOLDER - The upper left hand corners of the bounding boxes (i.e., minimum coordinates along all dimensions). - Shape: (n_boxes, n_axes) or (n_axes,) for a single box. - {min_coordinate_docs} - max_coordinate - The lower right hand corners of the bounding boxes (i.e., the maximum coordinates along all dimensions). - Shape: (n_boxes, n_axes) or (n_axes,) for a single box. - {max_coordinate_docs} - - Returns - ------- - The masks for the points inside the bounding boxes. - """ element_axes = get_axes_names(points) - min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) - - # Ensure min_coordinate and max_coordinate are 2D arrays min_coordinate = min_coordinate[np.newaxis, :] if min_coordinate.ndim == 1 else min_coordinate max_coordinate = max_coordinate[np.newaxis, :] if max_coordinate.ndim == 1 else max_coordinate + if points_df is None: + points_df = points.compute() + + relevant_axes = [ax for ax in axes if ax in element_axes] + point_coords = points_df[relevant_axes].values # (n_points, 2) + n_boxes = min_coordinate.shape[0] in_bounding_box_masks = [] for box in range(n_boxes): - box_masks = [] - for axis_index, axis_name in enumerate(axes): - if axis_name not in element_axes: - continue - min_value = min_coordinate[box, axis_index] - max_value = max_coordinate[box, axis_index] - box_masks.append(points[axis_name].gt(min_value).compute() & points[axis_name].lt(max_value).compute()) - bounding_box_mask = np.stack(box_masks, axis=-1) - in_bounding_box_masks.append(np.all(bounding_box_mask, axis=1)) + if polygon_corners is not None and len(relevant_axes) == 2: + # Exact path from the (potentially rotated/sheared) corners — + # close the polygon by appending the first corner again + corners = polygon_corners[box] # (n_corners, 2) + closed = np.vstack([corners, corners[0]]) + mask = Path(closed).contains_points(point_coords) + elif len(relevant_axes) == 2: + # Axis-aligned rectangle — still faster than per-axis boolean ops + axis_indices = [list(axes).index(ax) for ax in relevant_axes] + mins = min_coordinate[box, axis_indices] + maxs = max_coordinate[box, axis_indices] + x_min, y_min = mins + x_max, y_max = maxs + box_path = Path( + [ + (x_min, y_min), + (x_max, y_min), + (x_max, y_max), + (x_min, y_max), + (x_min, y_min), + ] + ) + mask = box_path.contains_points(point_coords) + else: + # Fallback for 1D or >2D + axis_indices = [list(axes).index(ax) for ax in relevant_axes] + box_masks = [] + for i, axis_name in enumerate(relevant_axes): + col = points_df[axis_name].values + box_masks.append( + (col > min_coordinate[box, axis_indices[i]]) & (col < max_coordinate[box, axis_indices[i]]) + ) + mask = np.all(np.stack(box_masks, axis=-1), axis=1) + + in_bounding_box_masks.append(mask) + return in_bounding_box_masks @@ -630,17 +641,13 @@ def _( max_coordinate: list[Number] | ArrayLike, target_coordinate_system: str, ) -> DaskDataFrame | list[DaskDataFrame] | None: - from spatialdata import transform from spatialdata.transformations import get_transformation min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) - - # Ensure min_coordinate and max_coordinate are 2D arrays min_coordinate = min_coordinate[np.newaxis, :] if min_coordinate.ndim == 1 else min_coordinate max_coordinate = max_coordinate[np.newaxis, :] if max_coordinate.ndim == 1 else max_coordinate - # for triggering validation _ = BoundingBoxRequest( target_coordinate_system=target_coordinate_system, axes=axes, @@ -648,100 +655,58 @@ def _( max_coordinate=max_coordinate, ) - # get the four corners of the bounding box (2D case), or the 8 corners of the "3D bounding box" (3D case) - (intrinsic_bounding_box_corners, intrinsic_axes) = _get_bounding_box_corners_in_intrinsic_coordinates( + (intrinsic_corners, intrinsic_axes) = _get_bounding_box_corners_in_intrinsic_coordinates( element=points, axes=axes, min_coordinate=min_coordinate, max_coordinate=max_coordinate, target_coordinate_system=target_coordinate_system, ) - min_coordinate_intrinsic = intrinsic_bounding_box_corners.min(dim="corner") - max_coordinate_intrinsic = intrinsic_bounding_box_corners.max(dim="corner") - - min_coordinate_intrinsic = min_coordinate_intrinsic.data - max_coordinate_intrinsic = max_coordinate_intrinsic.data - # get the points in the intrinsic coordinate bounding box - in_intrinsic_bounding_box = _bounding_box_mask_points( + # intrinsic_corners has shape (n_boxes, n_corners, n_axes) — extract the + # two spatial axes and pass the exact corner geometry to the mask function + axis_names = list(intrinsic_axes) + xy_indices = [axis_names.index("x"), axis_names.index("y")] + corners_np = intrinsic_corners.data + if corners_np.ndim == 2: + corners_np = corners_np[np.newaxis, ...] # add box dim → (1, n_corners, n_axes) + polygon_corners = corners_np[:, :, xy_indices] # (n_boxes, n_corners, 2) + + points_pd = points.compute() # single .compute() for the whole function + masks = _bounding_box_mask_points( points=points, axes=intrinsic_axes, - min_coordinate=min_coordinate_intrinsic, - max_coordinate=max_coordinate_intrinsic, + min_coordinate=intrinsic_corners.data.min(axis=1), # still needed for the fallback path + max_coordinate=intrinsic_corners.data.max(axis=1), + points_df=points_pd, + polygon_corners=polygon_corners, ) - if not (len_df := len(in_intrinsic_bounding_box)) == (len_bb := len(min_coordinate)): + if len(masks) != len(min_coordinate): raise ValueError( - f"Length of list of dataframes `{len_df}` is not equal to the number of bounding boxes axes `{len_bb}`." + f"Length of list of dataframes `{len(masks)}` is not equal to " + f"the number of bounding boxes `{len(min_coordinate)}`." ) - points_in_intrinsic_bounding_box: list[DaskDataFrame | None] = [] - points_pd = points.compute() + attrs = points.attrs.copy() - for mask_np in in_intrinsic_bounding_box: - if mask_np.sum() == 0: - points_in_intrinsic_bounding_box.append(None) - else: - # TODO there is a problem when mixing dask dataframe graph with dask array graph. Need to compute for now. - # we can't compute either mask or points as when we calculate either one of them - # test_query_points_multiple_partitions will fail as the mask will be used to index each partition. - # However, if we compute and then create the dask array again we get the mixed dask graph problem. - filtered_pd = points_pd[mask_np] - points_filtered = dd.from_pandas(filtered_pd, npartitions=points.npartitions) - points_filtered.attrs.update(attrs) - points_in_intrinsic_bounding_box.append(points_filtered) - if len(points_in_intrinsic_bounding_box) == 0: - return None + old_transformations = get_transformation(points, get_all=True) + assert isinstance(old_transformations, dict) + feature_key = points.attrs.get(ATTRS_KEY, {}).get(PointsModel.FEATURE_KEY) - # assert that the number of queried points is correct - assert len(points_in_intrinsic_bounding_box) == len(min_coordinate) - - # # we have to reset the index since we have subset - # # https://stackoverflow.com/questions/61395351/how-to-reset-index-on-concatenated-dataframe-in-dask - # points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.assign(idx=1) - # points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.set_index( - # points_in_intrinsic_bounding_box.idx.cumsum() - 1 - # ) - # points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.map_partitions( - # lambda df: df.rename(index={"idx": None}) - # ) - # points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.drop(columns=["idx"]) - - # transform the element to the query coordinate system output: list[DaskDataFrame | None] = [] - for p, min_c, max_c in zip(points_in_intrinsic_bounding_box, min_coordinate, max_coordinate, strict=True): - if p is None: + for mask_np in masks: + if mask_np.sum() == 0: output.append(None) else: - points_query_coordinate_system = transform( - p, to_coordinate_system=target_coordinate_system, maintain_positioning=False + filtered_pd = points_pd[mask_np] + output.append( + PointsModel.parse( + dd.from_pandas(filtered_pd, npartitions=1), + transformations=old_transformations.copy(), + feature_key=feature_key, + ) ) - # get a mask for the points in the bounding box - bounding_box_mask = _bounding_box_mask_points( - points=points_query_coordinate_system, - axes=axes, - min_coordinate=min_c, # type: ignore[arg-type] - max_coordinate=max_c, # type: ignore[arg-type] - ) - if len(bounding_box_mask) != 1: - raise ValueError(f"Expected a single mask, got {len(bounding_box_mask)} masks. Please report this bug.") - bounding_box_indices = np.where(bounding_box_mask[0])[0] - - if len(bounding_box_indices) == 0: - output.append(None) - else: - points_df = p.compute().iloc[bounding_box_indices] - old_transformations = get_transformation(p, get_all=True) - assert isinstance(old_transformations, dict) - feature_key = p.attrs.get(ATTRS_KEY, {}).get(PointsModel.FEATURE_KEY) - - output.append( - PointsModel.parse( - dd.from_pandas(points_df, npartitions=1), - transformations=old_transformations.copy(), - feature_key=feature_key, - ) - ) if len(output) == 0: return None if len(output) == 1: @@ -791,8 +756,8 @@ def _( ) for box_corners in intrinsic_bounding_box_corners: bounding_box_non_axes_aligned = Polygon(box_corners.data) - indices = polygons.geometry.intersects(bounding_box_non_axes_aligned) - queried = polygons[indices] + candidate_idx = polygons.sindex.query(bounding_box_non_axes_aligned, predicate="intersects") + queried = polygons.iloc[candidate_idx] if len(queried) == 0: queried_polygon = None else: