From be3d74f0cd3f1f3a70a09062b3b96178118cd7bd Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Wed, 1 Jul 2026 17:07:50 -0400 Subject: [PATCH 1/2] Carry over windowed_array implementation to new main branch with "model" concept --- src/parcels/_core/_windowed_array.py | 101 ++++++++++++++++++++++++++ src/parcels/_core/field.py | 2 +- src/parcels/_core/fieldset.py | 26 +++++++ src/parcels/_core/model.py | 38 ++++++++++ tests/test_windowed_array.py | 102 +++++++++++++++++++++++++++ 5 files changed, 268 insertions(+), 1 deletion(-) create mode 100644 src/parcels/_core/_windowed_array.py create mode 100644 tests/test_windowed_array.py diff --git a/src/parcels/_core/_windowed_array.py b/src/parcels/_core/_windowed_array.py new file mode 100644 index 000000000..9fbfd35b3 --- /dev/null +++ b/src/parcels/_core/_windowed_array.py @@ -0,0 +1,101 @@ +"""Transparent rolling time-window cache for lazy (dask-backed) field data. + +Assumptions / current limits: + * ``time`` is the leading dimension of the field (true for both the SGRID and + UGRID ingestion paths; the structured path transposes to ``(time, ...)``). + * Valid while the requested time indices stay within the resident window + (i.e. all particles share the clock). A sample that requests time indices + spanning more than the retained levels would force reloads. +""" + +from __future__ import annotations + +import numpy as np +import xarray as xr +from dask import is_dask_collection + +# xarray / uxarray ``isel`` keyword arguments that are NOT dimension indexers. +_NON_INDEXER_KWARGS = frozenset({"drop", "missing_dims", "ignore_grid"}) + + +class WindowedArray: + """Wrap a lazy DataArray so ``isel`` loads/caches/evicts time levels as NumPy.""" + + def __init__(self, data: xr.DataArray, time_dim: str = "time", max_levels: int | None = None): + if data.dims[0] != time_dim: + raise ValueError(f"WindowedArray expects {time_dim!r} as the leading dimension, got {data.dims}") + self._data = data + self._tdim = time_dim + self._cache: dict[int, np.ndarray] = {} # time index -> NumPy slab (remaining dims) + self._max = max_levels + # diagnostics + self.loads = 0 + self.bytes_read = 0 + self._slab_bytes = int(np.prod(data.isel({time_dim: 0}).shape)) * data.dtype.itemsize + + # -- transparency: forward everything we don't override ------------------- + def __getattr__(self, name): + # __getattr__ only fires for misses; reach _data without recursing. + return getattr(object.__getattribute__(self, "_data"), name) + + def __repr__(self): + return ( + f"WindowedArray(time_dim={self._tdim!r}, cached_levels={sorted(self._cache)}, " + f"loads={self.loads})\n{self._data!r}" + ) + + # -- window management ---------------------------------------------------- + def _read_level(self, lvl: int) -> np.ndarray: + """Bulk, sequential read of one time level into NumPy (the dask->NumPy step).""" + return np.asarray(self._data.isel({self._tdim: int(lvl)}).values) + + def _ensure(self, levels: np.ndarray) -> None: + for lvl in levels: + lvl = int(lvl) + if lvl not in self._cache: + self._cache[lvl] = self._read_level(lvl) + self.loads += 1 + self.bytes_read += self._slab_bytes + # retire stale levels (the clock only moves forward across the window) + lo = int(np.min(levels)) + for old in [k for k in self._cache if k < lo]: + del self._cache[old] + if self._max is not None and len(self._cache) > self._max: + for old in sorted(self._cache)[: len(self._cache) - self._max]: + del self._cache[old] + + # -- intercepted indexing ------------------------------------------------- + def isel(self, indexers: dict | None = None, **kwargs): + sel = dict(indexers) if indexers is not None else {} + sel.update({k: v for k, v in kwargs.items() if k not in _NON_INDEXER_KWARGS}) + + + # no time selection -> nothing to window; preserve control kwargs + if self._tdim not in sel: + return self._data.isel(indexers, **kwargs) + + t_ind = sel[self._tdim] + t_vals = np.asarray(t_ind.values if isinstance(t_ind, xr.DataArray) else t_ind) + levels = np.unique(t_vals) + self._ensure(levels) + + # stack the resident levels into one small NumPy block; remap to local indices + block = np.stack([self._cache[int(lvl)] for lvl in levels]) # (nlevels, *rest) + nda = xr.DataArray(block, dims=self._data.dims) # NumPy-backed, original dim order + local = np.searchsorted(levels, t_vals) + sel[self._tdim] = xr.DataArray(local, dims=getattr(t_ind, "dims", ())) + return nda.isel(sel) # plain vectorised gather in NumPy (no ignore_grid needed) + + +def maybe_windowed(data: xr.DataArray, max_levels: int | None = None): + """Wrap dask-backed, field data in a ``WindowedArray``; else pass through. + + NumPy-backed fields (already resident) and fields without a leading ``time`` + dimension are returned unchanged, so existing eager workflows are unaffected. + Already-wrapped data is returned unchanged. + """ + if isinstance(data, WindowedArray): + return data + if data.dims and data.dims[0] == "time" and is_dask_collection(data.data): + return WindowedArray(data, max_levels=max_levels) + return data diff --git a/src/parcels/_core/field.py b/src/parcels/_core/field.py index bff605af8..f0b8008fd 100644 --- a/src/parcels/_core/field.py +++ b/src/parcels/_core/field.py @@ -100,7 +100,7 @@ def __init__( @property def data(self): - return self.model.data[self.name] + return self.model.field_data(self.name) @property def grid(self): # TODO PR: Remove in favour of referencing model grid directly diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index 43aefd555..262092a2a 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -144,6 +144,32 @@ def add_field(self, field: Field, name: str | None = None): self.fields[name] = field + def to_windowed_arrays(self, *, max_levels: int | None = None): + """Wrap dask-backed field data in rolling time-window caches. + + Opt-in optimization for forward-marching simulations where all particles + share a single clock. Delegates to each underlying model; dask-backed, + time-leading fields are served through a resident NumPy window (each time + level loaded once and evicted as the clock advances) instead of re-reading + chunks on every kernel step. NumPy-backed (eager) and non-time-leading + fields are left unchanged, and re-invoking is idempotent, so this is safe + to call more than once. + + Parameters + ---------- + max_levels : int, optional + Cap on the number of time levels kept resident per field. ``None`` + (default) retains every level the advancing clock still brackets. + + Returns + ------- + FieldSet + ``self``, to allow chaining. + """ + for model in self.models: + model.to_windowed_arrays(max_levels=max_levels) + return self + def add_constant_field(self, name: str, value, mesh: Mesh = "spherical"): """Wrapper function to add a Field that is constant in space, useful e.g. when using constant horizontal diffusivity diff --git a/src/parcels/_core/model.py b/src/parcels/_core/model.py index 2040ca14a..b687ced92 100644 --- a/src/parcels/_core/model.py +++ b/src/parcels/_core/model.py @@ -8,6 +8,7 @@ import xarray as xr import parcels._sgrid as sgrid +from parcels._core._windowed_array import maybe_windowed from parcels._core.basegrid import BaseGrid from parcels._core.field import Field, VectorField from parcels._core.utils.time import TimeInterval @@ -58,6 +59,43 @@ def assert_valid_model_data(self) -> None: raise e return + def field_data(self, name: str) -> Any: + """Return the array backing field ``name``. + + Normally this is the ``xr.DataArray`` held in the dataset. After + :meth:`to_windowed_arrays`, dask-backed fields are served through a + cached :class:`~parcels._core._windowed_array.WindowedArray` instead. + """ + windowed = self.__dict__.get("_windowed") + if windowed is not None and name in windowed: + return windowed[name] + return self.data[name] + + def to_windowed_arrays(self, *, max_levels: int | None = None) -> Self: + """Wrap dask-backed field data in rolling time-window caches. + + Opt-in optimization for forward-marching simulations where all particles + share a single clock. For each dask-backed, time-leading field, ``isel`` + then samples a resident NumPy window (each time level loaded once and + evicted as the clock advances) instead of re-reading chunks and paying the + dask scheduling overhead on every kernel step. NumPy-backed (eager) fields + and non-time-leading fields are left unchanged. + + Idempotent: re-invoking reuses the existing wrapper (and its warm cache) + rather than rebuilding it. + + Parameters + ---------- + max_levels : int, optional + Cap on the number of time levels kept resident per field. ``None`` + (default) retains every level the advancing clock still brackets. + """ + windowed = self.__dict__.setdefault("_windowed", {}) + for name in self.scalar_field_names: + current = windowed.get(name, self.data[name]) + windowed[name] = maybe_windowed(current, max_levels=max_levels) + return self + @property def time_interval(self) -> TimeInterval | None: try: diff --git a/tests/test_windowed_array.py b/tests/test_windowed_array.py new file mode 100644 index 000000000..a0d755c44 --- /dev/null +++ b/tests/test_windowed_array.py @@ -0,0 +1,102 @@ +"""Tests for the transparent rolling time-window cache (WindowedArray).""" + +import dask.array as da +import numpy as np +import pytest +import xarray as xr + +from parcels import FieldSet, ParticleSet +from parcels._core._windowed_array import WindowedArray, maybe_windowed +from parcels._datasets.structured.generated import simple_UV_dataset +from parcels.kernels import AdvectionRK2 + + +def test_windowed_isel_matches_dask_loads_once_and_evicts(): + """WindowedArray.isel must equal dask isel, load each level once, keep <=2 resident.""" + ntime, n, npart = 20, 64, 200 + rng = np.random.default_rng(0) + base = rng.standard_normal((ntime, 3, n, n)) + lazy = xr.DataArray(da.from_array(base, chunks=(1, 3, n, n)), dims=("time", "depth", "lat", "lon")) + win = WindowedArray(lazy) + + worst, max_cache = 0.0, 0 + for step in range(40): + ti = min(step // 2, ntime - 2) # advancing clock, 2 sub-steps per level + yi, xi = rng.integers(0, n, npart), rng.integers(0, n, npart) + zi = np.zeros(npart, dtype=int) + sel = dict( + time=xr.DataArray(np.r_[np.full(npart, ti), np.full(npart, ti + 1)], dims="p"), + depth=xr.DataArray(np.r_[zi, zi], dims="p"), + lat=xr.DataArray(np.r_[yi, yi], dims="p"), + lon=xr.DataArray(np.r_[xi, xi], dims="p"), + ) + got = win.isel(sel).data + ref = lazy.isel(sel).data.compute() + worst = max(worst, float(np.abs(got - ref).max())) + max_cache = max(max_cache, len(win._cache)) + + assert worst == 0.0 # byte-identical to dask + assert win.loads == ntime # each time level read exactly once + assert max_cache <= 2 # only the bracketing levels resident + + +def test_to_windowed_arrays_wraps_dask_but_not_numpy(): + ds = simple_UV_dataset(mesh="flat") + fs_np = FieldSet.from_sgrid_conventions(ds, mesh="flat") + fs_dk = FieldSet.from_sgrid_conventions(ds.chunk({"time": 1}), mesh="flat") + + # construction is never windowing -- it is opt-in via the fieldset method + assert not isinstance(fs_np.U.data, WindowedArray) + assert not isinstance(fs_dk.U.data, WindowedArray) + + assert fs_np.to_windowed_arrays() is fs_np # chainable + fs_dk.to_windowed_arrays() + + # numpy-backed field is left eager; dask-backed field gets wrapped + assert not isinstance(fs_np.U.data, WindowedArray) + assert isinstance(fs_dk.U.data, WindowedArray) + # transparency: forwarded attributes still behave like the DataArray + assert fs_dk.U.data.dims == fs_np.U.data.dims + assert fs_dk.U.data.shape == fs_np.U.data.shape + + +def test_to_windowed_arrays_is_idempotent_and_forwards_max_levels(): + ds = simple_UV_dataset(mesh="flat") + fs = FieldSet.from_sgrid_conventions(ds.chunk({"time": 1}), mesh="flat") + + fs.to_windowed_arrays(max_levels=3) + first = fs.U.data + assert isinstance(first, WindowedArray) + assert first._max == 3 + + # re-wrapping returns the same object (idempotent, warm cache preserved) + fs.to_windowed_arrays(max_levels=3) + assert fs.U.data is first + + +def test_maybe_windowed_passthrough_for_non_time_leading(): + da_no_time = xr.DataArray(da.zeros((3, 4), chunks=(3, 4)), dims=("lat", "lon")) + assert maybe_windowed(da_no_time) is da_no_time # not wrapped (no leading time dim) + + +@pytest.mark.parametrize("mesh", ["flat", "spherical"]) +def test_dask_advection_matches_numpy(mesh): + """An identical advection must give identical trajectories whether the field + is numpy-backed or dask-backed (windowed). + """ + ds = simple_UV_dataset(mesh=mesh) + ds["U"].data[:] = 1.0 # steady zonal flow -> in-bounds, deterministic + + def run(chunked): + d = ds.chunk({"time": 1}) if chunked else ds + fs = FieldSet.from_sgrid_conventions(d, mesh=mesh) + if chunked: + fs.to_windowed_arrays() + pset = ParticleSet(fs, lon=np.zeros(10), lat=np.linspace(-10, 10, 10)) + pset.execute(AdvectionRK2, runtime=7200, dt=np.timedelta64(15, "m")) + return np.array(pset.lon), np.array(pset.lat) + + lon_np, lat_np = run(False) + lon_dk, lat_dk = run(True) + np.testing.assert_allclose(lon_dk, lon_np, atol=1e-9) + np.testing.assert_allclose(lat_dk, lat_np, atol=1e-9) From c4052b974cfba2e9e43ccd4ceebd435759d641cb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Jul 2026 14:09:47 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/parcels/_core/_windowed_array.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/parcels/_core/_windowed_array.py b/src/parcels/_core/_windowed_array.py index 9fbfd35b3..df325e31b 100644 --- a/src/parcels/_core/_windowed_array.py +++ b/src/parcels/_core/_windowed_array.py @@ -69,7 +69,6 @@ def isel(self, indexers: dict | None = None, **kwargs): sel = dict(indexers) if indexers is not None else {} sel.update({k: v for k, v in kwargs.items() if k not in _NON_INDEXER_KWARGS}) - # no time selection -> nothing to window; preserve control kwargs if self._tdim not in sel: return self._data.isel(indexers, **kwargs)