From 11273de6178e93772ddcb4748b54732c2d9e1034 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 26 Jan 2026 17:46:00 +0100 Subject: [PATCH 1/3] Add mypy to ci and update types in code --- .github/workflows/ci.yml | 3 +++ xarray_plotly/__init__.py | 4 ++-- xarray_plotly/common.py | 2 +- xarray_plotly/config.py | 14 +++++++------- xarray_plotly/figures.py | 12 +++++++----- 5 files changed, 20 insertions(+), 15 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e53bdc6..194f981 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,6 +31,9 @@ jobs: - name: Format check run: uv run ruff format --check . + - name: Type check + run: uv run mypy xarray_plotly + - name: Test run: uv run pytest --cov=xarray_plotly --cov-report=xml diff --git a/xarray_plotly/__init__.py b/xarray_plotly/__init__.py index d377b4c..ad89d62 100644 --- a/xarray_plotly/__init__.py +++ b/xarray_plotly/__init__.py @@ -110,5 +110,5 @@ def xpx(data: DataArray | Dataset) -> DataArrayPlotlyAccessor | DatasetPlotlyAcc __version__ = version("xarray_plotly") # Register the accessors -register_dataarray_accessor("plotly")(DataArrayPlotlyAccessor) -register_dataset_accessor("plotly")(DatasetPlotlyAccessor) +register_dataarray_accessor("plotly")(DataArrayPlotlyAccessor) # type: ignore[no-untyped-call] +register_dataset_accessor("plotly")(DatasetPlotlyAccessor) # type: ignore[no-untyped-call] diff --git a/xarray_plotly/common.py b/xarray_plotly/common.py index 898c980..c0e2344 100644 --- a/xarray_plotly/common.py +++ b/xarray_plotly/common.py @@ -147,7 +147,7 @@ def to_dataframe(darray: DataArray) -> pd.DataFrame: return df -def _get_label_from_attrs(attrs: dict, fallback: str) -> str: +def _get_label_from_attrs(attrs: dict[str, object], fallback: str) -> str: """Extract a label from xarray attributes based on current config. Args: diff --git a/xarray_plotly/config.py b/xarray_plotly/config.py index 0f704a3..99d6e6c 100644 --- a/xarray_plotly/config.py +++ b/xarray_plotly/config.py @@ -8,7 +8,7 @@ from contextlib import contextmanager from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast if TYPE_CHECKING: from collections.abc import Generator @@ -166,12 +166,12 @@ def set_options( yield finally: # Restore old values (modify in place) - _options.label_use_long_name = old_values["label_use_long_name"] - _options.label_use_standard_name = old_values["label_use_standard_name"] - _options.label_include_units = old_values["label_include_units"] - _options.label_unit_format = old_values["label_unit_format"] - _options.slot_orders = old_values["slot_orders"] - _options.dataset_variable_position = old_values["dataset_variable_position"] + _options.label_use_long_name = cast(bool, old_values["label_use_long_name"]) + _options.label_use_standard_name = cast(bool, old_values["label_use_standard_name"]) + _options.label_include_units = cast(bool, old_values["label_include_units"]) + _options.label_unit_format = cast(str, old_values["label_unit_format"]) + _options.slot_orders = cast(dict[str, tuple[str, ...]], old_values["slot_orders"]) + _options.dataset_variable_position = cast(int, old_values["dataset_variable_position"]) def notebook(renderer: str = "notebook") -> None: diff --git a/xarray_plotly/figures.py b/xarray_plotly/figures.py index 70bf660..2bb4031 100644 --- a/xarray_plotly/figures.py +++ b/xarray_plotly/figures.py @@ -5,7 +5,7 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from collections.abc import Iterator @@ -13,7 +13,7 @@ import plotly.graph_objects as go -def _iter_all_traces(fig: go.Figure) -> Iterator: +def _iter_all_traces(fig: go.Figure) -> Iterator[Any]: """Iterate over all traces in a figure, including animation frames. Yields traces from fig.data first, then from each frame in fig.frames. @@ -107,7 +107,7 @@ def _merge_frames( overlays: list[go.Figure], base_trace_count: int, overlay_trace_counts: list[int], -) -> list: +) -> list[go.Frame]: """Merge animation frames from base and overlay figures. Args: @@ -360,7 +360,7 @@ def _merge_secondary_y_frames( base: go.Figure, secondary: go.Figure, y_mapping: dict[str, str], -) -> list: +) -> list[go.Frame]: """Merge animation frames for secondary y-axis combination. Args: @@ -411,7 +411,9 @@ def _merge_secondary_y_frames( return merged_frames -def update_traces(fig: go.Figure, selector: dict | None = None, **kwargs) -> go.Figure: +def update_traces( + fig: go.Figure, selector: dict[str, Any] | None = None, **kwargs: Any +) -> go.Figure: """Update traces in both base figure and all animation frames. Plotly's `update_traces()` only updates the base figure, not animation frames. From 6d5af29739132912dd6ab97b87e6b9527fb0de62 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 26 Jan 2026 17:58:47 +0100 Subject: [PATCH 2/3] Use quoted cast types --- xarray_plotly/config.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/xarray_plotly/config.py b/xarray_plotly/config.py index 99d6e6c..d9d8655 100644 --- a/xarray_plotly/config.py +++ b/xarray_plotly/config.py @@ -166,12 +166,12 @@ def set_options( yield finally: # Restore old values (modify in place) - _options.label_use_long_name = cast(bool, old_values["label_use_long_name"]) - _options.label_use_standard_name = cast(bool, old_values["label_use_standard_name"]) - _options.label_include_units = cast(bool, old_values["label_include_units"]) - _options.label_unit_format = cast(str, old_values["label_unit_format"]) - _options.slot_orders = cast(dict[str, tuple[str, ...]], old_values["slot_orders"]) - _options.dataset_variable_position = cast(int, old_values["dataset_variable_position"]) + _options.label_use_long_name = cast("bool", old_values["label_use_long_name"]) + _options.label_use_standard_name = cast("bool", old_values["label_use_standard_name"]) + _options.label_include_units = cast("bool", old_values["label_include_units"]) + _options.label_unit_format = cast("str", old_values["label_unit_format"]) + _options.slot_orders = cast("dict[str, tuple[str, ...]]", old_values["slot_orders"]) + _options.dataset_variable_position = cast("int", old_values["dataset_variable_position"]) def notebook(renderer: str = "notebook") -> None: From 6c91ab0a7a5fd9755128826cdbe52571bc99d6b3 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 26 Jan 2026 18:00:16 +0100 Subject: [PATCH 3/3] fix typing --- xarray_plotly/plotting.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray_plotly/plotting.py b/xarray_plotly/plotting.py index b1e5085..acaea20 100644 --- a/xarray_plotly/plotting.py +++ b/xarray_plotly/plotting.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any import numpy as np +import numpy.typing as npt import plotly.express as px from xarray_plotly.common import ( @@ -171,7 +172,7 @@ def bar( ) -def _classify_trace_sign(y_values: np.ndarray) -> str: +def _classify_trace_sign(y_values: npt.ArrayLike) -> str: """Classify a trace as 'positive', 'negative', or 'mixed' based on its values.""" y_arr = np.asarray(y_values) y_clean = y_arr[np.isfinite(y_arr) & (np.abs(y_arr) > 1e-9)]