Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions src/parcels/_core/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@
import uxarray as ux
import xarray as xr

import parcels._typing as ptyping
from parcels._core.field import Field, VectorField
from parcels._core.model import CONSTANT_FIELD_MODELS, ModelData, StructuredModelData, UnstructuredModelData
from parcels._core.model import (
CONSTANT_FIELD_MODELS,
ModelData,
StructuredModelData,
UnstructuredModelData,
)
from parcels._core.utils.string import _assert_str_and_python_varname
from parcels._core.utils.time import get_datetime_type_calendar
from parcels._core.utils.time import is_compatible as datetime_is_compatible
from parcels._typing import Mesh
from parcels._python import NOTSET, NotSetType
from parcels.interpolators import (
XConstantField,
)
Expand Down Expand Up @@ -144,7 +150,7 @@ def add_field(self, field: Field, name: str | None = None):

self.fields[name] = field

def add_constant_field(self, name: str, value, mesh: Mesh = "spherical"):
def add_constant_field(self, name: str, value, mesh: ptyping.Mesh = "spherical"):
"""Wrapper function to add a Field that is constant in space,
useful e.g. when using constant horizontal diffusivity

Expand Down Expand Up @@ -201,7 +207,12 @@ def gridset(self) -> list[BaseGrid]:
return grids

@classmethod
def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: str = "spherical"):
def from_ugrid_conventions(
cls,
ds: ux.UxDataset,
mesh: str = "spherical",
vector_fields: ptyping.VectorFields | NotSetType = NOTSET,
):
"""Create a FieldSet from a Parcels compliant uxarray.UxDataset.

This is the primary ingestion method in Parcels for structured grid datasets.
Expand All @@ -215,6 +226,10 @@ def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: str = "spherical"):
----------
ds : uxarray.UxDataset
uxarray.UxDataset as obtained from the uxarray package but with appropriate named vertical dimensions
vector_fields : Mapping[str, tuple[str, ...]] or None, optional
Mapping of vector field names to tuples of component variable names in the dataset.
For example, ``{"UV": ("U", "V"), "UVW": ("U", "V", "W")}``.
If omitted (default), vector fields are auto-discovered from standard variable names (``U``/``V``/``W``).

Returns
-------
Expand All @@ -225,12 +240,15 @@ def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: str = "spherical"):
-----
See https://ugrid-conventions.github.io/ugrid-conventions/ for more information on the UGRID conventions.
"""
model = UnstructuredModelData.from_ugrid_conventions(ds, mesh)
model = UnstructuredModelData.from_ugrid_conventions(ds, mesh, vector_fields)
return cls([model])

@classmethod
def from_sgrid_conventions(
cls, ds: xr.Dataset, mesh: Mesh | None = None
cls,
ds: xr.Dataset,
mesh: ptyping.Mesh | None = None,
vector_fields: ptyping.VectorFields | NotSetType = NOTSET,
): # TODO: Update mesh to be discovered from the dataset metadata
"""Create a FieldSet from a dataset using SGRID convention metadata.

