diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index cfd7dafa..4c3fcb0e 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -743,6 +743,9 @@ def _render_points( ) added_color_from_table = True + # Reset to sequential index so row order matches after _reparse_points round-trip (#358). + points = points.reset_index(drop=True) + n_points = len(points) points_pd_with_color = points # When we pull colors from a table, keep the raw points (with color) for later, @@ -758,7 +761,7 @@ def _render_points( if table_name is None: adata = AnnData( X=points[["x", "y"]].values, - obs=points[coords].reset_index(), + obs=points[coords], dtype=points[["x", "y"]].values.dtype, ) else: diff --git a/tests/_images/Points_sampled_points_categorical_color_datashader.png b/tests/_images/Points_sampled_points_categorical_color_datashader.png new file mode 100644 index 00000000..ad2f9ac8 Binary files /dev/null and b/tests/_images/Points_sampled_points_categorical_color_datashader.png differ diff --git a/tests/_images/Points_sampled_points_categorical_color_matplotlib.png b/tests/_images/Points_sampled_points_categorical_color_matplotlib.png new file mode 100644 index 00000000..0511ea80 Binary files /dev/null and b/tests/_images/Points_sampled_points_categorical_color_matplotlib.png differ diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index eba38fbd..107c1efb 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -606,6 +606,27 @@ def test_plot_groups_na_color_none_filters_points_datashader(self, sdata_blobs: "blobs_points", color="cat_color", groups=["a"], size=30, method="datashader" ).pl.show(ax=axs[1], title="default (filtered)") + @staticmethod + def _make_sampled_sdata() -> SpatialData: + """Points with two spatially separated clusters, shuffled via .sample() (#358).""" + rng = get_standard_RNG() + n = 100 + x = np.concatenate([rng.uniform(0, 10, n // 2), rng.uniform(90, 100, n // 2)]) + y = np.concatenate([rng.uniform(0, 10, n // 2), rng.uniform(90, 100, n // 2)]) + df = pd.DataFrame({"x": x, "y": y, "cluster": pd.Categorical(["A"] * (n // 2) + ["B"] * (n // 2))}) + sdata = SpatialData(points={"pts": PointsModel.parse(df)}) + sampled = sdata.points["pts"].compute().sample(frac=0.8, random_state=42) + sdata.points["pts"] = PointsModel.parse(sampled) + return sdata + + def test_plot_sampled_points_categorical_color_matplotlib(self): + """Regression test for #358: .sample() must not shuffle categorical colors.""" + self._make_sampled_sdata().pl.render_points("pts", color="cluster", method="matplotlib").pl.show() + + def test_plot_sampled_points_categorical_color_datashader(self): + """Regression test for #358: .sample() must not shuffle categorical colors.""" + self._make_sampled_sdata().pl.render_points("pts", color="cluster", method="datashader").pl.show() + def test_groups_na_color_none_no_match_points(sdata_blobs: SpatialData): """When no elements match the groups, the plot should render without error."""