diff --git a/docs/examples/manipulation.ipynb b/docs/examples/manipulation.ipynb
index 6b7f160..cb4176a 100644
--- a/docs/examples/manipulation.ipynb
+++ b/docs/examples/manipulation.ipynb
@@ -7,7 +7,8 @@
"source": [
"# Figure Manipulation\n",
"\n",
- "What's easy, what's annoying, and how to work around it."
+ "After creating a figure with `xpx()`, you can manipulate it using standard Plotly methods.\n",
+ "This notebook shows what works out of the box, and where `update_traces` from xarray-plotly helps."
]
},
{
@@ -17,185 +18,192 @@
"metadata": {},
"outputs": [],
"source": [
+ "import numpy as np\n",
"import plotly.express as px\n",
- "import plotly.graph_objects as go\n",
- "import plotly.io as pio\n",
+ "import xarray as xr\n",
"\n",
- "from xarray_plotly import overlay, update_traces\n",
+ "from xarray_plotly import config, update_traces, xpx\n",
"\n",
- "pio.renderers.default = \"notebook_connected\"\n",
+ "config.notebook()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# 4D DataArray: scenario x metric x year x country\n",
+ "df_gap = px.data.gapminder()\n",
+ "countries = [\"United States\", \"China\", \"Germany\", \"Brazil\"]\n",
+ "metrics = [\"lifeExp\", \"gdpPercap\"]\n",
"\n",
- "# Sample data\n",
- "df = px.data.gapminder()\n",
- "df_2007 = df.query(\"year == 2007\")\n",
- "df_countries = df.query(\"country in ['United States', 'China', 'Germany', 'Brazil']\")"
+ "# Build base 3D array (metric x year x country)\n",
+ "arrays = []\n",
+ "for metric in metrics:\n",
+ " df_pivot = df_gap[df_gap[\"country\"].isin(countries)].pivot(\n",
+ " index=\"year\", columns=\"country\", values=metric\n",
+ " )\n",
+ " arrays.append(df_pivot.values)\n",
+ "\n",
+ "base_3d = np.stack(arrays)\n",
+ "\n",
+ "# Add scenario dimension (4D): original + 10% higher\n",
+ "scenarios = [\"baseline\", \"optimistic\"]\n",
+ "data_4d = np.stack([base_3d, base_3d * 1.1])\n",
+ "\n",
+ "da = xr.DataArray(\n",
+ " data_4d,\n",
+ " dims=[\"scenario\", \"metric\", \"year\", \"country\"],\n",
+ " coords={\n",
+ " \"scenario\": scenarios,\n",
+ " \"metric\": metrics,\n",
+ " \"year\": df_pivot.index.tolist(),\n",
+ " \"country\": df_pivot.columns.tolist(),\n",
+ " },\n",
+ " name=\"value\",\n",
+ ")\n",
+ "da"
]
},
{
"cell_type": "markdown",
- "id": "2",
+ "id": "3",
"metadata": {},
"source": [
"---\n",
- "# Easy: Single Plots\n",
+ "## Standard Plotly Methods\n",
"\n",
- "All standard manipulation methods work as expected."
+ "All standard Plotly manipulation methods work on figures created with `xpx()`."
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "3",
+ "id": "4",
"metadata": {},
"outputs": [],
"source": [
- "fig = px.scatter(df_2007, x=\"gdpPercap\", y=\"lifeExp\", color=\"continent\", size=\"pop\")\n",
+ "# Simple 2D slice\n",
+ "fig = xpx(da.sel(scenario=\"baseline\", metric=\"lifeExp\")).line()\n",
"fig"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "4",
+ "id": "5",
"metadata": {},
"outputs": [],
"source": [
"# Layout\n",
- "fig.update_layout(title=\"GDP vs Life Expectancy\", template=\"plotly_white\")\n",
+ "fig.update_layout(title=\"Life Expectancy Over Time\", template=\"plotly_white\")\n",
"\n",
"# All traces\n",
- "fig.update_traces(marker_opacity=0.7)\n",
+ "fig.update_traces(line_width=3)\n",
"\n",
- "# Specific traces\n",
- "fig.update_traces(marker_line_width=2, selector={\"name\": \"Europe\"})\n",
+ "# Specific trace by name\n",
+ "fig.update_traces(line_dash=\"dot\", selector={\"name\": \"Germany\"})\n",
"\n",
"# Axes\n",
- "fig.update_xaxes(type=\"log\", title=\"GDP per Capita\")\n",
- "fig.update_yaxes(range=[40, 90])\n",
+ "fig.update_xaxes(title=\"Year\", showgrid=False)\n",
+ "fig.update_yaxes(title=\"Life Expectancy (years)\", range=[40, 85])\n",
"\n",
- "# Annotations and shapes\n",
- "fig.add_hline(y=df_2007[\"lifeExp\"].mean(), line_dash=\"dash\", line_color=\"gray\")\n",
+ "# Reference line\n",
+ "fig.add_hline(y=70, line_dash=\"dash\", line_color=\"gray\", annotation_text=\"Target\")\n",
"\n",
"fig"
]
},
{
"cell_type": "markdown",
- "id": "5",
+ "id": "6",
"metadata": {},
"source": [
"---\n",
- "# Easy: Faceted Plots\n",
+ "## Faceted Plots\n",
"\n",
- "`update_traces`, `update_xaxes`, `update_yaxes` all work across facets."
+ "`update_traces`, `update_xaxes`, `update_yaxes` work across all facets."
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "6",
+ "id": "7",
"metadata": {},
"outputs": [],
"source": [
- "fig = px.line(df_countries, x=\"year\", y=\"gdpPercap\", color=\"country\", facet_col=\"country\")\n",
+ "# Facet by metric, color by country\n",
+ "fig = xpx(da.sel(scenario=\"baseline\")).line(facet_col=\"metric\")\n",
"fig"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "7",
+ "id": "8",
"metadata": {},
"outputs": [],
"source": [
"# Update ALL traces across all facets\n",
- "fig.update_traces(line_width=3)\n",
+ "fig.update_traces(line_width=2)\n",
"\n",
- "# Update ALL x-axes\n",
+ "# Update ALL axes\n",
"fig.update_xaxes(showgrid=False)\n",
"\n",
- "# Update ALL y-axes\n",
- "fig.update_yaxes(showgrid=False, type=\"log\")\n",
- "\n",
- "# Clean up facet labels\n",
- "fig.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n",
+ "# Target specific facet (1-indexed)\n",
+ "fig.update_yaxes(type=\"log\", col=2) # log scale only for gdpPercap\n",
"\n",
"fig"
]
},
{
"cell_type": "markdown",
- "id": "8",
+ "id": "9",
"metadata": {},
"source": [
- "### Targeting specific facets\n",
- "\n",
- "Use `row=` and `col=` (1-indexed) to target specific facets."
+ "### Grid layout with facet_row"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "9",
+ "id": "10",
"metadata": {},
"outputs": [],
"source": [
- "fig = px.histogram(px.data.tips(), x=\"total_bill\", facet_row=\"sex\", facet_col=\"time\")\n",
+ "# 2x2 grid: scenario x metric\n",
+ "fig = xpx(da).line(facet_col=\"metric\", facet_row=\"scenario\")\n",
"\n",
- "# Target specific cell\n",
- "fig.update_yaxes(title_text=\"Frequency\", row=1, col=1)\n",
- "\n",
- "# Target entire column\n",
- "fig.update_xaxes(title_text=\"Bill ($)\", col=2)\n",
- "\n",
- "# Target entire row\n",
- "fig.update_traces(marker_color=\"orange\", row=2)\n",
- "\n",
- "fig.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n",
+ "fig.update_traces(line_width=2)\n",
+ "fig.update_yaxes(type=\"log\", col=2) # log scale for gdpPercap column\n",
"fig"
]
},
{
"cell_type": "markdown",
- "id": "10",
+ "id": "11",
"metadata": {},
"source": [
- "### Reference lines on facets\n",
+ "---\n",
+ "## Animation: The Pain Point\n",
"\n",
- "`add_hline`/`add_vline` apply to all facets by default. Use `row=`/`col=` to target."
+ "Plotly's `fig.update_traces()` does **not** update animation frames. This is the main gotcha."
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "11",
+ "id": "12",
"metadata": {},
"outputs": [],
"source": [
- "fig = px.scatter(df_2007, x=\"gdpPercap\", y=\"lifeExp\", facet_col=\"continent\", facet_col_wrap=3)\n",
- "fig.update_xaxes(type=\"log\")\n",
- "\n",
- "# Applies to ALL facets\n",
- "fig.add_hline(y=70, line_dash=\"dash\", line_color=\"red\")\n",
- "\n",
- "# Specific facet only\n",
- "fig.add_hline(y=50, line_dash=\"dot\", line_color=\"blue\", row=2, col=1)\n",
- "\n",
- "fig.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n",
+ "# Animated bar chart\n",
+ "fig = xpx(da.sel(scenario=\"baseline\", metric=\"gdpPercap\")).bar(animation_frame=\"year\")\n",
"fig"
]
},
- {
- "cell_type": "markdown",
- "id": "12",
- "metadata": {},
- "source": [
- "---\n",
- "# Easy: Adding traces to faceted/animated figures\n",
- "\n",
- "Use `overlay` to add traces. It handles facets and animation frames automatically."
- ]
- },
{
"cell_type": "code",
"execution_count": null,
@@ -203,31 +211,11 @@
"metadata": {},
"outputs": [],
"source": [
- "# Animated scatter\n",
- "fig = px.scatter(\n",
- " df_countries,\n",
- " x=\"gdpPercap\",\n",
- " y=\"lifeExp\",\n",
- " color=\"country\",\n",
- " animation_frame=\"year\",\n",
- " log_x=True,\n",
- " range_y=[40, 85],\n",
- ")\n",
- "\n",
- "# Create a figure with reference marker\n",
- "ref = go.Figure(\n",
- " go.Scatter(\n",
- " x=[10000],\n",
- " y=[75],\n",
- " mode=\"markers\",\n",
- " marker={\"size\": 20, \"symbol\": \"star\", \"color\": \"gold\"},\n",
- " name=\"Target\",\n",
- " )\n",
- ")\n",
+ "# This only affects the INITIAL view!\n",
+ "fig.update_traces(marker_color=\"red\")\n",
"\n",
- "# Overlay - trace appears in all animation frames\n",
- "combined = overlay(fig, ref)\n",
- "combined"
+ "print(f\"Base trace color: {fig.data[0].marker.color}\")\n",
+ "print(f\"Frame 0 trace color: {fig.frames[0].data[0].marker.color}\")"
]
},
{
@@ -237,37 +225,8 @@
"metadata": {},
"outputs": [],
"source": [
- "# Faceted plot\n",
- "fig = px.scatter(df_2007, x=\"gdpPercap\", y=\"lifeExp\", facet_col=\"continent\", facet_col_wrap=3)\n",
- "fig.update_xaxes(type=\"log\")\n",
- "\n",
- "# Add reference to first facet (default axes x, y)\n",
- "ref1 = go.Figure(\n",
- " go.Scatter(\n",
- " x=[5000],\n",
- " y=[70],\n",
- " mode=\"markers\",\n",
- " marker={\"size\": 15, \"symbol\": \"star\", \"color\": \"gold\"},\n",
- " name=\"Target 1\",\n",
- " )\n",
- ")\n",
- "\n",
- "# Add reference to second facet (axes x2, y2)\n",
- "ref2 = go.Figure(\n",
- " go.Scatter(\n",
- " x=[20000],\n",
- " y=[80],\n",
- " mode=\"markers\",\n",
- " marker={\"size\": 15, \"symbol\": \"star\", \"color\": \"red\"},\n",
- " name=\"Target 2\",\n",
- " xaxis=\"x2\",\n",
- " yaxis=\"y2\", # specify target facet\n",
- " )\n",
- ")\n",
- "\n",
- "combined = overlay(fig, ref1, ref2)\n",
- "combined.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n",
- "combined"
+ "# Play the animation - it reverts to original colors\n",
+ "fig"
]
},
{
@@ -275,10 +234,9 @@
"id": "15",
"metadata": {},
"source": [
- "---\n",
- "# Annoying: Facet axis names\n",
+ "### Solution: `update_traces` from xarray-plotly\n",
"\n",
- "To target a specific facet with `add_shape`, `add_annotation`, or when adding traces via `overlay`, you need to know the axis name (`x2`, `y3`, etc.)."
+ "This helper updates both base traces and all animation frames."
]
},
{
@@ -288,407 +246,124 @@
"metadata": {},
"outputs": [],
"source": [
- "fig = px.scatter(df_2007, x=\"gdpPercap\", y=\"lifeExp\", facet_col=\"continent\", facet_col_wrap=3)\n",
- "\n",
- "# Inspect axis names\n",
- "layout_dict = fig.layout.to_plotly_json()\n",
- "print(\"X axes:\", sorted([k for k in layout_dict if k.startswith(\"xaxis\")]))\n",
- "print(\"Y axes:\", sorted([k for k in layout_dict if k.startswith(\"yaxis\")]))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "17",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Check which trace uses which axis\n",
- "for i, trace in enumerate(fig.data):\n",
- " print(f\"Trace {i} ({trace.name}): xaxis={trace.xaxis or 'x'}, yaxis={trace.yaxis or 'y'}\")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "18",
- "metadata": {},
- "source": [
- "**Tip:** For simple cases, use `add_hline`/`add_vline` with `row=`/`col=` instead of `add_shape` - it handles axis mapping internally."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "19",
- "metadata": {},
- "source": [
- "---\n",
- "# Annoying: Animation trace updates\n",
- "\n",
- "**This is the main pain point.** `update_traces()` does NOT update animation frames."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "20",
- "metadata": {},
- "outputs": [],
- "source": [
- "fig = px.line(df_countries, x=\"year\", y=\"gdpPercap\", color=\"country\", animation_frame=\"country\")\n",
- "fig"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "21",
- "metadata": {},
- "outputs": [],
- "source": [
- "# This only affects the INITIAL view, not the animation frames!\n",
- "fig.update_traces(line_width=5, line_dash=\"dot\")\n",
- "\n",
- "print(f\"Base trace line_width: {fig.data[0].line.width}\")\n",
- "print(f\"Frame 0 trace line_width: {fig.frames[0].data[0].line.width}\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "22",
- "metadata": {},
- "outputs": [],
- "source": [
- "# When you play the animation, it reverts to the frame's original style\n",
- "fig"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "23",
- "metadata": {},
- "source": [
- "### Solution: `update_traces`\n",
- "\n",
- "xarray-plotly provides this helper to update both base traces and animation frames:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "24",
- "metadata": {},
- "outputs": [],
- "source": [
- "# update_traces is imported from xarray_plotly\n",
- "# It updates traces in both base figure and all animation frames"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "25",
- "metadata": {},
- "outputs": [],
- "source": [
- "fig = px.line(df_countries, x=\"year\", y=\"gdpPercap\", color=\"country\", animation_frame=\"country\")\n",
- "\n",
- "update_traces(fig, line_width=4, line_dash=\"dot\")\n",
+ "fig = xpx(da.sel(scenario=\"baseline\", metric=\"gdpPercap\")).bar(animation_frame=\"year\")\n",
"\n",
- "print(f\"Base trace line_width: {fig.data[0].line.width}\")\n",
- "print(f\"Frame 0 trace line_width: {fig.frames[0].data[0].line.width}\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "26",
- "metadata": {},
- "outputs": [],
- "source": [
- "fig"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "27",
- "metadata": {},
- "source": [
- "### Selective updates with selector\n",
+ "update_traces(fig, marker_color=\"red\", marker_opacity=0.8)\n",
"\n",
- "Use `selector` to target specific traces by name:"
+ "print(f\"Base trace color: {fig.data[0].marker.color}\")\n",
+ "print(f\"Frame 0 trace color: {fig.frames[0].data[0].marker.color}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "28",
+ "id": "17",
"metadata": {},
"outputs": [],
"source": [
- "fig = px.line(df_countries, x=\"year\", y=\"gdpPercap\", color=\"country\", animation_frame=\"year\")\n",
- "\n",
- "# Update only one trace by name\n",
- "update_traces(fig, selector={\"name\": \"Germany\"}, line_width=5, line_dash=\"dot\")\n",
- "\n",
- "# Update multiple traces\n",
- "update_traces(fig, selector={\"name\": \"China\"}, line_color=\"red\", line_width=3)\n",
- "\n",
+ "# Now the style persists through animation\n",
"fig"
]
},
{
"cell_type": "markdown",
- "id": "29",
+ "id": "18",
"metadata": {},
"source": [
- "### Works with facets + animation"
+ "### Selective updates with selector"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "30",
+ "id": "19",
"metadata": {},
"outputs": [],
"source": [
- "df_subset = df.query(\n",
- " \"continent in ['Europe', 'Asia'] and country in ['Germany', 'France', 'China', 'Japan']\"\n",
- ")\n",
+ "fig = xpx(da.sel(scenario=\"baseline\", metric=\"lifeExp\")).line(x=\"year\")\n",
"\n",
- "fig = px.line(\n",
- " df_subset,\n",
- " x=\"year\",\n",
- " y=\"gdpPercap\",\n",
- " color=\"country\",\n",
- " facet_col=\"continent\",\n",
- " animation_frame=\"year\",\n",
- ")\n",
+ "# Highlight specific countries\n",
+ "update_traces(fig, selector={\"name\": \"China\"}, line_color=\"red\", line_width=4)\n",
+ "update_traces(fig, selector={\"name\": \"United States\"}, line_color=\"blue\", line_width=4)\n",
"\n",
- "update_traces(fig, line_width=3)\n",
- "fig.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n",
"fig"
]
},
{
"cell_type": "markdown",
- "id": "31",
- "metadata": {},
- "source": [
- "### What's affected\n",
- "\n",
- "Anything on **traces** needs the helper for animations:\n",
- "\n",
- "| Property | Facets | Animation |\n",
- "|----------|--------|-----------|\n",
- "| `line_width` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n",
- "| `line_dash` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n",
- "| `line_color` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n",
- "| `marker_size` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n",
- "| `marker_symbol` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n",
- "| `opacity` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n",
- "\n",
- "**Layout properties** (`update_layout`, `update_xaxes`, `update_yaxes`) work fine for animations."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "32",
- "metadata": {},
- "source": [
- "---\n",
- "# Annoying: Animation speed\n",
- "\n",
- "The API to change animation speed is deeply nested."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "33",
+ "id": "20",
"metadata": {},
- "outputs": [],
"source": [
- "fig = px.scatter(\n",
- " df,\n",
- " x=\"gdpPercap\",\n",
- " y=\"lifeExp\",\n",
- " color=\"continent\",\n",
- " size=\"pop\",\n",
- " animation_frame=\"year\",\n",
- " log_x=True,\n",
- " range_y=[25, 90],\n",
- ")\n",
+ "### Unified hover with animation\n",
"\n",
- "# This is... not intuitive\n",
- "fig.layout.updatemenus[0].buttons[0].args[1][\"frame\"][\"duration\"] = 100 # faster\n",
- "fig.layout.updatemenus[0].buttons[0].args[1][\"transition\"][\"duration\"] = 50\n",
+ "A common pattern: unified hover mode with custom formatting.\n",
"\n",
- "fig"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "34",
- "metadata": {},
- "source": [
- "### Workaround: Helper function"
+ "- **Layout** (`hovermode`, spikes): Standard Plotly works fine\n",
+ "- **Traces** (`hovertemplate`): Use `update_traces()` for animation support"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "35",
- "metadata": {},
- "outputs": [],
- "source": [
- "def set_animation_speed(fig, frame_duration=500, transition_duration=300):\n",
- " \"\"\"Set animation speed in milliseconds.\"\"\"\n",
- " if fig.layout.updatemenus:\n",
- " fig.layout.updatemenus[0].buttons[0].args[1][\"frame\"][\"duration\"] = frame_duration\n",
- " fig.layout.updatemenus[0].buttons[0].args[1][\"transition\"][\"duration\"] = transition_duration\n",
- " return fig"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "36",
+ "id": "21",
"metadata": {},
"outputs": [],
"source": [
- "fig = px.scatter(\n",
- " df,\n",
- " x=\"gdpPercap\",\n",
- " y=\"lifeExp\",\n",
- " color=\"continent\",\n",
- " animation_frame=\"year\",\n",
- " log_x=True,\n",
- " range_y=[25, 90],\n",
- ")\n",
+ "fig = xpx(da.sel(metric=\"gdpPercap\")).line(x=\"year\", animation_frame=\"scenario\")\n",
"\n",
- "set_animation_speed(fig, frame_duration=200, transition_duration=100)\n",
- "fig"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "37",
- "metadata": {},
- "source": [
- "---\n",
- "# Annoying: Slider styling\n",
+ "# Layout settings - standard Plotly\n",
+ "fig.update_layout(hovermode=\"x unified\")\n",
+ "fig.update_xaxes(showspikes=True, spikecolor=\"gray\", spikethickness=1)\n",
"\n",
- "Verbose but straightforward."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "38",
- "metadata": {},
- "outputs": [],
- "source": [
- "fig = px.scatter(\n",
- " df,\n",
- " x=\"gdpPercap\",\n",
- " y=\"lifeExp\",\n",
- " color=\"continent\",\n",
- " animation_frame=\"year\",\n",
- " log_x=True,\n",
- " range_y=[25, 90],\n",
- ")\n",
- "\n",
- "fig.layout.sliders[0].currentvalue.prefix = \"Year: \"\n",
- "fig.layout.sliders[0].currentvalue.font.size = 16\n",
- "fig.layout.sliders[0].pad.t = 50 # padding from top\n",
+ "# Trace settings - use update_traces for animation support\n",
+ "update_traces(fig, hovertemplate=\"%{fullData.name}: $%{y:,.0f}\")\n",
"\n",
"fig"
]
},
{
"cell_type": "markdown",
- "id": "39",
+ "id": "22",
"metadata": {},
"source": [
- "### Hide slider or play button"
+ "### Facets + Animation"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "40",
+ "id": "23",
"metadata": {},
"outputs": [],
"source": [
- "fig = px.scatter(\n",
- " df,\n",
- " x=\"gdpPercap\",\n",
- " y=\"lifeExp\",\n",
- " color=\"continent\",\n",
- " animation_frame=\"year\",\n",
- " log_x=True,\n",
- " range_y=[25, 90],\n",
- ")\n",
+ "# Facet by metric, animate by scenario\n",
+ "fig = xpx(da).line(facet_col=\"metric\", animation_frame=\"scenario\")\n",
"\n",
- "# Hide slider (keep play button)\n",
- "fig.layout.sliders = []\n",
+ "# Standard Plotly for layout\n",
+ "fig.update_yaxes(type=\"log\", col=2)\n",
"\n",
- "# Or hide play button (keep slider):\n",
- "# fig.layout.updatemenus = []\n",
+ "# update_traces for trace properties with animation\n",
+ "update_traces(fig, line_width=3)\n",
+ "update_traces(fig, selector={\"name\": \"China\"}, line_dash=\"dot\")\n",
"\n",
"fig"
]
},
{
"cell_type": "markdown",
- "id": "41",
+ "id": "24",
"metadata": {},
"source": [
"---\n",
- "# Summary\n",
- "\n",
- "### Provided by xarray-plotly\n",
+ "## Summary\n",
"\n",
- "```python\n",
- "from xarray_plotly import overlay, update_traces\n",
- "```\n",
- "\n",
- "### Local helper for animation speed"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "42",
- "metadata": {},
- "outputs": [],
- "source": [
- "def set_animation_speed(fig, frame_duration=500, transition_duration=300):\n",
- " \"\"\"Set animation speed in milliseconds.\"\"\"\n",
- " if fig.layout.updatemenus:\n",
- " fig.layout.updatemenus[0].buttons[0].args[1][\"frame\"][\"duration\"] = frame_duration\n",
- " fig.layout.updatemenus[0].buttons[0].args[1][\"transition\"][\"duration\"] = transition_duration\n",
- " return fig"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "43",
- "metadata": {},
- "source": [
- "### Quick reference\n",
- "\n",
- "| Task | Facets | Animation | Solution |\n",
- "|------|--------|-----------|----------|\n",
- "| Update trace style | `fig.update_traces()` | `update_traces()` | xarray-plotly helper |\n",
- "| Update axes | `update_xaxes()`/`update_yaxes()` | Same | ✅ Works |\n",
- "| Update layout | `update_layout()` | Same | ✅ Works |\n",
- "| Add reference line | `add_hline(row=, col=)` | `add_hline()` | ✅ Works |\n",
- "| Add trace | `overlay()` | `overlay()` | ✅ Works |\n",
- "| Add shape to specific facet | `add_shape(xref=\"x2\")` | Same | Need axis name |\n",
- "| Change animation speed | N/A | `set_animation_speed()` | Local helper |\n",
- "| Facet labels | `for_each_annotation()` | Same | ✅ Works |"
+ "| Method | Static/Faceted | Animated |\n",
+ "|--------|----------------|----------|\n",
+ "| `fig.update_layout()` | ✅ | ✅ |\n",
+ "| `fig.update_xaxes()` / `fig.update_yaxes()` | ✅ | ✅ |\n",
+ "| `fig.add_hline()` / `fig.add_vline()` | ✅ | ✅ |\n",
+ "| `fig.update_traces()` | ✅ | ❌ base only |\n",
+ "| `update_traces(fig, ...)` | ✅ | ✅ all frames |"
]
}
],
@@ -700,7 +375,7 @@
},
"language_info": {
"name": "python",
- "version": "3.11.0"
+ "version": "3.12.0"
}
},
"nbformat": 4,
diff --git a/xarray_plotly/figures.py b/xarray_plotly/figures.py
index 2e1f658..939d10d 100644
--- a/xarray_plotly/figures.py
+++ b/xarray_plotly/figures.py
@@ -8,9 +8,28 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
+ from collections.abc import Iterator
+
import plotly.graph_objects as go
+def _iter_all_traces(fig: go.Figure) -> Iterator:
+ """Iterate over all traces in a figure, including animation frames.
+
+ Yields traces from fig.data first, then from each frame in fig.frames.
+ Useful for applying styling to all traces including those in animations.
+
+ Args:
+ fig: Plotly Figure.
+
+ Yields:
+ Each trace object from the figure.
+ """
+ yield from fig.data
+ for frame in fig.frames or []:
+ yield from frame.data
+
+
def _get_subplot_axes(fig: go.Figure) -> set[tuple[str, str]]:
"""Extract (xaxis, yaxis) pairs from figure traces.
@@ -418,15 +437,12 @@ def update_traces(fig: go.Figure, selector: dict | None = None, **kwargs) -> go.
>>> # Update specific trace by name
>>> update_traces(fig, selector={"name": "Germany"}, line_width=5, line_dash="dot")
"""
- fig.update_traces(selector=selector, **kwargs)
-
- for frame in fig.frames:
- for trace in frame.data:
- if selector is None:
+ for trace in _iter_all_traces(fig):
+ if selector is None:
+ trace.update(**kwargs)
+ else:
+ # Check if trace matches all selector criteria
+ if all(getattr(trace, k, None) == v for k, v in selector.items()):
trace.update(**kwargs)
- else:
- # Check if trace matches all selector criteria
- if all(getattr(trace, k, None) == v for k, v in selector.items()):
- trace.update(**kwargs)
return fig
diff --git a/xarray_plotly/plotting.py b/xarray_plotly/plotting.py
index 69acf6c..b1e5085 100644
--- a/xarray_plotly/plotting.py
+++ b/xarray_plotly/plotting.py
@@ -19,6 +19,9 @@
get_value_col,
to_dataframe,
)
+from xarray_plotly.figures import (
+ _iter_all_traces,
+)
if TYPE_CHECKING:
import plotly.graph_objects as go
@@ -192,9 +195,7 @@ def _style_traces_as_bars(fig: go.Figure) -> None:
then assigns stackgroups: positive traces stack upward, negative stack downward.
"""
# Collect all traces (main + animation frames)
- all_traces = list(fig.data)
- for frame in fig.frames:
- all_traces.extend(frame.data)
+ all_traces = list(_iter_all_traces(fig))
# Classify each trace name by aggregating sign info across all occurrences
sign_flags: dict[str, dict[str, bool]] = {}