From 21ef7845885973d23bbb01572b7360b20922ddd3 Mon Sep 17 00:00:00 2001 From: LouiseDck Date: Tue, 18 Mar 2025 10:24:29 +0100 Subject: [PATCH 1/4] Start scatter --- src/anndata_plot/pl/__init__.py | 4 ++++ src/anndata_plot/pl/_utils.py | 24 ++++++++++++++++++++++++ src/anndata_plot/pl/scatter.py | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+) create mode 100644 src/anndata_plot/pl/_utils.py create mode 100644 src/anndata_plot/pl/scatter.py diff --git a/src/anndata_plot/pl/__init__.py b/src/anndata_plot/pl/__init__.py index c2315dd..1686783 100644 --- a/src/anndata_plot/pl/__init__.py +++ b/src/anndata_plot/pl/__init__.py @@ -1 +1,5 @@ from .basic import BasicClass, basic_plot + +from .scatter import scatter + +__all__ = ["scatter"] diff --git a/src/anndata_plot/pl/_utils.py b/src/anndata_plot/pl/_utils.py new file mode 100644 index 0000000..73a45a7 --- /dev/null +++ b/src/anndata_plot/pl/_utils.py @@ -0,0 +1,24 @@ +from typing import Literal + +_FontWeight = Literal["light", "normal", "medium", "semibold", "bold", "heavy", "black"] +_FontSize = Literal[ + "xx-small", "x-small", "small", "medium", "large", "x-large", "xx-large" +] +_LegendLoc = Literal[ + "none", + "right margin", + "on data", + "on data export", + "best", + "upper right", + "upper left", + "lower left", + "lower right", + "right", + "center left", + "center right", + "lower center", + "upper center", + "center", +] +ColorLike = str | tuple[float, ...] diff --git a/src/anndata_plot/pl/scatter.py b/src/anndata_plot/pl/scatter.py new file mode 100644 index 0000000..a3d4feb --- /dev/null +++ b/src/anndata_plot/pl/scatter.py @@ -0,0 +1,32 @@ +from ._utils import ColorLike +from typing import Literal, Sequence +import numpy as np + +import holoviews as hv + +def scatter( + adata, #Y: np.ndarray, + Y: Sequence[str], + *, + colors: str | Sequence[ColorLike | np.ndarray] = "blue", # should probably be a colormapping? + sort_order=True, + alpha=None, + highlights=(), + right_margin=None, + left_margin=None, + projection: Literal["2d", "3d"] = "2d", + title=None, + component_name="DC", + component_indexnames=(1, 2, 3), + axis_labels=None, + colorbars=(False,), + sizes=(1,), + markers=".", + color_map="viridis", + show_ticks=True, + ax=None, +): + print("HV scatter") + hv.Points(adata, Y) + return None + From 3ea688a24c448b72ecf04b3e1062495606995919 Mon Sep 17 00:00:00 2001 From: LouiseDck Date: Tue, 18 Mar 2025 15:49:00 +0100 Subject: [PATCH 2/4] PoC scatter plot & render utils --- src/anndata_plot/pl/render_utils.py | 15 ++++++++++++++ src/anndata_plot/pl/scatter.py | 31 +++++++++++++++++++++++++---- 2 files changed, 42 insertions(+), 4 deletions(-) create mode 100644 src/anndata_plot/pl/render_utils.py diff --git a/src/anndata_plot/pl/render_utils.py b/src/anndata_plot/pl/render_utils.py new file mode 100644 index 0000000..f66f459 --- /dev/null +++ b/src/anndata_plot/pl/render_utils.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass + +@dataclass +class LegendOpts: + legend_position: str = "right" + +@dataclass +class ColorOpts: + color: str = "blue" + cmap: str = "viridis" + +@dataclass +class SizeOpts: + height: int = 400 + width: int = 600 diff --git a/src/anndata_plot/pl/scatter.py b/src/anndata_plot/pl/scatter.py index a3d4feb..748402b 100644 --- a/src/anndata_plot/pl/scatter.py +++ b/src/anndata_plot/pl/scatter.py @@ -2,13 +2,19 @@ from typing import Literal, Sequence import numpy as np +from .render_utils import ColorOpts +from .render_utils import LegendOpts +from .render_utils import SizeOpts + +from dataclasses import asdict + import holoviews as hv def scatter( adata, #Y: np.ndarray, Y: Sequence[str], *, - colors: str | Sequence[ColorLike | np.ndarray] = "blue", # should probably be a colormapping? + colors: str | Sequence[ColorLike | np.ndarray] = "blue", #should probably be a colormapping? sort_order=True, alpha=None, highlights=(), @@ -26,7 +32,24 @@ def scatter( show_ticks=True, ax=None, ): - print("HV scatter") - hv.Points(adata, Y) - return None + # fig = hv.render(hv.Points(adata, Y), backend="matplotlib") + # points = hv.Points(adata, Y, colors).opts(color=colors) + # points.opts(color = colors) + + # if colors is a column in obs --> hv.Points(adata, Y, colors).opts(color=colors) + # if colors is just one color --> hv.Points(adata, Y).opts(color=colors) --> but is this needed? + # if colors is a var_name --> hv.Points(adata, Y, colors).opts(color=colors) + + if color_map is None: + color_map = "viridis" + + legend_opts = LegendOpts() + color_opts = ColorOpts(color = colors, cmap = color_map) + size_opts = SizeOpts() + + # merge opts dicts + opts = {**asdict(legend_opts), + **asdict(color_opts), + **asdict(size_opts)} + return (hv.Points(adata, Y, colors).opts(**opts)) From 315960faa7552e85d16f0ff9db7665df35c2d448 Mon Sep 17 00:00:00 2001 From: LouiseDck Date: Tue, 18 Mar 2025 23:24:12 +0100 Subject: [PATCH 3/4] update scatterplot --- src/anndata_plot/pl/__init__.py | 13 +- src/anndata_plot/pl/render_utils.py | 32 ++++- src/anndata_plot/pl/scatter.py | 193 ++++++++++++++++++++++------ 3 files changed, 199 insertions(+), 39 deletions(-) diff --git a/src/anndata_plot/pl/__init__.py b/src/anndata_plot/pl/__init__.py index 1686783..8835536 100644 --- a/src/anndata_plot/pl/__init__.py +++ b/src/anndata_plot/pl/__init__.py @@ -2,4 +2,15 @@ from .scatter import scatter -__all__ = ["scatter"] +from .render_utils import ColorOpts +from .render_utils import LegendOpts +from .render_utils import SizeOpts +from .render_utils import AxisOpts + +__all__ = [ + "scatter", + "ColorOpts", + "LegendOpts", + "SizeOpts", + "AxisOpts", +] diff --git a/src/anndata_plot/pl/render_utils.py b/src/anndata_plot/pl/render_utils.py index f66f459..577c2fe 100644 --- a/src/anndata_plot/pl/render_utils.py +++ b/src/anndata_plot/pl/render_utils.py @@ -1,5 +1,29 @@ from dataclasses import dataclass +from typing import Literal + +_LegendLoc = Literal[ + "none", + "right margin", + "on data", + "on data export", + "best", + "upper right", + "upper left", + "lower left", + "lower right", + "right", + "center left", + "center right", + "lower center", + "upper center", + "center", +] +_FontWeight = Literal["light", "normal", "medium", "semibold", "bold", "heavy", "black"] +_FontSize = Literal[ + "xx-small", "x-small", "small", "medium", "large", "x-large", "xx-large" +] + @dataclass class LegendOpts: legend_position: str = "right" @@ -11,5 +35,9 @@ class ColorOpts: @dataclass class SizeOpts: - height: int = 400 - width: int = 600 + size: int = 100 + +@dataclass +class AxisOpts: + xlabel: str = "" + ylabel: str = "" diff --git a/src/anndata_plot/pl/scatter.py b/src/anndata_plot/pl/scatter.py index 748402b..c13ca65 100644 --- a/src/anndata_plot/pl/scatter.py +++ b/src/anndata_plot/pl/scatter.py @@ -1,55 +1,176 @@ from ._utils import ColorLike from typing import Literal, Sequence import numpy as np +import pandas as pd + +from typing import Collection +from numpy.typing import NDArray from .render_utils import ColorOpts from .render_utils import LegendOpts from .render_utils import SizeOpts +from .render_utils import AxisOpts +from .render_utils import _LegendLoc +from .render_utils import _FontWeight +from .render_utils import _FontSize from dataclasses import asdict +from matplotlib.colors import is_color_like + +from anndata import AnnData + import holoviews as hv +# def scatter( +# adata, #Y: np.ndarray, +# Y: Sequence[str], +# title: str | None = None, +# color_by: str | None = None, +# legend_opts: LegendOpts = LegendOpts(), +# color_opts: ColorOpts = ColorOpts(), +# size_opts: SizeOpts = SizeOpts(), +# aixs_opts: AxisOpts = AxisOpts(), +# backend_opts: dict = None, + +# # *, +# # colors: str | Sequence[ColorLike | np.ndarray] = "blue", #should probably be a colormapping? +# # sort_order=True, +# # alpha=None, +# # highlights=(), +# # right_margin=None, +# # left_margin=None, +# # projection: Literal["2d", "3d"] = "2d", +# # title=None, +# # component_name="DC", +# # component_indexnames=(1, 2, 3), +# # axis_labels=None, +# # colorbars=(False,), +# # sizes=(1,), +# # markers=".", +# # color_map="viridis", +# # show_ticks=True, +# # ax=None, +# ): +# # fig = hv.render(hv.Points(adata, Y), backend="matplotlib") +# # points = hv.Points(adata, Y, colors).opts(color=colors) +# # points.opts(color = colors) + +# # if colors is a column in obs --> hv.Points(adata, Y, colors).opts(color=colors) +# # if colors is just one color --> hv.Points(adata, Y).opts(color=colors) --> but is this needed? +# # if colors is a var_name --> hv.Points(adata, Y, colors).opts(color=colors) + +# if title is None: +# title = f"Scatter plot of {Y[0]} and {Y[1]}" + +# # merge opts dicts +# opts = {**asdict(legend_opts), +# **asdict(color_opts), +# **asdict(size_opts), +# "title": title, +# } + +# if color_by is None: +# return (hv.Points(adata, Y).opts(**opts)) + +# return (hv.Points(adata, Y, color_by).opts(**opts)) + +# copied from scanpy, pl._anndata line +- 235 +def _check_if_annotations( + adata: AnnData, + axis_name: Literal["obs", "var"], + x: str | None = None, + y: str | None = None, + color: Collection[str | ColorLike] | None = None, + use_raw: bool | None = None, +) -> bool: + """Check if `x`, `y`, and `colors` are annotations of `adata`. + + If `axis_name` is `obs`, checks in `adata.obs.columns` and `adata.var_names`, + if `axis_name` is `var`, checks in `adata.var.columns` and `adata.obs_names`. + """ + annotations: pd.Index[str] = getattr(adata, axis_name).columns + other_ax_obj = ( + adata.raw if use_raw and axis_name == "obs" else adata + ) + names: pd.Index[str] = getattr( + other_ax_obj, "var" if axis_name == "obs" else "obs" + ).index + + def is_annotation(needle: pd.Index) -> NDArray[np.bool_]: + return needle.isin({None}) | needle.isin(annotations) | needle.isin(names) + + if not is_annotation(pd.Index([x, y])).all(): + return False + + return bool(is_annotation(pd.Index([color])).all()) + def scatter( - adata, #Y: np.ndarray, - Y: Sequence[str], - *, - colors: str | Sequence[ColorLike | np.ndarray] = "blue", #should probably be a colormapping? - sort_order=True, - alpha=None, - highlights=(), - right_margin=None, - left_margin=None, - projection: Literal["2d", "3d"] = "2d", - title=None, - component_name="DC", - component_indexnames=(1, 2, 3), - axis_labels=None, - colorbars=(False,), - sizes=(1,), - markers=".", - color_map="viridis", - show_ticks=True, - ax=None, + adata: AnnData, + x: str | None = None, + y: str | None = None, + basis: str | None = None, + color_by: str | None = None, + title: str | None = None, + color_opts: ColorOpts | dict | None = None, + cmap: str | None = None, + palette: Sequence[ColorLike] | None = None, + legend_opts: LegendOpts | dict | None = None, + legend_loc: _LegendLoc | None = "right margin", ): - # fig = hv.render(hv.Points(adata, Y), backend="matplotlib") - # points = hv.Points(adata, Y, colors).opts(color=colors) - # points.opts(color = colors) - # if colors is a column in obs --> hv.Points(adata, Y, colors).opts(color=colors) - # if colors is just one color --> hv.Points(adata, Y).opts(color=colors) --> but is this needed? - # if colors is a var_name --> hv.Points(adata, Y, colors).opts(color=colors) + # determine which dims to use + if basis is not None: + kdims = [f"obsm.X_{basis}.0", f"obsm.X_{basis}.1"] + vdims = [] + if color_by is not None and color_by in adata.obs.columns: + vdims = [f"obs.{color_by}"] if color_by is not None else [] + elif color_by is not None and color_by in adata.var_names: + vdims = [color_by] + elif _check_if_annotations(adata, "obs", x=x, y=y, color = color_by): + kdims = [f"obs.{x}", f"obs.{y}"] + vdims = [f"obs.{color_by}"] if color_by is not None else [] + elif _check_if_annotations(adata, "var", x=x, y=y, color = color_by): + kdims = [f"var.{x}", f"var.{y}"] + vdims = [f"var.{color_by}"] if color_by is not None else [] + else: + msg = ( + "`x`, `y`, and potential `color` inputs must all " + "come from either `.obs` or `.var`" + ) + raise ValueError(msg) + + if title is None and color_by is not None: + title = color_by.replace("_", " ") + + title_opts = {"title": title} + + # check if color_opts is a dict + if isinstance(color_opts, dict): + color_opts = get_color_opts(kdims, vdims, color_by, **color_opts) + elif color_opts is None: + color_opts = get_color_opts(kdims, vdims, color_by, cmap, palette) + + all_opts = { + **title_opts, + **asdict(color_opts), + } + + print(locals()) - if color_map is None: - color_map = "viridis" + print(kdims) + print(vdims) + print(all_opts) - legend_opts = LegendOpts() - color_opts = ColorOpts(color = colors, cmap = color_map) - size_opts = SizeOpts() + return hv.Points(adata, kdims, vdims).opts(**all_opts) - # merge opts dicts - opts = {**asdict(legend_opts), - **asdict(color_opts), - **asdict(size_opts)} +def get_color_opts(kdims, vdims, color_by = None, cmap = None, palette = None): + args = {} - return (hv.Points(adata, Y, colors).opts(**opts)) + if len(vdims) != 0: + args["color"] = vdims[0] + if palette is not None: + args["cmap"] = palette + if cmap is not None: + args["cmap"] = cmap + return ColorOpts(**args) From baefef02a71064e3c157b35dd6918362c78b2669 Mon Sep 17 00:00:00 2001 From: LouiseDck Date: Wed, 19 Mar 2025 10:23:09 +0100 Subject: [PATCH 4/4] add defaults, not great --- src/anndata_plot/pl/render_utils.py | 11 ++- src/anndata_plot/pl/scatter.py | 100 +++++++++------------------- 2 files changed, 41 insertions(+), 70 deletions(-) diff --git a/src/anndata_plot/pl/render_utils.py b/src/anndata_plot/pl/render_utils.py index 577c2fe..74d0d47 100644 --- a/src/anndata_plot/pl/render_utils.py +++ b/src/anndata_plot/pl/render_utils.py @@ -26,7 +26,16 @@ @dataclass class LegendOpts: - legend_position: str = "right" + legend_position: str = "inner" + legend_cols: int = 1 + show_legend: bool = True + +@dataclass +class LegendOptsMpl(LegendOpts): + legend_font_weight: _FontWeight = "normal" + legend_font_size: _FontSize = "medium" + + @dataclass class ColorOpts: diff --git a/src/anndata_plot/pl/scatter.py b/src/anndata_plot/pl/scatter.py index c13ca65..8aa9984 100644 --- a/src/anndata_plot/pl/scatter.py +++ b/src/anndata_plot/pl/scatter.py @@ -22,60 +22,8 @@ import holoviews as hv -# def scatter( -# adata, #Y: np.ndarray, -# Y: Sequence[str], -# title: str | None = None, -# color_by: str | None = None, -# legend_opts: LegendOpts = LegendOpts(), -# color_opts: ColorOpts = ColorOpts(), -# size_opts: SizeOpts = SizeOpts(), -# aixs_opts: AxisOpts = AxisOpts(), -# backend_opts: dict = None, - -# # *, -# # colors: str | Sequence[ColorLike | np.ndarray] = "blue", #should probably be a colormapping? -# # sort_order=True, -# # alpha=None, -# # highlights=(), -# # right_margin=None, -# # left_margin=None, -# # projection: Literal["2d", "3d"] = "2d", -# # title=None, -# # component_name="DC", -# # component_indexnames=(1, 2, 3), -# # axis_labels=None, -# # colorbars=(False,), -# # sizes=(1,), -# # markers=".", -# # color_map="viridis", -# # show_ticks=True, -# # ax=None, -# ): -# # fig = hv.render(hv.Points(adata, Y), backend="matplotlib") -# # points = hv.Points(adata, Y, colors).opts(color=colors) -# # points.opts(color = colors) - -# # if colors is a column in obs --> hv.Points(adata, Y, colors).opts(color=colors) -# # if colors is just one color --> hv.Points(adata, Y).opts(color=colors) --> but is this needed? -# # if colors is a var_name --> hv.Points(adata, Y, colors).opts(color=colors) - -# if title is None: -# title = f"Scatter plot of {Y[0]} and {Y[1]}" - -# # merge opts dicts -# opts = {**asdict(legend_opts), -# **asdict(color_opts), -# **asdict(size_opts), -# "title": title, -# } - -# if color_by is None: -# return (hv.Points(adata, Y).opts(**opts)) - -# return (hv.Points(adata, Y, color_by).opts(**opts)) - # copied from scanpy, pl._anndata line +- 235 +# adapted to work with 1 color, that can only be in names or be a column def _check_if_annotations( adata: AnnData, axis_name: Literal["obs", "var"], @@ -113,10 +61,9 @@ def scatter( color_by: str | None = None, title: str | None = None, color_opts: ColorOpts | dict | None = None, - cmap: str | None = None, - palette: Sequence[ColorLike] | None = None, legend_opts: LegendOpts | dict | None = None, - legend_loc: _LegendLoc | None = "right margin", + size_opts: SizeOpts | dict | None = None, + interactive: bool = False, ): # determine which dims to use @@ -145,26 +92,19 @@ def scatter( title_opts = {"title": title} - # check if color_opts is a dict - if isinstance(color_opts, dict): - color_opts = get_color_opts(kdims, vdims, color_by, **color_opts) - elif color_opts is None: - color_opts = get_color_opts(kdims, vdims, color_by, cmap, palette) + if interactive: + allopts = get_interactive_opts(vdims[0]) + else: + allopts = get_static_opts(vdims[0]) all_opts = { **title_opts, - **asdict(color_opts), + **allopts } - print(locals()) - - print(kdims) - print(vdims) - print(all_opts) - return hv.Points(adata, kdims, vdims).opts(**all_opts) -def get_color_opts(kdims, vdims, color_by = None, cmap = None, palette = None): +def get_color_opts(kdims, vdims, interactive, color_by = None, cmap = None, palette = None, **kwargs): args = {} if len(vdims) != 0: @@ -174,3 +114,25 @@ def get_color_opts(kdims, vdims, color_by = None, cmap = None, palette = None): if cmap is not None: args["cmap"] = cmap return ColorOpts(**args) + +def get_legend_opts(legend_loc = None, interactive = False, **kwargs): + if legend_loc is not None: + return LegendOpts(legend_position = legend_loc, **kwargs) + return LegendOpts(**kwargs) + + +def get_interactive_opts(color_by): + return { + "cmap": "viridis", + "color": color_by, + "width": 550, + "height": 550, + "legend_position": "bottom_left" + } + +def get_static_opts(color_by): + return { + "cmap": "viridis", + "color": color_by, + "fig_size": 250, + }