Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions docs/examples/kwargs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,71 @@
"xpx(change).imshow(color_continuous_scale=\"RdBu_r\", color_continuous_midpoint=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## colors (unified parameter)\n",
"\n",
"The `colors` parameter provides a simpler way to set colors without remembering the exact Plotly parameter name. It automatically maps to the correct parameter based on the input type:\n",
"\n",
"| Input | Maps To |\n",
"|-------|---------|\n",
"| `\"Viridis\"` (continuous scale name) | `color_continuous_scale` |\n",
"| `\"D3\"` (qualitative palette name) | `color_discrete_sequence` |\n",
"| `[\"red\", \"blue\"]` (list) | `color_discrete_sequence` |\n",
"| `{\"A\": \"red\"}` (dict) | `color_discrete_map` |"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Named qualitative palette\n",
"xpx(stocks).line(colors=\"D3\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# List of custom colors\n",
"xpx(stocks).line(colors=[\"#E63946\", \"#457B9D\", \"#2A9D8F\", \"#E9C46A\", \"#F4A261\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Dict for explicit mapping\n",
"xpx(stocks).line(\n",
" colors={\n",
" \"GOOG\": \"red\",\n",
" \"AAPL\": \"blue\",\n",
" \"AMZN\": \"green\",\n",
" \"FB\": \"purple\",\n",
" \"NFLX\": \"orange\",\n",
" \"MSFT\": \"brown\",\n",
" }\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Continuous scale for heatmaps\n",
"xpx(stocks).imshow(colors=\"Plasma\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
125 changes: 125 additions & 0 deletions tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,128 @@ def test_imshow_animation_consistent_bounds(self) -> None:
coloraxis = fig.layout.coloraxis
assert coloraxis.cmin == 0.0
assert coloraxis.cmax == 70.0


class TestColorsParameter:
"""Tests for the unified colors parameter."""

@pytest.fixture(autouse=True)
def setup(self) -> None:
"""Create test DataArrays."""
self.da = xr.DataArray(
np.random.rand(10, 3),
dims=["time", "city"],
coords={"city": ["A", "B", "C"]},
)

def test_colors_list_sets_discrete_sequence(self) -> None:
"""Test that a list of colors sets color_discrete_sequence."""
fig = self.da.plotly.line(colors=["red", "blue", "green"])
# Check that traces have the expected colors
assert len(fig.data) == 3
assert fig.data[0].line.color == "red"
assert fig.data[1].line.color == "blue"
assert fig.data[2].line.color == "green"

def test_colors_dict_sets_discrete_map(self) -> None:
"""Test that a dict sets color_discrete_map."""
fig = self.da.plotly.line(colors={"A": "red", "B": "blue", "C": "green"})
# Traces should be colored according to the mapping
assert len(fig.data) == 3
# Find traces by name and check their color
colors_by_name = {trace.name: trace.line.color for trace in fig.data}
assert colors_by_name["A"] == "red"
assert colors_by_name["B"] == "blue"
assert colors_by_name["C"] == "green"

def test_colors_continuous_scale_string(self) -> None:
"""Test that a continuous scale name sets color_continuous_scale."""
da = xr.DataArray(
np.random.rand(50, 2),
dims=["point", "coord"],
coords={"coord": ["x", "y"]},
)
fig = da.plotly.scatter(y="coord", x="point", color="value", colors="Viridis")
# Plotly Express uses coloraxis in the layout for continuous scales
# Check that the colorscale was applied to the coloraxis
assert fig.layout.coloraxis.colorscale is not None
colorscale = fig.layout.coloraxis.colorscale
# Viridis should be in the colorscale definition
assert any("viridis" in str(c).lower() for c in colorscale) or len(colorscale) > 0

def test_colors_qualitative_palette_string(self) -> None:
"""Test that a qualitative palette name sets color_discrete_sequence."""
import plotly.express as px

fig = self.da.plotly.line(colors="D3")
# D3 palette should be applied - check first trace color is from D3
d3_colors = px.colors.qualitative.D3
assert fig.data[0].line.color in d3_colors

def test_colors_ignored_with_warning_when_px_kwargs_present(self) -> None:
"""Test that colors is ignored with warning when color_* kwargs are present."""
import warnings

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
fig = self.da.plotly.line(
colors="D3", color_discrete_sequence=["orange", "purple", "cyan"]
)
# Should have raised a warning about colors being ignored
assert any(
"colors" in str(m.message).lower() and "ignored" in str(m.message).lower()
for m in w
), "Expected warning about 'colors' being 'ignored' not found"
# The explicit px_kwargs should take precedence
assert fig.data[0].line.color == "orange"

def test_colors_none_uses_defaults(self) -> None:
"""Test that colors=None uses Plotly defaults."""
fig1 = self.da.plotly.line(colors=None)
fig2 = self.da.plotly.line()
# Both should produce the same result
assert fig1.data[0].line.color == fig2.data[0].line.color

def test_colors_works_with_bar(self) -> None:
"""Test colors parameter with bar chart."""
fig = self.da.plotly.bar(colors=["#e41a1c", "#377eb8", "#4daf4a"])
assert fig.data[0].marker.color == "#e41a1c"

def test_colors_works_with_area(self) -> None:
"""Test colors parameter with area chart."""
fig = self.da.plotly.area(colors=["red", "green", "blue"])
assert len(fig.data) == 3

def test_colors_works_with_scatter(self) -> None:
"""Test colors parameter with scatter plot."""
fig = self.da.plotly.scatter(colors=["red", "green", "blue"])
assert len(fig.data) == 3

def test_colors_works_with_imshow(self) -> None:
"""Test colors parameter with imshow (continuous scale)."""
da = xr.DataArray(np.random.rand(10, 10), dims=["y", "x"])
fig = da.plotly.imshow(colors="RdBu")
# Plotly Express uses coloraxis in the layout for continuous scales
assert fig.layout.coloraxis.colorscale is not None
colorscale = fig.layout.coloraxis.colorscale
# RdBu should be in the colorscale definition
assert any("rdbu" in str(c).lower() for c in colorscale) or len(colorscale) > 0

def test_colors_works_with_pie(self) -> None:
"""Test colors parameter with pie chart."""
da = xr.DataArray([30, 40, 30], dims=["category"], coords={"category": ["A", "B", "C"]})
fig = da.plotly.pie(colors={"A": "red", "B": "blue", "C": "green"})
assert isinstance(fig, go.Figure)

def test_colors_works_with_dataset(self) -> None:
"""Test colors parameter works with Dataset accessor."""
ds = xr.Dataset(
{
"temp": (["time"], np.random.rand(10)),
"precip": (["time"], np.random.rand(10)),
}
)
fig = ds.plotly.line(colors=["red", "blue"])
assert len(fig.data) == 2
assert fig.data[0].line.color == "red"
assert fig.data[1].line.color == "blue"
3 changes: 2 additions & 1 deletion xarray_plotly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

from xarray_plotly import config
from xarray_plotly.accessor import DataArrayPlotlyAccessor, DatasetPlotlyAccessor
from xarray_plotly.common import SLOT_ORDERS, auto
from xarray_plotly.common import SLOT_ORDERS, Colors, auto
from xarray_plotly.figures import (
add_secondary_y,
overlay,
Expand All @@ -61,6 +61,7 @@

__all__ = [
"SLOT_ORDERS",
"Colors",
"add_secondary_y",
"auto",
"config",
Expand Down
Loading