From 7b186f7e020b058420de288c60b5c7ba313d7a52 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 1 Jul 2026 13:47:44 +0200 Subject: [PATCH 01/20] Move tests and rename --- tests/test_fieldset.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index 8cfdd38d5..f08c2785b 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -208,6 +208,18 @@ def test_fieldset_from_sgrid_conventions(ds_name): assert len(fieldset.fields) > 0 +def test_fieldset_add_error_on_duplicate_fields(): + """Test that adding FieldSets with overlapping field names raises a ValueError.""" + ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) + ds2 = ds1.copy() + + fset1 = FieldSet.from_sgrid_conventions(ds1, mesh="flat") + fset2 = FieldSet.from_sgrid_conventions(ds2, mesh="flat") + + with pytest.raises(ValueError, match="field names in common.*'U'"): + fset1 + fset2 + + def test_fieldset_add(): """Test that two FieldSets can be combined with + (fset1 + fset2).""" ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "grid"]].rename({"U_A_grid": "U1"}) @@ -223,19 +235,7 @@ def test_fieldset_add(): assert "V2" in fset.fields -def test_fieldset_add_overlapping_fields(): - """Test that adding FieldSets with overlapping field names raises a ValueError.""" - ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "grid"]].rename({"U_A_grid": "U"}) - ds2 = datasets_structured["ds_2d_left"][["V_A_grid", "grid"]].rename({"V_A_grid": "U"}) - - fset1 = FieldSet.from_sgrid_conventions(ds1, mesh="flat") - fset2 = FieldSet.from_sgrid_conventions(ds2, mesh="flat") - - with pytest.raises(ValueError, match="field names in common.*'U'"): - fset1 + fset2 - - -def test_fieldset_add_overlapping_context_values(): +def test_fieldset_add_error_on_duplicate_context_values(): """Test that adding FieldSets with overlapping context value names raises a ValueError.""" ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "grid"]].rename({"U_A_grid": "U1"}) ds2 = datasets_structured["ds_2d_left"][["V_A_grid", "grid"]].rename({"V_A_grid": "V2"}) From 9ec09f3c32e830bb3a044b9cba817cfb30b30696 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 1 Jul 2026 13:56:15 +0200 Subject: [PATCH 02/20] Update field check --- tests/test_fieldset.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index f08c2785b..d8b3bf937 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -231,8 +231,10 @@ def test_fieldset_add(): fset = fset1 + fset2 assert len(fset.models) == len(fset1.models) + len(fset2.models) - assert "U1" in fset.fields - assert "V2" in fset.fields + + fields_before = list(fset1.fields.keys()) + list(fset2.fields.keys()) + assert len(fields_before) == len(fset.fields) + assert set(fields_before) == set(fset.fields.keys()) def test_fieldset_add_error_on_duplicate_context_values(): From 9e3c37fca650015c5de5b62b8f8aa6ea50794a7d Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 1 Jul 2026 14:00:14 +0200 Subject: [PATCH 03/20] Remove test stub No longer relevant --- tests/test_fieldset.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index d8b3bf937..bff1ca08f 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -96,9 +96,6 @@ def test_fieldset_from_structured_generic_datasets(ds): assert len(fieldset.gridset) == 1 -def test_fieldset_gridset_multiple_grids(): ... - - # TODO restructure: use adding of fieldset notation to test this @pytest.mark.skip("Needs updating after refactoring from https://github.com/Parcels-code/Parcels/pull/2646") def test_fieldset_time_interval(): From 2fd533a2e4aa04522f5f1d497c2ee15977e8059c Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 1 Jul 2026 14:01:16 +0200 Subject: [PATCH 04/20] Add tests for custom vectorfields --- tests/test_fieldset.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index bff1ca08f..2bed7a385 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -96,6 +96,39 @@ def test_fieldset_from_structured_generic_datasets(ds): assert len(fieldset.gridset) == 1 +def test_fieldset_vectorfield_default(): + ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) + + fset1 = FieldSet.from_sgrid_conventions(ds1, mesh="flat") + + assert "U" in fset1.fields + assert "V" in fset1.fields + assert "UV" in fset1.fields + + +def test_fieldset_vectorfield_custom(): + ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) + + fset1 = FieldSet.from_sgrid_conventions(ds1, mesh="flat", vector_fields={"UV_wind": ("U_wind", "V_wind")}) + + assert "U_wind" in fset1.fields + assert "V_wind" in fset1.fields + assert "UV_wind" in fset1.fields + + +def test_fieldset_vectorfield_none(): + ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) + + fset1 = FieldSet.from_sgrid_conventions(ds1, mesh="flat", vector_fields=None) + + assert "U" in fset1.fields + assert "V" in fset1.fields + assert "UV" not in fset1.fields + + +def test_resolve_vector_field_components(): ... + + # TODO restructure: use adding of fieldset notation to test this @pytest.mark.skip("Needs updating after refactoring from https://github.com/Parcels-code/Parcels/pull/2646") def test_fieldset_time_interval(): From 925e017ae35d706548c4e9443adffbbd5c068d41 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 1 Jul 2026 14:46:57 +0200 Subject: [PATCH 05/20] Remove default mesh on ModelData from_{s,u}grid_conventions These methods arent public API and these defaults are set on the FieldSet class --- src/parcels/_core/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/parcels/_core/model.py b/src/parcels/_core/model.py index 2040ca14a..b8fdae29c 100644 --- a/src/parcels/_core/model.py +++ b/src/parcels/_core/model.py @@ -133,7 +133,7 @@ def construct_fields(self) -> list[Field | VectorField]: return list(fields.values()) @classmethod - def from_sgrid_conventions(cls, ds: xr.Dataset, mesh: Mesh | None = None) -> Self: + def from_sgrid_conventions(cls, ds: xr.Dataset, mesh: Mesh | None) -> Self: ds = ds.copy() if mesh is None: mesh = _get_mesh_type_from_sgrid_dataset(ds) @@ -239,7 +239,7 @@ def scalar_field_names(self) -> list[str]: return list(self.data.data_vars) @classmethod - def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: str = "spherical"): + def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: Mesh): ds_dims = list(ds.dims) if not all(dim in ds_dims for dim in ["time", "zf", "zc"]): raise ValueError( From 8002c419dfdc0eb0c390fb42eb3dff834989e768 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 1 Jul 2026 14:01:16 +0200 Subject: [PATCH 06/20] Add MISSING sentinel value --- src/parcels/_python.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/parcels/_python.py b/src/parcels/_python.py index 81db6ade4..552dcd614 100644 --- a/src/parcels/_python.py +++ b/src/parcels/_python.py @@ -1,4 +1,5 @@ # Generic Python helpers +import enum import inspect from collections.abc import Callable, Mapping from typing import TypeVar @@ -6,6 +7,9 @@ K = TypeVar("K") V = TypeVar("V") +_MissingType = enum.Enum("_MissingType", "VALUE") +_MISSING = _MissingType.VALUE + def isinstance_noimport(obj, class_or_tuple): """A version of isinstance that does not require importing the class. From 4cf7c28bca9d089511c3f94a5b16ab4ecc9bf101 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 1 Jul 2026 14:01:16 +0200 Subject: [PATCH 07/20] Update API to take `vector_field Also add vector_field_components to the private API --- src/parcels/_core/fieldset.py | 25 ++++++++++++++++++++----- src/parcels/_core/model.py | 33 ++++++++++++++++++++++++++++----- tests/test_fieldset.py | 13 ++++++++++++- 3 files changed, 60 insertions(+), 11 deletions(-) diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index 43aefd555..be7481ca4 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -10,10 +10,17 @@ import xarray as xr from parcels._core.field import Field, VectorField -from parcels._core.model import CONSTANT_FIELD_MODELS, ModelData, StructuredModelData, UnstructuredModelData +from parcels._core.model import ( + CONSTANT_FIELD_MODELS, + ModelData, + StructuredModelData, + TVectorFieldMapping, + UnstructuredModelData, +) from parcels._core.utils.string import _assert_str_and_python_varname from parcels._core.utils.time import get_datetime_type_calendar from parcels._core.utils.time import is_compatible as datetime_is_compatible +from parcels._python import _MISSING, _MissingType from parcels._typing import Mesh from parcels.interpolators import ( XConstantField, @@ -201,7 +208,12 @@ def gridset(self) -> list[BaseGrid]: return grids @classmethod - def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: str = "spherical"): + def from_ugrid_conventions( + cls, + ds: ux.UxDataset, + mesh: str = "spherical", + vector_fields: TVectorFieldMapping | None | _MissingType = _MISSING, + ): """Create a FieldSet from a Parcels compliant uxarray.UxDataset. This is the primary ingestion method in Parcels for structured grid datasets. @@ -225,12 +237,15 @@ def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: str = "spherical"): ----- See https://ugrid-conventions.github.io/ugrid-conventions/ for more information on the UGRID conventions. """ - model = UnstructuredModelData.from_ugrid_conventions(ds, mesh) + model = UnstructuredModelData.from_ugrid_conventions(ds, mesh, vector_fields) return cls([model]) @classmethod def from_sgrid_conventions( - cls, ds: xr.Dataset, mesh: Mesh | None = None + cls, + ds: xr.Dataset, + mesh: Mesh | None = None, + vector_fields: TVectorFieldMapping | None | _MissingType = _MISSING, ): # TODO: Update mesh to be discovered from the dataset metadata """Create a FieldSet from a dataset using SGRID convention metadata. @@ -259,7 +274,7 @@ def from_sgrid_conventions( See https://sgrid.github.io/sgrid/ for more information on the SGRID conventions. """ - model = StructuredModelData.from_sgrid_conventions(ds, mesh) + model = StructuredModelData.from_sgrid_conventions(ds, mesh, vector_fields) return cls([model]) diff --git a/src/parcels/_core/model.py b/src/parcels/_core/model.py index b8fdae29c..aaec5873b 100644 --- a/src/parcels/_core/model.py +++ b/src/parcels/_core/model.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence from typing import Any, Self import cf_xarray # noqa: F401 @@ -18,6 +19,7 @@ assert_all_field_dims_have_axis, # noqa: F401, leave import for now until decision is made # TODO v4: Make decision ) from parcels._logger import logger +from parcels._python import _MissingType from parcels._typing import Mesh from parcels.convert import _ds_rename_using_standard_names from parcels.interpolators import ( @@ -32,11 +34,14 @@ ) from parcels.interpolators._base import ScalarInterpolator, VectorInterpolator +TVectorFieldMapping = Mapping[str, tuple[str, str] | tuple[str, str, str]] + class ModelData(ABC): data: Any grid: BaseGrid field_to_interpolator: dict[str, ScalarInterpolator | VectorInterpolator] + vector_field_components: TVectorFieldMapping @abstractmethod def construct_fields(self) -> list[Field | VectorField]: ... @@ -79,7 +84,7 @@ def preprocess_sgrid_model_data(ds: xr.Dataset) -> xr.Dataset: class StructuredModelData(ModelData): - def __init__(self, data: xr.Dataset, mesh: Mesh): + def __init__(self, data: xr.Dataset, mesh: Mesh, vector_field_components: TVectorFieldMapping): if not isinstance(data, xr.Dataset): raise ValueError(f"Expected `data` to be an xarray.Dataset . Got {type(data)}") @@ -88,6 +93,7 @@ def __init__(self, data: xr.Dataset, mesh: Mesh): self.data = data self.grid = grid + self.vector_field_components = vector_field_components self.field_to_interpolator = {} self._fields: list[Field | VectorField] | None = None self.assert_valid_model_data() @@ -133,7 +139,9 @@ def construct_fields(self) -> list[Field | VectorField]: return list(fields.values()) @classmethod - def from_sgrid_conventions(cls, ds: xr.Dataset, mesh: Mesh | None) -> Self: + def from_sgrid_conventions( + cls, ds: xr.Dataset, mesh: Mesh | None, vector_fields: TVectorFieldMapping | None | _MissingType + ) -> Self: ds = ds.copy() if mesh is None: mesh = _get_mesh_type_from_sgrid_dataset(ds) @@ -160,7 +168,7 @@ def from_sgrid_conventions(cls, ds: xr.Dataset, mesh: Mesh | None) -> Self: # ds["lon"] = ds[node_dimensions[0]] # ds["lat"] = ds[node_dimensions[1]] - model = cls(ds, mesh=mesh) + model = cls(ds, mesh=mesh, vector_field_components=vector_fields) model._fields = model.construct_fields() for f in model._fields: if isinstance(f, Field): @@ -191,13 +199,14 @@ def from_sgrid_conventions(cls, ds: xr.Dataset, mesh: Mesh | None) -> Self: ), ), mesh=mesh, # type:ignore + vector_fields=None, ) for mesh in ["flat", "spherical"] } class UnstructuredModelData(ModelData): - def __init__(self, data: ux.UxDataset, grid: UxGrid): + def __init__(self, data: ux.UxDataset, grid: UxGrid, vector_field_components: TVectorFieldMapping): if not isinstance(data, ux.UxDataset): raise ValueError(f"Expected `data` to be an uxarray.UxDataset . Got {type(data)}") @@ -206,6 +215,7 @@ def __init__(self, data: ux.UxDataset, grid: UxGrid): self.data = data self.grid = grid + self.vector_field_components = vector_field_components self.field_to_interpolator = {} self._fields: list[Field | VectorField] | None = None @@ -239,7 +249,9 @@ def scalar_field_names(self) -> list[str]: return list(self.data.data_vars) @classmethod - def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: Mesh): + def from_ugrid_conventions( + cls, ds: ux.UxDataset, mesh: Mesh, vector_fields: TVectorFieldMapping | None | _MissingType + ): ds_dims = list(ds.dims) if not all(dim in ds_dims for dim in ["time", "zf", "zc"]): raise ValueError( @@ -276,6 +288,17 @@ def _get_mesh_type_from_sgrid_dataset(ds_sgrid: xr.Dataset) -> Mesh: return "spherical" if _is_coordinate_in_degrees(ds_sgrid[fpoint_x]) else "flat" +def _default_vector_field_components(data_vars: Sequence[str]) -> TVectorFieldMapping: + vars = set(data_vars) + ret = {} + + if {"U", "V"}.issubset(vars): + ret["UV"] = ("U", "V") + if {"U", "V", "W"}.issubset(vars): + ret["UVW"] = ("U", "V", "W") + return ret + + def _is_coordinate_in_degrees(da: xr.DataArray) -> bool: units = da.attrs.get("units") if units is None: diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index 2bed7a385..9d11a42e6 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -8,6 +8,7 @@ from parcels import Field, ParticleFile, ParticleSet, XGrid, convert from parcels._core.fieldset import FieldSet, _datetime_to_msg +from parcels._core.model import _default_vector_field_components from parcels._datasets.structured.generic import datasets as datasets_structured from parcels._datasets.structured.generic import datasets_sgrid from parcels._datasets.unstructured.generic import datasets as datasets_unstructured @@ -126,7 +127,17 @@ def test_fieldset_vectorfield_none(): assert "UV" not in fset1.fields -def test_resolve_vector_field_components(): ... +@pytest.mark.parametrize( + "data_vars,expected", + [ + (["U", "V", "land_mask"], {"UV": ("U", "V")}), + (["U", "V", "W", "land_mask"], {"UV": ("U", "V"), "UVW": ("U", "V", "W")}), + (["field1", "field2", "field3"], {}), + ], +) +def test_default_vector_field_components(data_vars, expected): + got = _default_vector_field_components(data_vars) + assert got == expected # TODO restructure: use adding of fieldset notation to test this From cfaaba2cb6af380fc912d6ae2da6cf36f9e9e6dc Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 1 Jul 2026 15:29:41 +0200 Subject: [PATCH 08/20] Remove duplicate function --- src/parcels/_core/fieldset.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index be7481ca4..5891ecee4 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -371,9 +371,3 @@ def _format_calendar_error_message(field: Field | VectorField, reference_datetim ], "W": ["upward_sea_water_velocity", "vertical_sea_water_velocity"], } - - -def _is_agrid(ds: xr.Dataset) -> bool: - # check if U and V are defined on the same dimensions - # if yes, interpret as A grid - return set(ds["U"].dims) == set(ds["V"].dims) From 284dcdbd2b8b26158e78ba4b32ec52f89cf760b2 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 1 Jul 2026 15:29:41 +0200 Subject: [PATCH 09/20] Update StructuredModelData.construct_fields() --- src/parcels/_core/model.py | 56 ++++++++++++++++++++++++-------------- tests/test_fieldset.py | 1 + 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/src/parcels/_core/model.py b/src/parcels/_core/model.py index aaec5873b..1e747831b 100644 --- a/src/parcels/_core/model.py +++ b/src/parcels/_core/model.py @@ -19,7 +19,7 @@ assert_all_field_dims_have_axis, # noqa: F401, leave import for now until decision is made # TODO v4: Make decision ) from parcels._logger import logger -from parcels._python import _MissingType +from parcels._python import _MISSING, _MissingType from parcels._typing import Mesh from parcels.convert import _ds_rename_using_standard_names from parcels.interpolators import ( @@ -116,26 +116,19 @@ def construct_fields(self) -> list[Field | VectorField]: single_fields: dict[str, Field] = {} vector_fields: dict[str, VectorField] = {} scalar_field_names = self.scalar_field_names - if "U" in scalar_field_names and "V" in scalar_field_names: - interp_method = XLinear_Velocity() if _is_agrid(self.data) else CGrid_Velocity() - single_fields["U"] = Field("U", self) - single_fields["V"] = Field("V", self) - vector_fields["UV"] = VectorField("UV", single_fields["U"], single_fields["V"], interp_method=interp_method) - if "W" in scalar_field_names: - single_fields["W"] = Field("W", self) - vector_fields["UVW"] = VectorField( - "UVW", - single_fields["U"], - single_fields["V"], - single_fields["W"], - interp_method=interp_method, - ) + for varname in set(scalar_field_names): + single_fields[varname] = Field(str(varname), self) - fields: dict[str, Field | VectorField] = {**single_fields, **vector_fields} - for varname in set(scalar_field_names) - set(fields.keys()): - fields[varname] = Field(str(varname), self) + for vfield_name, components in self.vector_field_components.items(): + interp_method = ( + XLinear_Velocity() if _is_agrid(self.data, u=components[0], v=components[1]) else CGrid_Velocity() + ) + component_fields = [single_fields[name] for name in components] + vector_fields[vfield_name] = VectorField(vfield_name, *component_fields, interp_method=interp_method) + + fields: dict[str, Field | VectorField] = {**single_fields, **vector_fields} return list(fields.values()) @classmethod @@ -168,6 +161,9 @@ def from_sgrid_conventions( # ds["lon"] = ds[node_dimensions[0]] # ds["lat"] = ds[node_dimensions[1]] + vector_fields = resolve_vector_fields(ds, vector_fields) + assert_vector_field_components_in_dataset(ds, vector_fields) + model = cls(ds, mesh=mesh, vector_field_components=vector_fields) model._fields = model.construct_fields() for f in model._fields: @@ -176,6 +172,26 @@ def from_sgrid_conventions( return model +def resolve_vector_fields( + ds: xr.Dataset, vector_fields: TVectorFieldMapping | None | _MissingType +) -> TVectorFieldMapping: + if vector_fields is None: + return {} + if vector_fields is _MISSING: # i.e., the default vectorfield discovery behaviour + return _default_vector_field_components(ds.data_vars) + return vector_fields + + +def assert_vector_field_components_in_dataset(ds: xr.Dataset, vector_fields: TVectorFieldMapping) -> None: + for components in vector_fields.values(): + for c in components: + if c not in ds.data_vars: + raise ValueError( + f"Field component '{c}' not present in the source dataset, but is listed in {vector_fields=!r}. This component cannot be used in this mapping." + ) + return + + CONSTANT_FIELD_MODELS = { mesh: StructuredModelData.from_sgrid_conventions( xr.Dataset( @@ -389,10 +405,10 @@ def _select_uxinterpolator(da: ux.UxDataArray): return None -def _is_agrid(ds: xr.Dataset) -> bool: +def _is_agrid(ds: xr.Dataset, u: str, v: str) -> bool: # check if U and V are defined on the same dimensions # if yes, interpret as A grid - return set(ds["U"].dims) == set(ds["V"].dims) + return set(ds[u].dims) == set(ds[v].dims) def _get_time_interval(data: xr.DataArray | ux.UxDataArray) -> TimeInterval | None: diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index 9d11a42e6..b7a6efe54 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -109,6 +109,7 @@ def test_fieldset_vectorfield_default(): def test_fieldset_vectorfield_custom(): ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) + ds1 = ds1.rename({"U": "U_wind", "V": "V_wind"}) fset1 = FieldSet.from_sgrid_conventions(ds1, mesh="flat", vector_fields={"UV_wind": ("U_wind", "V_wind")}) From be87edec152fb087ed4778aa1a77f88865824904 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 1 Jul 2026 16:00:57 +0200 Subject: [PATCH 10/20] Update naming --- tests/test_fieldset.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index b7a6efe54..79db2ccdd 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -97,35 +97,35 @@ def test_fieldset_from_structured_generic_datasets(ds): assert len(fieldset.gridset) == 1 -def test_fieldset_vectorfield_default(): - ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) +def test_fieldset_structured_vectorfield_default(): + ds = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) - fset1 = FieldSet.from_sgrid_conventions(ds1, mesh="flat") + fset = FieldSet.from_sgrid_conventions(ds, mesh="flat") - assert "U" in fset1.fields - assert "V" in fset1.fields - assert "UV" in fset1.fields + assert "U" in fset.fields + assert "V" in fset.fields + assert "UV" in fset.fields -def test_fieldset_vectorfield_custom(): - ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) - ds1 = ds1.rename({"U": "U_wind", "V": "V_wind"}) +def test_fieldset_structured_vectorfield_custom(): + ds = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) + ds = ds.rename({"U": "U_wind", "V": "V_wind"}) - fset1 = FieldSet.from_sgrid_conventions(ds1, mesh="flat", vector_fields={"UV_wind": ("U_wind", "V_wind")}) + fset = FieldSet.from_sgrid_conventions(ds, mesh="flat", vector_fields={"UV_wind": ("U_wind", "V_wind")}) - assert "U_wind" in fset1.fields - assert "V_wind" in fset1.fields - assert "UV_wind" in fset1.fields + assert "U_wind" in fset.fields + assert "V_wind" in fset.fields + assert "UV_wind" in fset.fields -def test_fieldset_vectorfield_none(): - ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) +def test_fieldset_structured_vectorfield_none(): + ds = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) - fset1 = FieldSet.from_sgrid_conventions(ds1, mesh="flat", vector_fields=None) + fset = FieldSet.from_sgrid_conventions(ds, mesh="flat", vector_fields=None) - assert "U" in fset1.fields - assert "V" in fset1.fields - assert "UV" not in fset1.fields + assert "U" in fset.fields + assert "V" in fset.fields + assert "UV" not in fset.fields @pytest.mark.parametrize( From 80451c81c87d3fb4abc7922a92fe1ce652269a9b Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 1 Jul 2026 16:00:57 +0200 Subject: [PATCH 11/20] Update unstructured code to work with custom vectorfields --- src/parcels/_core/model.py | 28 ++++++++++++++-------------- tests/test_fieldset.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 14 deletions(-) diff --git a/src/parcels/_core/model.py b/src/parcels/_core/model.py index 1e747831b..2fef8189f 100644 --- a/src/parcels/_core/model.py +++ b/src/parcels/_core/model.py @@ -239,21 +239,17 @@ def construct_fields(self) -> list[Field | VectorField]: single_fields: dict[str, Field] = {} vector_fields: dict[str, VectorField] = {} scalar_field_names = self.scalar_field_names - if "U" in scalar_field_names and "V" in scalar_field_names: - single_fields["U"] = Field("U", self) - single_fields["V"] = Field("V", self) - vector_fields["UV"] = VectorField("UV", single_fields["U"], single_fields["V"], interp_method=Ux_Velocity()) - - if "W" in scalar_field_names: - single_fields["W"] = Field("W", self) - vector_fields["UVW"] = VectorField( - "UVW", single_fields["U"], single_fields["V"], single_fields["W"], interp_method=Ux_Velocity() - ) - fields: dict[str, Field | VectorField] = {**single_fields, **vector_fields} - for varname in set(scalar_field_names) - set(single_fields.keys()): - fields[varname] = Field(str(varname), self) + for varname in set(scalar_field_names): + single_fields[varname] = Field(str(varname), self) + + for vfield_name, components in self.vector_field_components.items(): + interp_method = Ux_Velocity() + + component_fields = [single_fields[name] for name in components] + vector_fields[vfield_name] = VectorField(vfield_name, *component_fields, interp_method=interp_method) + fields: dict[str, Field | VectorField] = {**single_fields, **vector_fields} return list(fields.values()) def assert_valid_field_data(self, field_data: ux.UxDataArray) -> None: @@ -276,7 +272,11 @@ def from_ugrid_conventions( grid = UxGrid(ds.uxgrid, z=ds.coords["zf"], mesh=mesh) ds = _discover_ux_U_and_V(ds) - model = cls(ds, grid) + + vector_fields = resolve_vector_fields(ds, vector_fields) + assert_vector_field_components_in_dataset(ds, vector_fields) + + model = cls(ds, grid, vector_fields) model._fields = model.construct_fields() for f in model._fields: if isinstance(f, Field): diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index 79db2ccdd..60099640d 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -128,6 +128,36 @@ def test_fieldset_structured_vectorfield_none(): assert "UV" not in fset.fields +def test_fieldset_unstructured_vectorfield_default(): + ds = datasets_unstructured["stommel_gyre_delaunay"] + fset = FieldSet.from_ugrid_conventions(ds, mesh="spherical") + + assert "U" in fset.fields + assert "V" in fset.fields + assert "UV" in fset.fields + + +def test_fieldset_unstructured_vectorfield_custom(): + ds = datasets_unstructured["stommel_gyre_delaunay"] + ds = ds.rename({"U": "U_wind", "V": "V_wind"}) + + fset = FieldSet.from_ugrid_conventions(ds, mesh="spherical", vector_fields={"UV_wind": ("U_wind", "V_wind")}) + + assert "U_wind" in fset.fields + assert "V_wind" in fset.fields + assert "UV_wind" in fset.fields + + +def test_fieldset_unstructured_vectorfield_none(): + ds = datasets_unstructured["stommel_gyre_delaunay"] + + fset = FieldSet.from_ugrid_conventions(ds, mesh="spherical", vector_fields=None) + + assert "U" in fset.fields + assert "V" in fset.fields + assert "UV" not in fset.fields + + @pytest.mark.parametrize( "data_vars,expected", [ From 872fe66aa1f1c9996f1b646bed2769849e98476c Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 1 Jul 2026 16:27:47 +0200 Subject: [PATCH 12/20] Update docstring --- src/parcels/_core/fieldset.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index 5891ecee4..565d71ab5 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -227,6 +227,11 @@ def from_ugrid_conventions( ---------- ds : uxarray.UxDataset uxarray.UxDataset as obtained from the uxarray package but with appropriate named vertical dimensions + vector_fields : Mapping[str, tuple[str, ...]] or None, optional + Mapping of vector field names to tuples of component variable names in the dataset. + For example, ``{"UV": ("U", "V"), "UVW": ("U", "V", "W")}``. + If ``None``, no vector fields are constructed. If omitted (default), vector fields + are auto-discovered from standard variable names (``U``/``V``/``W``). Returns ------- @@ -260,6 +265,11 @@ def from_sgrid_conventions( mesh : str String indicating the type of mesh coordinates used during velocity interpolation. Options are "spherical" or "flat". + vector_fields : Mapping[str, tuple[str, ...]] or None, optional + Mapping of vector field names to tuples of component variable names in the dataset. + For example, ``{"UV": ("U", "V"), "UVW": ("U", "V", "W")}``. + If ``None``, no vector fields are constructed. If omitted (default), vector fields + are auto-discovered from standard variable names (``U``/``V``/``W``). Returns ------- From a40ab20bb4ca3e7aa97422eaac818b4649300690 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 1 Jul 2026 16:46:17 +0200 Subject: [PATCH 13/20] Rename sentinel value --- src/parcels/_core/fieldset.py | 6 +++--- src/parcels/_core/model.py | 10 +++++----- src/parcels/_python.py | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index 565d71ab5..390765ae2 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -20,7 +20,7 @@ from parcels._core.utils.string import _assert_str_and_python_varname from parcels._core.utils.time import get_datetime_type_calendar from parcels._core.utils.time import is_compatible as datetime_is_compatible -from parcels._python import _MISSING, _MissingType +from parcels._python import NOTSET, NotSetType from parcels._typing import Mesh from parcels.interpolators import ( XConstantField, @@ -212,7 +212,7 @@ def from_ugrid_conventions( cls, ds: ux.UxDataset, mesh: str = "spherical", - vector_fields: TVectorFieldMapping | None | _MissingType = _MISSING, + vector_fields: TVectorFieldMapping | None | NotSetType = NOTSET, ): """Create a FieldSet from a Parcels compliant uxarray.UxDataset. @@ -250,7 +250,7 @@ def from_sgrid_conventions( cls, ds: xr.Dataset, mesh: Mesh | None = None, - vector_fields: TVectorFieldMapping | None | _MissingType = _MISSING, + vector_fields: TVectorFieldMapping | None | NotSetType = NOTSET, ): # TODO: Update mesh to be discovered from the dataset metadata """Create a FieldSet from a dataset using SGRID convention metadata. diff --git a/src/parcels/_core/model.py b/src/parcels/_core/model.py index 2fef8189f..ac96cf5f9 100644 --- a/src/parcels/_core/model.py +++ b/src/parcels/_core/model.py @@ -19,7 +19,7 @@ assert_all_field_dims_have_axis, # noqa: F401, leave import for now until decision is made # TODO v4: Make decision ) from parcels._logger import logger -from parcels._python import _MISSING, _MissingType +from parcels._python import NOTSET, NotSetType from parcels._typing import Mesh from parcels.convert import _ds_rename_using_standard_names from parcels.interpolators import ( @@ -133,7 +133,7 @@ def construct_fields(self) -> list[Field | VectorField]: @classmethod def from_sgrid_conventions( - cls, ds: xr.Dataset, mesh: Mesh | None, vector_fields: TVectorFieldMapping | None | _MissingType + cls, ds: xr.Dataset, mesh: Mesh | None, vector_fields: TVectorFieldMapping | None | NotSetType ) -> Self: ds = ds.copy() if mesh is None: @@ -173,11 +173,11 @@ def from_sgrid_conventions( def resolve_vector_fields( - ds: xr.Dataset, vector_fields: TVectorFieldMapping | None | _MissingType + ds: xr.Dataset, vector_fields: TVectorFieldMapping | None | NotSetType ) -> TVectorFieldMapping: if vector_fields is None: return {} - if vector_fields is _MISSING: # i.e., the default vectorfield discovery behaviour + if vector_fields is NOTSET: # i.e., the default vectorfield discovery behaviour return _default_vector_field_components(ds.data_vars) return vector_fields @@ -262,7 +262,7 @@ def scalar_field_names(self) -> list[str]: @classmethod def from_ugrid_conventions( - cls, ds: ux.UxDataset, mesh: Mesh, vector_fields: TVectorFieldMapping | None | _MissingType + cls, ds: ux.UxDataset, mesh: Mesh, vector_fields: TVectorFieldMapping | None | NotSetType ): ds_dims = list(ds.dims) if not all(dim in ds_dims for dim in ["time", "zf", "zc"]): diff --git a/src/parcels/_python.py b/src/parcels/_python.py index 552dcd614..4f8bf106e 100644 --- a/src/parcels/_python.py +++ b/src/parcels/_python.py @@ -7,8 +7,8 @@ K = TypeVar("K") V = TypeVar("V") -_MissingType = enum.Enum("_MissingType", "VALUE") -_MISSING = _MissingType.VALUE +NotSetType = enum.Enum("NotSetType", "VALUE") +NOTSET = NotSetType.VALUE def isinstance_noimport(obj, class_or_tuple): From 8043eeca0bffeb2bfaa72f28925402f5f59e10e8 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 1 Jul 2026 16:47:36 +0200 Subject: [PATCH 14/20] Fix tests --- tests/test_field.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_field.py b/tests/test_field.py index f3893fbcc..5ae9ed794 100644 --- a/tests/test_field.py +++ b/tests/test_field.py @@ -12,6 +12,7 @@ from parcels._datasets.structured.generic import datasets as datasets_structured from parcels._datasets.unstructured.generic import _ux_constant_flow_face_centered_2D from parcels._datasets.unstructured.generic import datasets as datasets_unstructured +from parcels._python import NOTSET from parcels.interpolators import ( UxConstantFaceConstantZC, ) @@ -19,7 +20,7 @@ def test_field_init_param_types(): data = datasets_structured["ds_2d_left"] - model = StructuredModelData.from_sgrid_conventions(data, mesh="flat") + model = StructuredModelData.from_sgrid_conventions(data, mesh="flat", vector_fields=NOTSET) with pytest.raises(TypeError, match="Expected a string for variable name, got int instead."): Field(name=123, model=model) @@ -52,7 +53,7 @@ def test_field_init_fail_on_float_time_dim(): ds["time"].attrs, ) - model = StructuredModelData.from_sgrid_conventions(ds, mesh="flat") + model = StructuredModelData.from_sgrid_conventions(ds, mesh="flat", vector_fields=NOTSET) with pytest.raises( ValueError, match=r"Are you sure that the time dimension on the xarray dataset is stored as timedelta, datetime or cftime datetime objects\?", @@ -64,7 +65,7 @@ def test_field_init_fail_on_float_time_dim(): def test_field_time_interval(): """Test that field.time_interval delegates correctly to model.time_interval.""" data = datasets_structured["ds_2d_left"] - model = StructuredModelData.from_sgrid_conventions(data, mesh="flat") + model = StructuredModelData.from_sgrid_conventions(data, mesh="flat", vector_fields=NOTSET) field = Field(name="data_g", model=model) assert field.time_interval.left == np.datetime64("2000-01-01") assert field.time_interval.right == np.datetime64("2001-01-01") @@ -77,7 +78,7 @@ def test_vectorfield_init_different_time_intervals(): def test_field_invalid_interpolator(): ds = datasets_structured["ds_2d_left"] - model = StructuredModelData.from_sgrid_conventions(ds, mesh="flat") + model = StructuredModelData.from_sgrid_conventions(ds, mesh="flat", vector_fields=NOTSET) field = Field(name="data_g", model=model) def not_a_scalar_interpolator(particle_positions, grid_positions, field): @@ -90,7 +91,7 @@ def not_a_scalar_interpolator(particle_positions, grid_positions, field): def test_vectorfield_invalid_interpolator(): ds = datasets_structured["ds_2d_left"] - model = StructuredModelData.from_sgrid_conventions(ds, mesh="flat") + model = StructuredModelData.from_sgrid_conventions(ds, mesh="flat", vector_fields=NOTSET) fields = {f.name: f for f in model.construct_fields()} U = fields["U_A_grid"] V = fields["V_A_grid"] From 3dbd950c2a1a59e54366a675f5fe8bb44ea73bd8 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 1 Jul 2026 16:52:26 +0200 Subject: [PATCH 15/20] Fix mypy issues --- src/parcels/_core/fieldset.py | 6 +++--- src/parcels/_core/model.py | 30 +++++++++++++++--------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index 390765ae2..badbfa328 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -14,7 +14,7 @@ CONSTANT_FIELD_MODELS, ModelData, StructuredModelData, - TVectorFieldMapping, + TVectorField, UnstructuredModelData, ) from parcels._core.utils.string import _assert_str_and_python_varname @@ -212,7 +212,7 @@ def from_ugrid_conventions( cls, ds: ux.UxDataset, mesh: str = "spherical", - vector_fields: TVectorFieldMapping | None | NotSetType = NOTSET, + vector_fields: TVectorField | None | NotSetType = NOTSET, ): """Create a FieldSet from a Parcels compliant uxarray.UxDataset. @@ -250,7 +250,7 @@ def from_sgrid_conventions( cls, ds: xr.Dataset, mesh: Mesh | None = None, - vector_fields: TVectorFieldMapping | None | NotSetType = NOTSET, + vector_fields: TVectorField | None | NotSetType = NOTSET, ): # TODO: Update mesh to be discovered from the dataset metadata """Create a FieldSet from a dataset using SGRID convention metadata. diff --git a/src/parcels/_core/model.py b/src/parcels/_core/model.py index ac96cf5f9..ab6697e84 100644 --- a/src/parcels/_core/model.py +++ b/src/parcels/_core/model.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Mapping, Sequence +from collections.abc import Hashable, Sequence from typing import Any, Self import cf_xarray # noqa: F401 @@ -34,14 +34,14 @@ ) from parcels.interpolators._base import ScalarInterpolator, VectorInterpolator -TVectorFieldMapping = Mapping[str, tuple[str, str] | tuple[str, str, str]] +TVectorField = dict[str, tuple[str, str] | tuple[str, str, str]] class ModelData(ABC): data: Any grid: BaseGrid field_to_interpolator: dict[str, ScalarInterpolator | VectorInterpolator] - vector_field_components: TVectorFieldMapping + vector_field_components: TVectorField @abstractmethod def construct_fields(self) -> list[Field | VectorField]: ... @@ -84,7 +84,7 @@ def preprocess_sgrid_model_data(ds: xr.Dataset) -> xr.Dataset: class StructuredModelData(ModelData): - def __init__(self, data: xr.Dataset, mesh: Mesh, vector_field_components: TVectorFieldMapping): + def __init__(self, data: xr.Dataset, mesh: Mesh, vector_field_components: TVectorField): if not isinstance(data, xr.Dataset): raise ValueError(f"Expected `data` to be an xarray.Dataset . Got {type(data)}") @@ -126,14 +126,14 @@ def construct_fields(self) -> list[Field | VectorField]: ) component_fields = [single_fields[name] for name in components] - vector_fields[vfield_name] = VectorField(vfield_name, *component_fields, interp_method=interp_method) + vector_fields[vfield_name] = VectorField(vfield_name, *component_fields, interp_method=interp_method) # type:ignore[misc,arg-type] fields: dict[str, Field | VectorField] = {**single_fields, **vector_fields} return list(fields.values()) @classmethod def from_sgrid_conventions( - cls, ds: xr.Dataset, mesh: Mesh | None, vector_fields: TVectorFieldMapping | None | NotSetType + cls, ds: xr.Dataset, mesh: Mesh | None, vector_fields: TVectorField | None | NotSetType ) -> Self: ds = ds.copy() if mesh is None: @@ -173,16 +173,16 @@ def from_sgrid_conventions( def resolve_vector_fields( - ds: xr.Dataset, vector_fields: TVectorFieldMapping | None | NotSetType -) -> TVectorFieldMapping: + ds: xr.Dataset, vector_fields: TVectorField | None | NotSetType +) -> TVectorField: if vector_fields is None: return {} if vector_fields is NOTSET: # i.e., the default vectorfield discovery behaviour - return _default_vector_field_components(ds.data_vars) + return _default_vector_field_components(list(ds.data_vars)) return vector_fields -def assert_vector_field_components_in_dataset(ds: xr.Dataset, vector_fields: TVectorFieldMapping) -> None: +def assert_vector_field_components_in_dataset(ds: xr.Dataset, vector_fields: TVectorField) -> None: for components in vector_fields.values(): for c in components: if c not in ds.data_vars: @@ -222,7 +222,7 @@ def assert_vector_field_components_in_dataset(ds: xr.Dataset, vector_fields: TVe class UnstructuredModelData(ModelData): - def __init__(self, data: ux.UxDataset, grid: UxGrid, vector_field_components: TVectorFieldMapping): + def __init__(self, data: ux.UxDataset, grid: UxGrid, vector_field_components: TVectorField): if not isinstance(data, ux.UxDataset): raise ValueError(f"Expected `data` to be an uxarray.UxDataset . Got {type(data)}") @@ -247,7 +247,7 @@ def construct_fields(self) -> list[Field | VectorField]: interp_method = Ux_Velocity() component_fields = [single_fields[name] for name in components] - vector_fields[vfield_name] = VectorField(vfield_name, *component_fields, interp_method=interp_method) + vector_fields[vfield_name] = VectorField(vfield_name, *component_fields, interp_method=interp_method) # type:ignore[misc, arg-type] fields: dict[str, Field | VectorField] = {**single_fields, **vector_fields} return list(fields.values()) @@ -262,7 +262,7 @@ def scalar_field_names(self) -> list[str]: @classmethod def from_ugrid_conventions( - cls, ds: ux.UxDataset, mesh: Mesh, vector_fields: TVectorFieldMapping | None | NotSetType + cls, ds: ux.UxDataset, mesh: Mesh, vector_fields: TVectorField | None | NotSetType ): ds_dims = list(ds.dims) if not all(dim in ds_dims for dim in ["time", "zf", "zc"]): @@ -304,9 +304,9 @@ def _get_mesh_type_from_sgrid_dataset(ds_sgrid: xr.Dataset) -> Mesh: return "spherical" if _is_coordinate_in_degrees(ds_sgrid[fpoint_x]) else "flat" -def _default_vector_field_components(data_vars: Sequence[str]) -> TVectorFieldMapping: +def _default_vector_field_components(data_vars: Sequence[Hashable]) -> TVectorField: vars = set(data_vars) - ret = {} + ret: TVectorField = {} if {"U", "V"}.issubset(vars): ret["UV"] = ("U", "V") From a6de0cadffd10684c32307f5c6148b77eb58f95f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Jul 2026 14:57:56 +0000 Subject: [PATCH 16/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/parcels/_core/model.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/parcels/_core/model.py b/src/parcels/_core/model.py index ab6697e84..94b333a74 100644 --- a/src/parcels/_core/model.py +++ b/src/parcels/_core/model.py @@ -172,9 +172,7 @@ def from_sgrid_conventions( return model -def resolve_vector_fields( - ds: xr.Dataset, vector_fields: TVectorField | None | NotSetType -) -> TVectorField: +def resolve_vector_fields(ds: xr.Dataset, vector_fields: TVectorField | None | NotSetType) -> TVectorField: if vector_fields is None: return {} if vector_fields is NOTSET: # i.e., the default vectorfield discovery behaviour @@ -261,9 +259,7 @@ def scalar_field_names(self) -> list[str]: return list(self.data.data_vars) @classmethod - def from_ugrid_conventions( - cls, ds: ux.UxDataset, mesh: Mesh, vector_fields: TVectorField | None | NotSetType - ): + def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: Mesh, vector_fields: TVectorField | None | NotSetType): ds_dims = list(ds.dims) if not all(dim in ds_dims for dim in ["time", "zf", "zc"]): raise ValueError( From 3fba0580417c7d615db2805caf1424281f05c47b Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 2 Jul 2026 10:14:31 +0200 Subject: [PATCH 17/20] Move TVectorField to _typing module --- src/parcels/_core/fieldset.py | 11 +++++------ src/parcels/_core/model.py | 25 ++++++++++++++----------- src/parcels/_typing.py | 1 + 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index badbfa328..6e3ce28dd 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -9,19 +9,18 @@ import uxarray as ux import xarray as xr +import parcels._typing as ptyping from parcels._core.field import Field, VectorField from parcels._core.model import ( CONSTANT_FIELD_MODELS, ModelData, StructuredModelData, - TVectorField, UnstructuredModelData, ) from parcels._core.utils.string import _assert_str_and_python_varname from parcels._core.utils.time import get_datetime_type_calendar from parcels._core.utils.time import is_compatible as datetime_is_compatible from parcels._python import NOTSET, NotSetType -from parcels._typing import Mesh from parcels.interpolators import ( XConstantField, ) @@ -151,7 +150,7 @@ def add_field(self, field: Field, name: str | None = None): self.fields[name] = field - def add_constant_field(self, name: str, value, mesh: Mesh = "spherical"): + def add_constant_field(self, name: str, value, mesh: ptyping.Mesh = "spherical"): """Wrapper function to add a Field that is constant in space, useful e.g. when using constant horizontal diffusivity @@ -212,7 +211,7 @@ def from_ugrid_conventions( cls, ds: ux.UxDataset, mesh: str = "spherical", - vector_fields: TVectorField | None | NotSetType = NOTSET, + vector_fields: ptyping.VectorFields | None | NotSetType = NOTSET, ): """Create a FieldSet from a Parcels compliant uxarray.UxDataset. @@ -249,8 +248,8 @@ def from_ugrid_conventions( def from_sgrid_conventions( cls, ds: xr.Dataset, - mesh: Mesh | None = None, - vector_fields: TVectorField | None | NotSetType = NOTSET, + mesh: ptyping.Mesh | None = None, + vector_fields: ptyping.VectorFields | None | NotSetType = NOTSET, ): # TODO: Update mesh to be discovered from the dataset metadata """Create a FieldSet from a dataset using SGRID convention metadata. diff --git a/src/parcels/_core/model.py b/src/parcels/_core/model.py index 94b333a74..9a18a6f2b 100644 --- a/src/parcels/_core/model.py +++ b/src/parcels/_core/model.py @@ -9,6 +9,7 @@ import xarray as xr import parcels._sgrid as sgrid +import parcels._typing as ptyping from parcels._core.basegrid import BaseGrid from parcels._core.field import Field, VectorField from parcels._core.utils.time import TimeInterval @@ -34,14 +35,12 @@ ) from parcels.interpolators._base import ScalarInterpolator, VectorInterpolator -TVectorField = dict[str, tuple[str, str] | tuple[str, str, str]] - class ModelData(ABC): data: Any grid: BaseGrid field_to_interpolator: dict[str, ScalarInterpolator | VectorInterpolator] - vector_field_components: TVectorField + vector_field_components: ptyping.VectorFields @abstractmethod def construct_fields(self) -> list[Field | VectorField]: ... @@ -84,7 +83,7 @@ def preprocess_sgrid_model_data(ds: xr.Dataset) -> xr.Dataset: class StructuredModelData(ModelData): - def __init__(self, data: xr.Dataset, mesh: Mesh, vector_field_components: TVectorField): + def __init__(self, data: xr.Dataset, mesh: Mesh, vector_field_components: ptyping.VectorFields): if not isinstance(data, xr.Dataset): raise ValueError(f"Expected `data` to be an xarray.Dataset . Got {type(data)}") @@ -133,7 +132,7 @@ def construct_fields(self) -> list[Field | VectorField]: @classmethod def from_sgrid_conventions( - cls, ds: xr.Dataset, mesh: Mesh | None, vector_fields: TVectorField | None | NotSetType + cls, ds: xr.Dataset, mesh: Mesh | None, vector_fields: ptyping.VectorFields | None | NotSetType ) -> Self: ds = ds.copy() if mesh is None: @@ -172,7 +171,9 @@ def from_sgrid_conventions( return model -def resolve_vector_fields(ds: xr.Dataset, vector_fields: TVectorField | None | NotSetType) -> TVectorField: +def resolve_vector_fields( + ds: xr.Dataset, vector_fields: ptyping.VectorFields | None | NotSetType +) -> ptyping.VectorFields: if vector_fields is None: return {} if vector_fields is NOTSET: # i.e., the default vectorfield discovery behaviour @@ -180,7 +181,7 @@ def resolve_vector_fields(ds: xr.Dataset, vector_fields: TVectorField | None | N return vector_fields -def assert_vector_field_components_in_dataset(ds: xr.Dataset, vector_fields: TVectorField) -> None: +def assert_vector_field_components_in_dataset(ds: xr.Dataset, vector_fields: ptyping.VectorFields) -> None: for components in vector_fields.values(): for c in components: if c not in ds.data_vars: @@ -220,7 +221,7 @@ def assert_vector_field_components_in_dataset(ds: xr.Dataset, vector_fields: TVe class UnstructuredModelData(ModelData): - def __init__(self, data: ux.UxDataset, grid: UxGrid, vector_field_components: TVectorField): + def __init__(self, data: ux.UxDataset, grid: UxGrid, vector_field_components: ptyping.VectorFields): if not isinstance(data, ux.UxDataset): raise ValueError(f"Expected `data` to be an uxarray.UxDataset . Got {type(data)}") @@ -259,7 +260,9 @@ def scalar_field_names(self) -> list[str]: return list(self.data.data_vars) @classmethod - def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: Mesh, vector_fields: TVectorField | None | NotSetType): + def from_ugrid_conventions( + cls, ds: ux.UxDataset, mesh: Mesh, vector_fields: ptyping.VectorFields | None | NotSetType + ): ds_dims = list(ds.dims) if not all(dim in ds_dims for dim in ["time", "zf", "zc"]): raise ValueError( @@ -300,9 +303,9 @@ def _get_mesh_type_from_sgrid_dataset(ds_sgrid: xr.Dataset) -> Mesh: return "spherical" if _is_coordinate_in_degrees(ds_sgrid[fpoint_x]) else "flat" -def _default_vector_field_components(data_vars: Sequence[Hashable]) -> TVectorField: +def _default_vector_field_components(data_vars: Sequence[Hashable]) -> ptyping.VectorFields: vars = set(data_vars) - ret: TVectorField = {} + ret: ptyping.VectorFields = {} if {"U", "V"}.issubset(vars): ret["UV"] = ("U", "V") diff --git a/src/parcels/_typing.py b/src/parcels/_typing.py index 18e8aa55f..e8993cb05 100644 --- a/src/parcels/_typing.py +++ b/src/parcels/_typing.py @@ -47,6 +47,7 @@ CfAxis = XgcmAxisDirection XgcmAxisPosition = Literal["center", "left", "right", "inner", "outer"] XgcmAxes = Mapping[XgcmAxisDirection, "xgcm.Axis"] +VectorFields = dict[str, tuple[str, str] | tuple[str, str, str]] def _is_xarray_object(obj): # with no imports From dfb9044280ee3ba94c2ec771366cb60f66bdf59c Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 2 Jul 2026 10:44:55 +0200 Subject: [PATCH 18/20] Improve validation of vector_fields --- src/parcels/_core/model.py | 27 +++++++++++++++++++++++++-- tests/test_fieldset.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/src/parcels/_core/model.py b/src/parcels/_core/model.py index 9a18a6f2b..e8858ac9c 100644 --- a/src/parcels/_core/model.py +++ b/src/parcels/_core/model.py @@ -161,7 +161,7 @@ def from_sgrid_conventions( # ds["lat"] = ds[node_dimensions[1]] vector_fields = resolve_vector_fields(ds, vector_fields) - assert_vector_field_components_in_dataset(ds, vector_fields) + assert_valid_vector_fields(ds, vector_fields) model = cls(ds, mesh=mesh, vector_field_components=vector_fields) model._fields = model.construct_fields() @@ -181,6 +181,29 @@ def resolve_vector_fields( return vector_fields +def assert_valid_vector_fields(ds: xr.Dataset, vector_fields: ptyping.VectorFields) -> None: + # if not isinstance(vector_fields, dict): + # raise ValueError(f"vector_fields must be a dictionary. Got {type(vector_fields)=!r}.") + + for vfield_name, components in vector_fields.items(): + if not isinstance(vfield_name, str): + raise ValueError( + f"Invalid `vector_fields` argument. Vector field name in `vector_fields` should be a string. Got field name {vfield_name!r}." + ) + if not (2 <= len(components) <= 3): + raise ValueError( + f"Invalid `vector_fields` argument. Vector fields must have either 2 or 3 components. Vector field {vfield_name} has {len(components)} components." + ) + for c in components: + if not isinstance(c, str): + raise ValueError( + f"Invalid `vector_fields` argument. Component names must be strings. Got component name of value {c!r}." + ) + + assert_vector_field_components_in_dataset(ds, vector_fields) + return + + def assert_vector_field_components_in_dataset(ds: xr.Dataset, vector_fields: ptyping.VectorFields) -> None: for components in vector_fields.values(): for c in components: @@ -273,7 +296,7 @@ def from_ugrid_conventions( ds = _discover_ux_U_and_V(ds) vector_fields = resolve_vector_fields(ds, vector_fields) - assert_vector_field_components_in_dataset(ds, vector_fields) + assert_valid_vector_fields(ds, vector_fields) model = cls(ds, grid, vector_fields) model._fields = model.construct_fields() diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index 60099640d..1a65e5829 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from datetime import timedelta import cf_xarray # noqa: F401 @@ -97,6 +98,34 @@ def test_fieldset_from_structured_generic_datasets(ds): assert len(fieldset.gridset) == 1 +@pytest.mark.parametrize( + "vector_fields,ctx", + [ + pytest.param( + {"UV": ("U",)}, + pytest.raises(ValueError, match="must have either 2 or 3 components"), + id="single-component", + ), + pytest.param( + {"UV": ("U", "missing")}, + pytest.raises(ValueError, match="not present in the source dataset"), + id="component-not-in-dataset", + ), + pytest.param( + {"UV": ("U", "U", "U", "U")}, + pytest.raises(ValueError, match="must have either 2 or 3 components"), + id="too-many-components", + ), + pytest.param(None, nullcontext(), id="None"), + ], +) +def test_fieldset_invalid_vector_fields(vector_fields, ctx): + ds = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) + + with ctx: + FieldSet.from_sgrid_conventions(ds, mesh="flat", vector_fields=vector_fields) + + def test_fieldset_structured_vectorfield_default(): ds = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) From 3f160eaf7e87b919edfaa28e9a2a9b610604d3ee Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 2 Jul 2026 10:44:55 +0200 Subject: [PATCH 19/20] Remove None as option for vector_fields --- src/parcels/_core/fieldset.py | 10 ++++------ src/parcels/_core/model.py | 18 ++++++------------ tests/test_fieldset.py | 15 +++++++++------ 3 files changed, 19 insertions(+), 24 deletions(-) diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index 6e3ce28dd..bdea570ad 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -211,7 +211,7 @@ def from_ugrid_conventions( cls, ds: ux.UxDataset, mesh: str = "spherical", - vector_fields: ptyping.VectorFields | None | NotSetType = NOTSET, + vector_fields: ptyping.VectorFields | NotSetType = NOTSET, ): """Create a FieldSet from a Parcels compliant uxarray.UxDataset. @@ -229,8 +229,7 @@ def from_ugrid_conventions( vector_fields : Mapping[str, tuple[str, ...]] or None, optional Mapping of vector field names to tuples of component variable names in the dataset. For example, ``{"UV": ("U", "V"), "UVW": ("U", "V", "W")}``. - If ``None``, no vector fields are constructed. If omitted (default), vector fields - are auto-discovered from standard variable names (``U``/``V``/``W``). + If omitted (default), vector fields are auto-discovered from standard variable names (``U``/``V``/``W``). Returns ------- @@ -249,7 +248,7 @@ def from_sgrid_conventions( cls, ds: xr.Dataset, mesh: ptyping.Mesh | None = None, - vector_fields: ptyping.VectorFields | None | NotSetType = NOTSET, + vector_fields: ptyping.VectorFields | NotSetType = NOTSET, ): # TODO: Update mesh to be discovered from the dataset metadata """Create a FieldSet from a dataset using SGRID convention metadata. @@ -267,8 +266,7 @@ def from_sgrid_conventions( vector_fields : Mapping[str, tuple[str, ...]] or None, optional Mapping of vector field names to tuples of component variable names in the dataset. For example, ``{"UV": ("U", "V"), "UVW": ("U", "V", "W")}``. - If ``None``, no vector fields are constructed. If omitted (default), vector fields - are auto-discovered from standard variable names (``U``/``V``/``W``). + If omitted (default), vector fields are auto-discovered from standard variable names (``U``/``V``/``W``). Returns ------- diff --git a/src/parcels/_core/model.py b/src/parcels/_core/model.py index e8858ac9c..c44026370 100644 --- a/src/parcels/_core/model.py +++ b/src/parcels/_core/model.py @@ -132,7 +132,7 @@ def construct_fields(self) -> list[Field | VectorField]: @classmethod def from_sgrid_conventions( - cls, ds: xr.Dataset, mesh: Mesh | None, vector_fields: ptyping.VectorFields | None | NotSetType + cls, ds: xr.Dataset, mesh: Mesh | None, vector_fields: ptyping.VectorFields | NotSetType ) -> Self: ds = ds.copy() if mesh is None: @@ -171,19 +171,15 @@ def from_sgrid_conventions( return model -def resolve_vector_fields( - ds: xr.Dataset, vector_fields: ptyping.VectorFields | None | NotSetType -) -> ptyping.VectorFields: - if vector_fields is None: - return {} +def resolve_vector_fields(ds: xr.Dataset, vector_fields: ptyping.VectorFields | NotSetType) -> ptyping.VectorFields: if vector_fields is NOTSET: # i.e., the default vectorfield discovery behaviour return _default_vector_field_components(list(ds.data_vars)) return vector_fields def assert_valid_vector_fields(ds: xr.Dataset, vector_fields: ptyping.VectorFields) -> None: - # if not isinstance(vector_fields, dict): - # raise ValueError(f"vector_fields must be a dictionary. Got {type(vector_fields)=!r}.") + if not isinstance(vector_fields, dict): + raise ValueError(f"vector_fields must be a dictionary. Got {type(vector_fields)=!r}.") for vfield_name, components in vector_fields.items(): if not isinstance(vfield_name, str): @@ -237,7 +233,7 @@ def assert_vector_field_components_in_dataset(ds: xr.Dataset, vector_fields: pty ), ), mesh=mesh, # type:ignore - vector_fields=None, + vector_fields={}, ) for mesh in ["flat", "spherical"] } @@ -283,9 +279,7 @@ def scalar_field_names(self) -> list[str]: return list(self.data.data_vars) @classmethod - def from_ugrid_conventions( - cls, ds: ux.UxDataset, mesh: Mesh, vector_fields: ptyping.VectorFields | None | NotSetType - ): + def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: Mesh, vector_fields: ptyping.VectorFields | NotSetType): ds_dims = list(ds.dims) if not all(dim in ds_dims for dim in ["time", "zf", "zc"]): raise ValueError( diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index 1a65e5829..9eb2de1d5 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -1,4 +1,3 @@ -from contextlib import nullcontext from datetime import timedelta import cf_xarray # noqa: F401 @@ -116,7 +115,11 @@ def test_fieldset_from_structured_generic_datasets(ds): pytest.raises(ValueError, match="must have either 2 or 3 components"), id="too-many-components", ), - pytest.param(None, nullcontext(), id="None"), + pytest.param( + None, + pytest.raises(ValueError, match="vector_fields must be a dictionary"), + id="None", + ), ], ) def test_fieldset_invalid_vector_fields(vector_fields, ctx): @@ -147,10 +150,10 @@ def test_fieldset_structured_vectorfield_custom(): assert "UV_wind" in fset.fields -def test_fieldset_structured_vectorfield_none(): +def test_fieldset_structured_vectorfield_empty(): ds = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) - fset = FieldSet.from_sgrid_conventions(ds, mesh="flat", vector_fields=None) + fset = FieldSet.from_sgrid_conventions(ds, mesh="flat", vector_fields={}) assert "U" in fset.fields assert "V" in fset.fields @@ -177,10 +180,10 @@ def test_fieldset_unstructured_vectorfield_custom(): assert "UV_wind" in fset.fields -def test_fieldset_unstructured_vectorfield_none(): +def test_fieldset_unstructured_vectorfield_empty(): ds = datasets_unstructured["stommel_gyre_delaunay"] - fset = FieldSet.from_ugrid_conventions(ds, mesh="spherical", vector_fields=None) + fset = FieldSet.from_ugrid_conventions(ds, mesh="spherical", vector_fields={}) assert "U" in fset.fields assert "V" in fset.fields From 61c8c6ad5548791c3b805089c8f878a35b154930 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 2 Jul 2026 11:42:12 +0200 Subject: [PATCH 20/20] Improve test_fieldset_add Now also has vectorfields --- tests/test_fieldset.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index 9eb2de1d5..8e2f65b6f 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -326,11 +326,13 @@ def test_fieldset_add_error_on_duplicate_fields(): def test_fieldset_add(): """Test that two FieldSets can be combined with + (fset1 + fset2).""" - ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "grid"]].rename({"U_A_grid": "U1"}) - ds2 = datasets_structured["ds_2d_left"][["V_A_grid", "grid"]].rename({"V_A_grid": "V2"}) + ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) + ds2 = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename( + {"U_A_grid": "U_wind", "V_A_grid": "V_wind"} + ) fset1 = FieldSet.from_sgrid_conventions(ds1, mesh="flat") - fset2 = FieldSet.from_sgrid_conventions(ds2, mesh="flat") + fset2 = FieldSet.from_sgrid_conventions(ds2, mesh="flat", vector_fields={"UV_wind": ("U_wind", "V_wind")}) fset = fset1 + fset2