diff --git a/docs/examples/manipulation.ipynb b/docs/examples/manipulation.ipynb index 6b7f160..cb4176a 100644 --- a/docs/examples/manipulation.ipynb +++ b/docs/examples/manipulation.ipynb @@ -7,7 +7,8 @@ "source": [ "# Figure Manipulation\n", "\n", - "What's easy, what's annoying, and how to work around it." + "After creating a figure with `xpx()`, you can manipulate it using standard Plotly methods.\n", + "This notebook shows what works out of the box, and where `update_traces` from xarray-plotly helps." ] }, { @@ -17,185 +18,192 @@ "metadata": {}, "outputs": [], "source": [ + "import numpy as np\n", "import plotly.express as px\n", - "import plotly.graph_objects as go\n", - "import plotly.io as pio\n", + "import xarray as xr\n", "\n", - "from xarray_plotly import overlay, update_traces\n", + "from xarray_plotly import config, update_traces, xpx\n", "\n", - "pio.renderers.default = \"notebook_connected\"\n", + "config.notebook()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "# 4D DataArray: scenario x metric x year x country\n", + "df_gap = px.data.gapminder()\n", + "countries = [\"United States\", \"China\", \"Germany\", \"Brazil\"]\n", + "metrics = [\"lifeExp\", \"gdpPercap\"]\n", "\n", - "# Sample data\n", - "df = px.data.gapminder()\n", - "df_2007 = df.query(\"year == 2007\")\n", - "df_countries = df.query(\"country in ['United States', 'China', 'Germany', 'Brazil']\")" + "# Build base 3D array (metric x year x country)\n", + "arrays = []\n", + "for metric in metrics:\n", + " df_pivot = df_gap[df_gap[\"country\"].isin(countries)].pivot(\n", + " index=\"year\", columns=\"country\", values=metric\n", + " )\n", + " arrays.append(df_pivot.values)\n", + "\n", + "base_3d = np.stack(arrays)\n", + "\n", + "# Add scenario dimension (4D): original + 10% higher\n", + "scenarios = [\"baseline\", \"optimistic\"]\n", + "data_4d = np.stack([base_3d, base_3d * 1.1])\n", + "\n", + "da = xr.DataArray(\n", + " data_4d,\n", + " dims=[\"scenario\", \"metric\", \"year\", \"country\"],\n", + " coords={\n", + " \"scenario\": scenarios,\n", + " \"metric\": metrics,\n", + " \"year\": df_pivot.index.tolist(),\n", + " \"country\": df_pivot.columns.tolist(),\n", + " },\n", + " name=\"value\",\n", + ")\n", + "da" ] }, { "cell_type": "markdown", - "id": "2", + "id": "3", "metadata": {}, "source": [ "---\n", - "# Easy: Single Plots\n", + "## Standard Plotly Methods\n", "\n", - "All standard manipulation methods work as expected." + "All standard Plotly manipulation methods work on figures created with `xpx()`." ] }, { "cell_type": "code", "execution_count": null, - "id": "3", + "id": "4", "metadata": {}, "outputs": [], "source": [ - "fig = px.scatter(df_2007, x=\"gdpPercap\", y=\"lifeExp\", color=\"continent\", size=\"pop\")\n", + "# Simple 2D slice\n", + "fig = xpx(da.sel(scenario=\"baseline\", metric=\"lifeExp\")).line()\n", "fig" ] }, { "cell_type": "code", "execution_count": null, - "id": "4", + "id": "5", "metadata": {}, "outputs": [], "source": [ "# Layout\n", - "fig.update_layout(title=\"GDP vs Life Expectancy\", template=\"plotly_white\")\n", + "fig.update_layout(title=\"Life Expectancy Over Time\", template=\"plotly_white\")\n", "\n", "# All traces\n", - "fig.update_traces(marker_opacity=0.7)\n", + "fig.update_traces(line_width=3)\n", "\n", - "# Specific traces\n", - "fig.update_traces(marker_line_width=2, selector={\"name\": \"Europe\"})\n", + "# Specific trace by name\n", + "fig.update_traces(line_dash=\"dot\", selector={\"name\": \"Germany\"})\n", "\n", "# Axes\n", - "fig.update_xaxes(type=\"log\", title=\"GDP per Capita\")\n", - "fig.update_yaxes(range=[40, 90])\n", + "fig.update_xaxes(title=\"Year\", showgrid=False)\n", + "fig.update_yaxes(title=\"Life Expectancy (years)\", range=[40, 85])\n", "\n", - "# Annotations and shapes\n", - "fig.add_hline(y=df_2007[\"lifeExp\"].mean(), line_dash=\"dash\", line_color=\"gray\")\n", + "# Reference line\n", + "fig.add_hline(y=70, line_dash=\"dash\", line_color=\"gray\", annotation_text=\"Target\")\n", "\n", "fig" ] }, { "cell_type": "markdown", - "id": "5", + "id": "6", "metadata": {}, "source": [ "---\n", - "# Easy: Faceted Plots\n", + "## Faceted Plots\n", "\n", - "`update_traces`, `update_xaxes`, `update_yaxes` all work across facets." + "`update_traces`, `update_xaxes`, `update_yaxes` work across all facets." ] }, { "cell_type": "code", "execution_count": null, - "id": "6", + "id": "7", "metadata": {}, "outputs": [], "source": [ - "fig = px.line(df_countries, x=\"year\", y=\"gdpPercap\", color=\"country\", facet_col=\"country\")\n", + "# Facet by metric, color by country\n", + "fig = xpx(da.sel(scenario=\"baseline\")).line(facet_col=\"metric\")\n", "fig" ] }, { "cell_type": "code", "execution_count": null, - "id": "7", + "id": "8", "metadata": {}, "outputs": [], "source": [ "# Update ALL traces across all facets\n", - "fig.update_traces(line_width=3)\n", + "fig.update_traces(line_width=2)\n", "\n", - "# Update ALL x-axes\n", + "# Update ALL axes\n", "fig.update_xaxes(showgrid=False)\n", "\n", - "# Update ALL y-axes\n", - "fig.update_yaxes(showgrid=False, type=\"log\")\n", - "\n", - "# Clean up facet labels\n", - "fig.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", + "# Target specific facet (1-indexed)\n", + "fig.update_yaxes(type=\"log\", col=2) # log scale only for gdpPercap\n", "\n", "fig" ] }, { "cell_type": "markdown", - "id": "8", + "id": "9", "metadata": {}, "source": [ - "### Targeting specific facets\n", - "\n", - "Use `row=` and `col=` (1-indexed) to target specific facets." + "### Grid layout with facet_row" ] }, { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "10", "metadata": {}, "outputs": [], "source": [ - "fig = px.histogram(px.data.tips(), x=\"total_bill\", facet_row=\"sex\", facet_col=\"time\")\n", + "# 2x2 grid: scenario x metric\n", + "fig = xpx(da).line(facet_col=\"metric\", facet_row=\"scenario\")\n", "\n", - "# Target specific cell\n", - "fig.update_yaxes(title_text=\"Frequency\", row=1, col=1)\n", - "\n", - "# Target entire column\n", - "fig.update_xaxes(title_text=\"Bill ($)\", col=2)\n", - "\n", - "# Target entire row\n", - "fig.update_traces(marker_color=\"orange\", row=2)\n", - "\n", - "fig.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", + "fig.update_traces(line_width=2)\n", + "fig.update_yaxes(type=\"log\", col=2) # log scale for gdpPercap column\n", "fig" ] }, { "cell_type": "markdown", - "id": "10", + "id": "11", "metadata": {}, "source": [ - "### Reference lines on facets\n", + "---\n", + "## Animation: The Pain Point\n", "\n", - "`add_hline`/`add_vline` apply to all facets by default. Use `row=`/`col=` to target." + "Plotly's `fig.update_traces()` does **not** update animation frames. This is the main gotcha." ] }, { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "12", "metadata": {}, "outputs": [], "source": [ - "fig = px.scatter(df_2007, x=\"gdpPercap\", y=\"lifeExp\", facet_col=\"continent\", facet_col_wrap=3)\n", - "fig.update_xaxes(type=\"log\")\n", - "\n", - "# Applies to ALL facets\n", - "fig.add_hline(y=70, line_dash=\"dash\", line_color=\"red\")\n", - "\n", - "# Specific facet only\n", - "fig.add_hline(y=50, line_dash=\"dot\", line_color=\"blue\", row=2, col=1)\n", - "\n", - "fig.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", + "# Animated bar chart\n", + "fig = xpx(da.sel(scenario=\"baseline\", metric=\"gdpPercap\")).bar(animation_frame=\"year\")\n", "fig" ] }, - { - "cell_type": "markdown", - "id": "12", - "metadata": {}, - "source": [ - "---\n", - "# Easy: Adding traces to faceted/animated figures\n", - "\n", - "Use `overlay` to add traces. It handles facets and animation frames automatically." - ] - }, { "cell_type": "code", "execution_count": null, @@ -203,31 +211,11 @@ "metadata": {}, "outputs": [], "source": [ - "# Animated scatter\n", - "fig = px.scatter(\n", - " df_countries,\n", - " x=\"gdpPercap\",\n", - " y=\"lifeExp\",\n", - " color=\"country\",\n", - " animation_frame=\"year\",\n", - " log_x=True,\n", - " range_y=[40, 85],\n", - ")\n", - "\n", - "# Create a figure with reference marker\n", - "ref = go.Figure(\n", - " go.Scatter(\n", - " x=[10000],\n", - " y=[75],\n", - " mode=\"markers\",\n", - " marker={\"size\": 20, \"symbol\": \"star\", \"color\": \"gold\"},\n", - " name=\"Target\",\n", - " )\n", - ")\n", + "# This only affects the INITIAL view!\n", + "fig.update_traces(marker_color=\"red\")\n", "\n", - "# Overlay - trace appears in all animation frames\n", - "combined = overlay(fig, ref)\n", - "combined" + "print(f\"Base trace color: {fig.data[0].marker.color}\")\n", + "print(f\"Frame 0 trace color: {fig.frames[0].data[0].marker.color}\")" ] }, { @@ -237,37 +225,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Faceted plot\n", - "fig = px.scatter(df_2007, x=\"gdpPercap\", y=\"lifeExp\", facet_col=\"continent\", facet_col_wrap=3)\n", - "fig.update_xaxes(type=\"log\")\n", - "\n", - "# Add reference to first facet (default axes x, y)\n", - "ref1 = go.Figure(\n", - " go.Scatter(\n", - " x=[5000],\n", - " y=[70],\n", - " mode=\"markers\",\n", - " marker={\"size\": 15, \"symbol\": \"star\", \"color\": \"gold\"},\n", - " name=\"Target 1\",\n", - " )\n", - ")\n", - "\n", - "# Add reference to second facet (axes x2, y2)\n", - "ref2 = go.Figure(\n", - " go.Scatter(\n", - " x=[20000],\n", - " y=[80],\n", - " mode=\"markers\",\n", - " marker={\"size\": 15, \"symbol\": \"star\", \"color\": \"red\"},\n", - " name=\"Target 2\",\n", - " xaxis=\"x2\",\n", - " yaxis=\"y2\", # specify target facet\n", - " )\n", - ")\n", - "\n", - "combined = overlay(fig, ref1, ref2)\n", - "combined.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", - "combined" + "# Play the animation - it reverts to original colors\n", + "fig" ] }, { @@ -275,10 +234,9 @@ "id": "15", "metadata": {}, "source": [ - "---\n", - "# Annoying: Facet axis names\n", + "### Solution: `update_traces` from xarray-plotly\n", "\n", - "To target a specific facet with `add_shape`, `add_annotation`, or when adding traces via `overlay`, you need to know the axis name (`x2`, `y3`, etc.)." + "This helper updates both base traces and all animation frames." ] }, { @@ -288,407 +246,124 @@ "metadata": {}, "outputs": [], "source": [ - "fig = px.scatter(df_2007, x=\"gdpPercap\", y=\"lifeExp\", facet_col=\"continent\", facet_col_wrap=3)\n", - "\n", - "# Inspect axis names\n", - "layout_dict = fig.layout.to_plotly_json()\n", - "print(\"X axes:\", sorted([k for k in layout_dict if k.startswith(\"xaxis\")]))\n", - "print(\"Y axes:\", sorted([k for k in layout_dict if k.startswith(\"yaxis\")]))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17", - "metadata": {}, - "outputs": [], - "source": [ - "# Check which trace uses which axis\n", - "for i, trace in enumerate(fig.data):\n", - " print(f\"Trace {i} ({trace.name}): xaxis={trace.xaxis or 'x'}, yaxis={trace.yaxis or 'y'}\")" - ] - }, - { - "cell_type": "markdown", - "id": "18", - "metadata": {}, - "source": [ - "**Tip:** For simple cases, use `add_hline`/`add_vline` with `row=`/`col=` instead of `add_shape` - it handles axis mapping internally." - ] - }, - { - "cell_type": "markdown", - "id": "19", - "metadata": {}, - "source": [ - "---\n", - "# Annoying: Animation trace updates\n", - "\n", - "**This is the main pain point.** `update_traces()` does NOT update animation frames." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "20", - "metadata": {}, - "outputs": [], - "source": [ - "fig = px.line(df_countries, x=\"year\", y=\"gdpPercap\", color=\"country\", animation_frame=\"country\")\n", - "fig" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "21", - "metadata": {}, - "outputs": [], - "source": [ - "# This only affects the INITIAL view, not the animation frames!\n", - "fig.update_traces(line_width=5, line_dash=\"dot\")\n", - "\n", - "print(f\"Base trace line_width: {fig.data[0].line.width}\")\n", - "print(f\"Frame 0 trace line_width: {fig.frames[0].data[0].line.width}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "22", - "metadata": {}, - "outputs": [], - "source": [ - "# When you play the animation, it reverts to the frame's original style\n", - "fig" - ] - }, - { - "cell_type": "markdown", - "id": "23", - "metadata": {}, - "source": [ - "### Solution: `update_traces`\n", - "\n", - "xarray-plotly provides this helper to update both base traces and animation frames:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "24", - "metadata": {}, - "outputs": [], - "source": [ - "# update_traces is imported from xarray_plotly\n", - "# It updates traces in both base figure and all animation frames" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "25", - "metadata": {}, - "outputs": [], - "source": [ - "fig = px.line(df_countries, x=\"year\", y=\"gdpPercap\", color=\"country\", animation_frame=\"country\")\n", - "\n", - "update_traces(fig, line_width=4, line_dash=\"dot\")\n", + "fig = xpx(da.sel(scenario=\"baseline\", metric=\"gdpPercap\")).bar(animation_frame=\"year\")\n", "\n", - "print(f\"Base trace line_width: {fig.data[0].line.width}\")\n", - "print(f\"Frame 0 trace line_width: {fig.frames[0].data[0].line.width}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "26", - "metadata": {}, - "outputs": [], - "source": [ - "fig" - ] - }, - { - "cell_type": "markdown", - "id": "27", - "metadata": {}, - "source": [ - "### Selective updates with selector\n", + "update_traces(fig, marker_color=\"red\", marker_opacity=0.8)\n", "\n", - "Use `selector` to target specific traces by name:" + "print(f\"Base trace color: {fig.data[0].marker.color}\")\n", + "print(f\"Frame 0 trace color: {fig.frames[0].data[0].marker.color}\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "28", + "id": "17", "metadata": {}, "outputs": [], "source": [ - "fig = px.line(df_countries, x=\"year\", y=\"gdpPercap\", color=\"country\", animation_frame=\"year\")\n", - "\n", - "# Update only one trace by name\n", - "update_traces(fig, selector={\"name\": \"Germany\"}, line_width=5, line_dash=\"dot\")\n", - "\n", - "# Update multiple traces\n", - "update_traces(fig, selector={\"name\": \"China\"}, line_color=\"red\", line_width=3)\n", - "\n", + "# Now the style persists through animation\n", "fig" ] }, { "cell_type": "markdown", - "id": "29", + "id": "18", "metadata": {}, "source": [ - "### Works with facets + animation" + "### Selective updates with selector" ] }, { "cell_type": "code", "execution_count": null, - "id": "30", + "id": "19", "metadata": {}, "outputs": [], "source": [ - "df_subset = df.query(\n", - " \"continent in ['Europe', 'Asia'] and country in ['Germany', 'France', 'China', 'Japan']\"\n", - ")\n", + "fig = xpx(da.sel(scenario=\"baseline\", metric=\"lifeExp\")).line(x=\"year\")\n", "\n", - "fig = px.line(\n", - " df_subset,\n", - " x=\"year\",\n", - " y=\"gdpPercap\",\n", - " color=\"country\",\n", - " facet_col=\"continent\",\n", - " animation_frame=\"year\",\n", - ")\n", + "# Highlight specific countries\n", + "update_traces(fig, selector={\"name\": \"China\"}, line_color=\"red\", line_width=4)\n", + "update_traces(fig, selector={\"name\": \"United States\"}, line_color=\"blue\", line_width=4)\n", "\n", - "update_traces(fig, line_width=3)\n", - "fig.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", "fig" ] }, { "cell_type": "markdown", - "id": "31", - "metadata": {}, - "source": [ - "### What's affected\n", - "\n", - "Anything on **traces** needs the helper for animations:\n", - "\n", - "| Property | Facets | Animation |\n", - "|----------|--------|-----------|\n", - "| `line_width` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n", - "| `line_dash` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n", - "| `line_color` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n", - "| `marker_size` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n", - "| `marker_symbol` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n", - "| `opacity` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n", - "\n", - "**Layout properties** (`update_layout`, `update_xaxes`, `update_yaxes`) work fine for animations." - ] - }, - { - "cell_type": "markdown", - "id": "32", - "metadata": {}, - "source": [ - "---\n", - "# Annoying: Animation speed\n", - "\n", - "The API to change animation speed is deeply nested." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "33", + "id": "20", "metadata": {}, - "outputs": [], "source": [ - "fig = px.scatter(\n", - " df,\n", - " x=\"gdpPercap\",\n", - " y=\"lifeExp\",\n", - " color=\"continent\",\n", - " size=\"pop\",\n", - " animation_frame=\"year\",\n", - " log_x=True,\n", - " range_y=[25, 90],\n", - ")\n", + "### Unified hover with animation\n", "\n", - "# This is... not intuitive\n", - "fig.layout.updatemenus[0].buttons[0].args[1][\"frame\"][\"duration\"] = 100 # faster\n", - "fig.layout.updatemenus[0].buttons[0].args[1][\"transition\"][\"duration\"] = 50\n", + "A common pattern: unified hover mode with custom formatting.\n", "\n", - "fig" - ] - }, - { - "cell_type": "markdown", - "id": "34", - "metadata": {}, - "source": [ - "### Workaround: Helper function" + "- **Layout** (`hovermode`, spikes): Standard Plotly works fine\n", + "- **Traces** (`hovertemplate`): Use `update_traces()` for animation support" ] }, { "cell_type": "code", "execution_count": null, - "id": "35", - "metadata": {}, - "outputs": [], - "source": [ - "def set_animation_speed(fig, frame_duration=500, transition_duration=300):\n", - " \"\"\"Set animation speed in milliseconds.\"\"\"\n", - " if fig.layout.updatemenus:\n", - " fig.layout.updatemenus[0].buttons[0].args[1][\"frame\"][\"duration\"] = frame_duration\n", - " fig.layout.updatemenus[0].buttons[0].args[1][\"transition\"][\"duration\"] = transition_duration\n", - " return fig" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "36", + "id": "21", "metadata": {}, "outputs": [], "source": [ - "fig = px.scatter(\n", - " df,\n", - " x=\"gdpPercap\",\n", - " y=\"lifeExp\",\n", - " color=\"continent\",\n", - " animation_frame=\"year\",\n", - " log_x=True,\n", - " range_y=[25, 90],\n", - ")\n", + "fig = xpx(da.sel(metric=\"gdpPercap\")).line(x=\"year\", animation_frame=\"scenario\")\n", "\n", - "set_animation_speed(fig, frame_duration=200, transition_duration=100)\n", - "fig" - ] - }, - { - "cell_type": "markdown", - "id": "37", - "metadata": {}, - "source": [ - "---\n", - "# Annoying: Slider styling\n", + "# Layout settings - standard Plotly\n", + "fig.update_layout(hovermode=\"x unified\")\n", + "fig.update_xaxes(showspikes=True, spikecolor=\"gray\", spikethickness=1)\n", "\n", - "Verbose but straightforward." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "38", - "metadata": {}, - "outputs": [], - "source": [ - "fig = px.scatter(\n", - " df,\n", - " x=\"gdpPercap\",\n", - " y=\"lifeExp\",\n", - " color=\"continent\",\n", - " animation_frame=\"year\",\n", - " log_x=True,\n", - " range_y=[25, 90],\n", - ")\n", - "\n", - "fig.layout.sliders[0].currentvalue.prefix = \"Year: \"\n", - "fig.layout.sliders[0].currentvalue.font.size = 16\n", - "fig.layout.sliders[0].pad.t = 50 # padding from top\n", + "# Trace settings - use update_traces for animation support\n", + "update_traces(fig, hovertemplate=\"%{fullData.name}: $%{y:,.0f}\")\n", "\n", "fig" ] }, { "cell_type": "markdown", - "id": "39", + "id": "22", "metadata": {}, "source": [ - "### Hide slider or play button" + "### Facets + Animation" ] }, { "cell_type": "code", "execution_count": null, - "id": "40", + "id": "23", "metadata": {}, "outputs": [], "source": [ - "fig = px.scatter(\n", - " df,\n", - " x=\"gdpPercap\",\n", - " y=\"lifeExp\",\n", - " color=\"continent\",\n", - " animation_frame=\"year\",\n", - " log_x=True,\n", - " range_y=[25, 90],\n", - ")\n", + "# Facet by metric, animate by scenario\n", + "fig = xpx(da).line(facet_col=\"metric\", animation_frame=\"scenario\")\n", "\n", - "# Hide slider (keep play button)\n", - "fig.layout.sliders = []\n", + "# Standard Plotly for layout\n", + "fig.update_yaxes(type=\"log\", col=2)\n", "\n", - "# Or hide play button (keep slider):\n", - "# fig.layout.updatemenus = []\n", + "# update_traces for trace properties with animation\n", + "update_traces(fig, line_width=3)\n", + "update_traces(fig, selector={\"name\": \"China\"}, line_dash=\"dot\")\n", "\n", "fig" ] }, { "cell_type": "markdown", - "id": "41", + "id": "24", "metadata": {}, "source": [ "---\n", - "# Summary\n", - "\n", - "### Provided by xarray-plotly\n", + "## Summary\n", "\n", - "```python\n", - "from xarray_plotly import overlay, update_traces\n", - "```\n", - "\n", - "### Local helper for animation speed" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "42", - "metadata": {}, - "outputs": [], - "source": [ - "def set_animation_speed(fig, frame_duration=500, transition_duration=300):\n", - " \"\"\"Set animation speed in milliseconds.\"\"\"\n", - " if fig.layout.updatemenus:\n", - " fig.layout.updatemenus[0].buttons[0].args[1][\"frame\"][\"duration\"] = frame_duration\n", - " fig.layout.updatemenus[0].buttons[0].args[1][\"transition\"][\"duration\"] = transition_duration\n", - " return fig" - ] - }, - { - "cell_type": "markdown", - "id": "43", - "metadata": {}, - "source": [ - "### Quick reference\n", - "\n", - "| Task | Facets | Animation | Solution |\n", - "|------|--------|-----------|----------|\n", - "| Update trace style | `fig.update_traces()` | `update_traces()` | xarray-plotly helper |\n", - "| Update axes | `update_xaxes()`/`update_yaxes()` | Same | ✅ Works |\n", - "| Update layout | `update_layout()` | Same | ✅ Works |\n", - "| Add reference line | `add_hline(row=, col=)` | `add_hline()` | ✅ Works |\n", - "| Add trace | `overlay()` | `overlay()` | ✅ Works |\n", - "| Add shape to specific facet | `add_shape(xref=\"x2\")` | Same | Need axis name |\n", - "| Change animation speed | N/A | `set_animation_speed()` | Local helper |\n", - "| Facet labels | `for_each_annotation()` | Same | ✅ Works |" + "| Method | Static/Faceted | Animated |\n", + "|--------|----------------|----------|\n", + "| `fig.update_layout()` | ✅ | ✅ |\n", + "| `fig.update_xaxes()` / `fig.update_yaxes()` | ✅ | ✅ |\n", + "| `fig.add_hline()` / `fig.add_vline()` | ✅ | ✅ |\n", + "| `fig.update_traces()` | ✅ | ❌ base only |\n", + "| `update_traces(fig, ...)` | ✅ | ✅ all frames |" ] } ], @@ -700,7 +375,7 @@ }, "language_info": { "name": "python", - "version": "3.11.0" + "version": "3.12.0" } }, "nbformat": 4, diff --git a/xarray_plotly/figures.py b/xarray_plotly/figures.py index 2e1f658..939d10d 100644 --- a/xarray_plotly/figures.py +++ b/xarray_plotly/figures.py @@ -8,9 +8,28 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from collections.abc import Iterator + import plotly.graph_objects as go +def _iter_all_traces(fig: go.Figure) -> Iterator: + """Iterate over all traces in a figure, including animation frames. + + Yields traces from fig.data first, then from each frame in fig.frames. + Useful for applying styling to all traces including those in animations. + + Args: + fig: Plotly Figure. + + Yields: + Each trace object from the figure. + """ + yield from fig.data + for frame in fig.frames or []: + yield from frame.data + + def _get_subplot_axes(fig: go.Figure) -> set[tuple[str, str]]: """Extract (xaxis, yaxis) pairs from figure traces. @@ -418,15 +437,12 @@ def update_traces(fig: go.Figure, selector: dict | None = None, **kwargs) -> go. >>> # Update specific trace by name >>> update_traces(fig, selector={"name": "Germany"}, line_width=5, line_dash="dot") """ - fig.update_traces(selector=selector, **kwargs) - - for frame in fig.frames: - for trace in frame.data: - if selector is None: + for trace in _iter_all_traces(fig): + if selector is None: + trace.update(**kwargs) + else: + # Check if trace matches all selector criteria + if all(getattr(trace, k, None) == v for k, v in selector.items()): trace.update(**kwargs) - else: - # Check if trace matches all selector criteria - if all(getattr(trace, k, None) == v for k, v in selector.items()): - trace.update(**kwargs) return fig diff --git a/xarray_plotly/plotting.py b/xarray_plotly/plotting.py index 69acf6c..b1e5085 100644 --- a/xarray_plotly/plotting.py +++ b/xarray_plotly/plotting.py @@ -19,6 +19,9 @@ get_value_col, to_dataframe, ) +from xarray_plotly.figures import ( + _iter_all_traces, +) if TYPE_CHECKING: import plotly.graph_objects as go @@ -192,9 +195,7 @@ def _style_traces_as_bars(fig: go.Figure) -> None: then assigns stackgroups: positive traces stack upward, negative stack downward. """ # Collect all traces (main + animation frames) - all_traces = list(fig.data) - for frame in fig.frames: - all_traces.extend(frame.data) + all_traces = list(_iter_all_traces(fig)) # Classify each trace name by aggregating sign info across all occurrences sign_flags: dict[str, dict[str, bool]] = {}