Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ jobs:
- name: Format check
run: uv run ruff format --check .

- name: Type check
run: uv run mypy xarray_plotly

- name: Test
run: uv run pytest --cov=xarray_plotly --cov-report=xml

Expand Down
4 changes: 2 additions & 2 deletions xarray_plotly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,5 @@ def xpx(data: DataArray | Dataset) -> DataArrayPlotlyAccessor | DatasetPlotlyAcc
__version__ = version("xarray_plotly")

# Register the accessors
register_dataarray_accessor("plotly")(DataArrayPlotlyAccessor)
register_dataset_accessor("plotly")(DatasetPlotlyAccessor)
register_dataarray_accessor("plotly")(DataArrayPlotlyAccessor) # type: ignore[no-untyped-call]
register_dataset_accessor("plotly")(DatasetPlotlyAccessor) # type: ignore[no-untyped-call]
2 changes: 1 addition & 1 deletion xarray_plotly/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def to_dataframe(darray: DataArray) -> pd.DataFrame:
return df


def _get_label_from_attrs(attrs: dict, fallback: str) -> str:
def _get_label_from_attrs(attrs: dict[str, object], fallback: str) -> str:
"""Extract a label from xarray attributes based on current config.

Args:
Expand Down
14 changes: 7 additions & 7 deletions xarray_plotly/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

if TYPE_CHECKING:
from collections.abc import Generator
Expand Down Expand Up @@ -166,12 +166,12 @@ def set_options(
yield
finally:
# Restore old values (modify in place)
_options.label_use_long_name = old_values["label_use_long_name"]
_options.label_use_standard_name = old_values["label_use_standard_name"]
_options.label_include_units = old_values["label_include_units"]
_options.label_unit_format = old_values["label_unit_format"]
_options.slot_orders = old_values["slot_orders"]
_options.dataset_variable_position = old_values["dataset_variable_position"]
_options.label_use_long_name = cast("bool", old_values["label_use_long_name"])
_options.label_use_standard_name = cast("bool", old_values["label_use_standard_name"])
_options.label_include_units = cast("bool", old_values["label_include_units"])
_options.label_unit_format = cast("str", old_values["label_unit_format"])
_options.slot_orders = cast("dict[str, tuple[str, ...]]", old_values["slot_orders"])
_options.dataset_variable_position = cast("int", old_values["dataset_variable_position"])


def notebook(renderer: str = "notebook") -> None:
Expand Down
12 changes: 7 additions & 5 deletions xarray_plotly/figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from __future__ import annotations

import copy
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from collections.abc import Iterator

import plotly.graph_objects as go


def _iter_all_traces(fig: go.Figure) -> Iterator:
def _iter_all_traces(fig: go.Figure) -> Iterator[Any]:
"""Iterate over all traces in a figure, including animation frames.

Yields traces from fig.data first, then from each frame in fig.frames.
Expand Down Expand Up @@ -107,7 +107,7 @@ def _merge_frames(
overlays: list[go.Figure],
base_trace_count: int,
overlay_trace_counts: list[int],
) -> list:
) -> list[go.Frame]:
"""Merge animation frames from base and overlay figures.

Args:
Expand Down Expand Up @@ -360,7 +360,7 @@ def _merge_secondary_y_frames(
base: go.Figure,
secondary: go.Figure,
y_mapping: dict[str, str],
) -> list:
) -> list[go.Frame]:
"""Merge animation frames for secondary y-axis combination.

Args:
Expand Down Expand Up @@ -411,7 +411,9 @@ def _merge_secondary_y_frames(
return merged_frames


def update_traces(fig: go.Figure, selector: dict | None = None, **kwargs) -> go.Figure:
def update_traces(
fig: go.Figure, selector: dict[str, Any] | None = None, **kwargs: Any
) -> go.Figure:
"""Update traces in both base figure and all animation frames.

Plotly's `update_traces()` only updates the base figure, not animation frames.
Expand Down
3 changes: 2 additions & 1 deletion xarray_plotly/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import TYPE_CHECKING, Any

import numpy as np
import numpy.typing as npt
import plotly.express as px

from xarray_plotly.common import (
Expand Down Expand Up @@ -171,7 +172,7 @@ def bar(
)


def _classify_trace_sign(y_values: np.ndarray) -> str:
def _classify_trace_sign(y_values: npt.ArrayLike) -> str:
"""Classify a trace as 'positive', 'negative', or 'mixed' based on its values."""
y_arr = np.asarray(y_values)
y_clean = y_arr[np.isfinite(y_arr) & (np.abs(y_arr) > 1e-9)]
Expand Down