Expand All @@ -245,6 +263,10 @@ def from_sgrid_conventions(
mesh : str
String indicating the type of mesh coordinates used during
velocity interpolation. Options are "spherical" or "flat".
vector_fields : Mapping[str, tuple[str, ...]] or None, optional
Mapping of vector field names to tuples of component variable names in the dataset.
For example, ``{"UV": ("U", "V"), "UVW": ("U", "V", "W")}``.
If omitted (default), vector fields are auto-discovered from standard variable names (``U``/``V``/``W``).

Returns
-------
Expand All @@ -259,7 +281,7 @@ def from_sgrid_conventions(

See https://sgrid.github.io/sgrid/ for more information on the SGRID conventions.
"""
model = StructuredModelData.from_sgrid_conventions(ds, mesh)
model = StructuredModelData.from_sgrid_conventions(ds, mesh, vector_fields)
return cls([model])


Expand Down Expand Up @@ -356,9 +378,3 @@ def _format_calendar_error_message(field: Field | VectorField, reference_datetim
],
"W": ["upward_sea_water_velocity", "vertical_sea_water_velocity"],
}


def _is_agrid(ds: xr.Dataset) -> bool:
# check if U and V are defined on the same dimensions
# if yes, interpret as A grid
return set(ds["U"].dims) == set(ds["V"].dims)
133 changes: 94 additions & 39 deletions src/parcels/_core/model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Hashable, Sequence
from typing import Any, Self

import cf_xarray # noqa: F401
import uxarray as ux
import xarray as xr

import parcels._sgrid as sgrid
import parcels._typing as ptyping
from parcels._core.basegrid import BaseGrid
from parcels._core.field import Field, VectorField
from parcels._core.utils.time import TimeInterval
Expand All @@ -18,6 +20,7 @@
assert_all_field_dims_have_axis, # noqa: F401, leave import for now until decision is made # TODO v4: Make decision
)
from parcels._logger import logger
from parcels._python import NOTSET, NotSetType
from parcels._typing import Mesh
from parcels.convert import _ds_rename_using_standard_names
from parcels.interpolators import (
Expand All @@ -37,6 +40,7 @@ class ModelData(ABC):
data: Any
grid: BaseGrid
field_to_interpolator: dict[str, ScalarInterpolator | VectorInterpolator]
vector_field_components: ptyping.VectorFields

@abstractmethod
def construct_fields(self) -> list[Field | VectorField]: ...
Expand Down Expand Up @@ -79,7 +83,7 @@ def preprocess_sgrid_model_data(ds: xr.Dataset) -> xr.Dataset:


class StructuredModelData(ModelData):
def __init__(self, data: xr.Dataset, mesh: Mesh):
def __init__(self, data: xr.Dataset, mesh: Mesh, vector_field_components: ptyping.VectorFields):
if not isinstance(data, xr.Dataset):
raise ValueError(f"Expected `data` to be an xarray.Dataset . Got {type(data)}")

Expand All @@ -88,6 +92,7 @@ def __init__(self, data: xr.Dataset, mesh: Mesh):

self.data = data
self.grid = grid
self.vector_field_components = vector_field_components
self.field_to_interpolator = {}
self._fields: list[Field | VectorField] | None = None
self.assert_valid_model_data()
Expand All @@ -110,30 +115,25 @@ def construct_fields(self) -> list[Field | VectorField]:
single_fields: dict[str, Field] = {}
vector_fields: dict[str, VectorField] = {}
scalar_field_names = self.scalar_field_names
if "U" in scalar_field_names and "V" in scalar_field_names:
interp_method = XLinear_Velocity() if _is_agrid(self.data) else CGrid_Velocity()
single_fields["U"] = Field("U", self)
single_fields["V"] = Field("V", self)
vector_fields["UV"] = VectorField("UV", single_fields["U"], single_fields["V"], interp_method=interp_method)

if "W" in scalar_field_names:
single_fields["W"] = Field("W", self)
vector_fields["UVW"] = VectorField(
"UVW",
single_fields["U"],
single_fields["V"],
single_fields["W"],
interp_method=interp_method,
)

fields: dict[str, Field | VectorField] = {**single_fields, **vector_fields}
for varname in set(scalar_field_names) - set(fields.keys()):
fields[varname] = Field(str(varname), self)
for varname in set(scalar_field_names):
single_fields[varname] = Field(str(varname), self)

for vfield_name, components in self.vector_field_components.items():
interp_method = (
XLinear_Velocity() if _is_agrid(self.data, u=components[0], v=components[1]) else CGrid_Velocity()
)
Comment thread
erikvansebille marked this conversation as resolved.

component_fields = [single_fields[name] for name in components]
vector_fields[vfield_name] = VectorField(vfield_name, *component_fields, interp_method=interp_method) # type:ignore[misc,arg-type]

fields: dict[str, Field | VectorField] = {**single_fields, **vector_fields}
return list(fields.values())

@classmethod
def from_sgrid_conventions(cls, ds: xr.Dataset, mesh: Mesh | None = None) -> Self:
def from_sgrid_conventions(
cls, ds: xr.Dataset, mesh: Mesh | None, vector_fields: ptyping.VectorFields | NotSetType
) -> Self:
ds = ds.copy()
if mesh is None:
mesh = _get_mesh_type_from_sgrid_dataset(ds)
Expand All @@ -160,14 +160,56 @@ def from_sgrid_conventions(cls, ds: xr.Dataset, mesh: Mesh | None = None) -> Sel
# ds["lon"] = ds[node_dimensions[0]]
# ds["lat"] = ds[node_dimensions[1]]

model = cls(ds, mesh=mesh)
vector_fields = resolve_vector_fields(ds, vector_fields)
assert_valid_vector_fields(ds, vector_fields)

model = cls(ds, mesh=mesh, vector_field_components=vector_fields)
model._fields = model.construct_fields()
for f in model._fields:
if isinstance(f, Field):
f.interp_method = XLinear()
return model


def resolve_vector_fields(ds: xr.Dataset, vector_fields: ptyping.VectorFields | NotSetType) -> ptyping.VectorFields:
if vector_fields is NOTSET: # i.e., the default vectorfield discovery behaviour
return _default_vector_field_components(list(ds.data_vars))
return vector_fields


def assert_valid_vector_fields(ds: xr.Dataset, vector_fields: ptyping.VectorFields) -> None:
if not isinstance(vector_fields, dict):
raise ValueError(f"vector_fields must be a dictionary. Got {type(vector_fields)=!r}.")

for vfield_name, components in vector_fields.items():
if not isinstance(vfield_name, str):
raise ValueError(
f"Invalid `vector_fields` argument. Vector field name in `vector_fields` should be a string. Got field name {vfield_name!r}."
)
if not (2 <= len(components) <= 3):
raise ValueError(
f"Invalid `vector_fields` argument. Vector fields must have either 2 or 3 components. Vector field {vfield_name} has {len(components)} components."
)
for c in components:
if not isinstance(c, str):
raise ValueError(
f"Invalid `vector_fields` argument. Component names must be strings. Got component name of value {c!r}."
)

assert_vector_field_components_in_dataset(ds, vector_fields)
return


def assert_vector_field_components_in_dataset(ds: xr.Dataset, vector_fields: ptyping.VectorFields) -> None:
for components in vector_fields.values():
for c in components:
if c not in ds.data_vars:
raise ValueError(
f"Field component '{c}' not present in the source dataset, but is listed in {vector_fields=!r}. This component cannot be used in this mapping."
)
return


CONSTANT_FIELD_MODELS = {
mesh: StructuredModelData.from_sgrid_conventions(
xr.Dataset(
Expand All @@ -191,13 +233,14 @@ def from_sgrid_conventions(cls, ds: xr.Dataset, mesh: Mesh | None = None) -> Sel
),
),
mesh=mesh, # type:ignore
vector_fields={},
)
for mesh in ["flat", "spherical"]
}


class UnstructuredModelData(ModelData):
def __init__(self, data: ux.UxDataset, grid: UxGrid):
def __init__(self, data: ux.UxDataset, grid: UxGrid, vector_field_components: ptyping.VectorFields):
if not isinstance(data, ux.UxDataset):
raise ValueError(f"Expected `data` to be an uxarray.UxDataset . Got {type(data)}")

Expand All @@ -206,28 +249,25 @@ def __init__(self, data: ux.UxDataset, grid: UxGrid):

self.data = data
self.grid = grid
self.vector_field_components = vector_field_components
self.field_to_interpolator = {}
self._fields: list[Field | VectorField] | None = None

def construct_fields(self) -> list[Field | VectorField]:
single_fields: dict[str, Field] = {}
vector_fields: dict[str, VectorField] = {}
scalar_field_names = self.scalar_field_names
if "U" in scalar_field_names and "V" in scalar_field_names:
single_fields["U"] = Field("U", self)
single_fields["V"] = Field("V", self)
vector_fields["UV"] = VectorField("UV", single_fields["U"], single_fields["V"], interp_method=Ux_Velocity())

if "W" in scalar_field_names:
single_fields["W"] = Field("W", self)
vector_fields["UVW"] = VectorField(
"UVW", single_fields["U"], single_fields["V"], single_fields["W"], interp_method=Ux_Velocity()
)

fields: dict[str, Field | VectorField] = {**single_fields, **vector_fields}
for varname in set(scalar_field_names) - set(single_fields.keys()):
fields[varname] = Field(str(varname), self)
for varname in set(scalar_field_names):
single_fields[varname] = Field(str(varname), self)

for vfield_name, components in self.vector_field_components.items():
interp_method = Ux_Velocity()

component_fields = [single_fields[name] for name in components]
vector_fields[vfield_name] = VectorField(vfield_name, *component_fields, interp_method=interp_method) # type:ignore[misc, arg-type]

fields: dict[str, Field | VectorField] = {**single_fields, **vector_fields}
return list(fields.values())

def assert_valid_field_data(self, field_data: ux.UxDataArray) -> None:
Expand All @@ -239,7 +279,7 @@ def scalar_field_names(self) -> list[str]:
return list(self.data.data_vars)

@classmethod
def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: str = "spherical"):
def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: Mesh, vector_fields: ptyping.VectorFields | NotSetType):
ds_dims = list(ds.dims)
if not all(dim in ds_dims for dim in ["time", "zf", "zc"]):
raise ValueError(
Expand All @@ -248,7 +288,11 @@ def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: str = "spherical"):

grid = UxGrid(ds.uxgrid, z=ds.coords["zf"], mesh=mesh)
ds = _discover_ux_U_and_V(ds)
model = cls(ds, grid)

vector_fields = resolve_vector_fields(ds, vector_fields)
assert_valid_vector_fields(ds, vector_fields)

model = cls(ds, grid, vector_fields)
model._fields = model.construct_fields()
for f in model._fields:
if isinstance(f, Field):
Expand Down Expand Up @@ -276,6 +320,17 @@ def _get_mesh_type_from_sgrid_dataset(ds_sgrid: xr.Dataset) -> Mesh:
return "spherical" if _is_coordinate_in_degrees(ds_sgrid[fpoint_x]) else "flat"


def _default_vector_field_components(data_vars: Sequence[Hashable]) -> ptyping.VectorFields:
vars = set(data_vars)
ret: ptyping.VectorFields = {}

if {"U", "V"}.issubset(vars):
ret["UV"] = ("U", "V")
if {"U", "V", "W"}.issubset(vars):
ret["UVW"] = ("U", "V", "W")
return ret


def _is_coordinate_in_degrees(da: xr.DataArray) -> bool:
units = da.attrs.get("units")
if units is None:
Expand Down Expand Up @@ -366,10 +421,10 @@ def _select_uxinterpolator(da: ux.UxDataArray):
return None


def _is_agrid(ds: xr.Dataset) -> bool:
def _is_agrid(ds: xr.Dataset, u: str, v: str) -> bool:
# check if U and V are defined on the same dimensions
# if yes, interpret as A grid
return set(ds["U"].dims) == set(ds["V"].dims)
return set(ds[u].dims) == set(ds[v].dims)


def _get_time_interval(data: xr.DataArray | ux.UxDataArray) -> TimeInterval | None:
Expand Down
4 changes: 4 additions & 0 deletions src/parcels/_python.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# Generic Python helpers
import enum
import inspect
from collections.abc import Callable, Mapping
from typing import TypeVar

K = TypeVar("K")
V = TypeVar("V")

NotSetType = enum.Enum("NotSetType", "VALUE")
NOTSET = NotSetType.VALUE

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but would there be a benefit in also using this NOTSET when postponing setting the particle.time?

if time is None or len(time) == 0:
# do not set a time yet (because sign_dt not known)
time = np.array(np.nan)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite sure what you mean by this. Do you mean to use NOTSET instead of None for the time parameter in the init of the particleset?



def isinstance_noimport(obj, class_or_tuple):
"""A version of isinstance that does not require importing the class.
Expand Down
1 change: 1 addition & 0 deletions src/parcels/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
CfAxis = XgcmAxisDirection
XgcmAxisPosition = Literal["center", "left", "right", "inner", "outer"]
XgcmAxes = Mapping[XgcmAxisDirection, "xgcm.Axis"]
VectorFields = dict[str, tuple[str, str] | tuple[str, str, str]]


def _is_xarray_object(obj): # with no imports
Expand Down
Loading