-
Notifications
You must be signed in to change notification settings - Fork 254
api: Fix handling of staggering in injection/interpolation #2936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,7 @@ | ||
| from abc import ABC, abstractmethod | ||
| from contextlib import suppress | ||
| from functools import cached_property, wraps | ||
| from itertools import groupby | ||
|
|
||
| import numpy as np | ||
| import sympy | ||
|
|
@@ -67,12 +69,9 @@ def _extract_subdomain(variables): | |
| """ | ||
| sdms = set() | ||
| for v in variables: | ||
| try: | ||
| with suppress(AttributeError): | ||
| if v.grid.is_SubDomain: | ||
| sdms.add(v.grid) | ||
| except AttributeError: | ||
| # Variable not on a grid (Indexed for example) | ||
| pass | ||
|
|
||
| if len(sdms) > 1: | ||
| raise NotImplementedError("Sparse operation on multiple Functions defined on" | ||
|
|
@@ -230,7 +229,7 @@ def r(self): | |
| return self.sfunction.r | ||
|
|
||
| @memoized_meth | ||
| def _weights(self, subdomain=None): | ||
| def _weights(self, subdomain=None, shifts=None): | ||
| raise NotImplementedError | ||
|
|
||
| @property | ||
|
|
@@ -243,8 +242,22 @@ def _cdim(self): | |
| dims = [self.sfunction._crdim(d) for d in self._gdims] | ||
| return dims | ||
|
|
||
| def _field_shifts(self, field): | ||
| """ | ||
| Per-grid-Dimension half-cell shift induced by `field`'s staggering | ||
| (e.g. `h_x/2` for a field staggered in `x`). Returns None for | ||
| unstaggered fields. SubDomain-induced origin offsets are deliberately | ||
| ignored — they are not staggering. | ||
| """ | ||
| staggered = field.staggered | ||
| if not staggered or staggered.on_node: | ||
| return () | ||
| return tuple((d.spacing / 2) if s else 0 | ||
| for d, s in zip(field.dimensions, staggered, strict=True) | ||
| if d.is_Space) | ||
|
|
||
| @memoized_meth | ||
| def _rdim(self, subdomain=None): | ||
| def _rdim(self, subdomain=None, shifts=None): | ||
| # If the interpolation operation is limited to a SubDomain, | ||
| # use the SubDimensions of that SubDomain | ||
| if subdomain: | ||
|
|
@@ -254,7 +267,7 @@ def _rdim(self, subdomain=None): | |
|
|
||
| # Make radius dimension conditional to avoid OOB | ||
| rdims = [] | ||
| pos = self.sfunction._position_map.values() | ||
| pos = self.sfunction._position_map(shifts=shifts).values() | ||
|
|
||
| for (d, rd, p) in zip(gdims, self._cdim, pos, strict=True): | ||
| # Add conditional to avoid OOB | ||
|
|
@@ -279,12 +292,10 @@ def _augment_implicit_dims(self, implicit_dims, extras=None): | |
| # dimensions of that SubDomain from any extra dimensions found | ||
| edims = [] | ||
| for v in extras: | ||
| try: | ||
| with suppress(AttributeError): | ||
| if v.grid.is_SubDomain: | ||
| edims.extend([d for d in v.grid.dimensions | ||
| if d.is_Sub and d.root in self._gdims]) | ||
| except AttributeError: | ||
| pass | ||
|
|
||
| gdims = filter_ordered(edims + list(self._gdims)) | ||
| extra = filter_ordered([i for v in extras for i in v.dimensions | ||
|
|
@@ -300,27 +311,34 @@ def _augment_implicit_dims(self, implicit_dims, extras=None): | |
| idims = extra + as_tuple(implicit_dims) + self.sfunction.dimensions | ||
| return tuple(idims) | ||
|
|
||
| def _coeff_temps(self, implicit_dims): | ||
| def _coeff_temps(self, implicit_dims, shifts=None): | ||
| return [] | ||
|
|
||
| def _positions(self, implicit_dims): | ||
| def _positions(self, implicit_dims, shifts=None): | ||
| return [Eq(v, INT(floor(k)), implicit_dims=implicit_dims) | ||
| for k, v in self.sfunction._position_map.items()] | ||
| for k, v in self.sfunction._position_map(shifts=shifts).items()] | ||
|
|
||
| def _interp_idx(self, variables, implicit_dims=None, subdomain=None): | ||
| def _interp_idx(self, variables, implicit_dims=None, subdomain=None, | ||
| shifts=None): | ||
| """ | ||
| Generate interpolation indices for the DiscreteFunctions in ``variables``. | ||
| Generate interpolation indices for the DiscreteFunctions in `variables`. | ||
|
|
||
| `shifts` is a per-Dimension physical offset for the target field's | ||
| origin: it only affects the integer position symbol via the position | ||
| map (`pos = floor((c - o - shift)/h)`). The index substitution itself | ||
| is unchanged — any staggered offset in a field's own symbolic access is | ||
| absorbed by Devito's normal indexing. | ||
| """ | ||
| pos = self.sfunction._position_map.values() | ||
| pos = self.sfunction._position_map(shifts=shifts).values() | ||
|
|
||
| # Temporaries for the position | ||
| temps = self._positions(implicit_dims) | ||
| temps = self._positions(implicit_dims, shifts=shifts) | ||
|
|
||
| # Coefficient symbol expression | ||
| temps.extend(self._coeff_temps(implicit_dims)) | ||
| temps.extend(self._coeff_temps(implicit_dims, shifts=shifts)) | ||
|
|
||
| # Substitution mapper for variables | ||
| mapper = self._rdim(subdomain=subdomain).getters | ||
| mapper = self._rdim(subdomain=subdomain, shifts=shifts).getters | ||
|
|
||
| # Index substitution to make in variables | ||
| subs = { | ||
|
|
@@ -337,7 +355,7 @@ def _interp_idx(self, variables, implicit_dims=None, subdomain=None): | |
| @check_coords | ||
| def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None): | ||
| """ | ||
| Generate equations interpolating an arbitrary expression into ``self``. | ||
| Generate equations interpolating an arbitrary expression into `self`. | ||
|
|
||
| Parameters | ||
| ---------- | ||
|
|
@@ -375,7 +393,7 @@ def inject(self, field, expr, implicit_dims=None): | |
|
|
||
| def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None): | ||
| """ | ||
| Generate equations interpolating an arbitrary expression into ``self``. | ||
| Generate equations interpolating an arbitrary expression into `self`. | ||
|
|
||
| Parameters | ||
| ---------- | ||
|
|
@@ -389,16 +407,13 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None | |
| the operator. | ||
| """ | ||
| # Derivatives must be evaluated before the introduction of indirect accesses | ||
| try: | ||
| _expr = expr.evaluate | ||
| except AttributeError: | ||
| # E.g., a generic SymPy expression or a number | ||
| _expr = expr | ||
| with suppress(AttributeError): | ||
| expr = expr._eval_at(self.sfunction).evaluate | ||
|
|
||
| if self_subs is None: | ||
| self_subs = {} | ||
|
|
||
| variables = list(retrieve_function_carriers(_expr)) | ||
| variables = list(retrieve_function_carriers(expr)) | ||
| subdomain = _extract_subdomain(variables) | ||
|
|
||
| # Implicit dimensions | ||
|
|
@@ -413,7 +428,7 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None | |
| summands = [Eq(rhs, 0., implicit_dims=implicit_dims)] | ||
| # Substitute coordinate base symbols into the interpolation coefficients | ||
| weights = self._weights(subdomain=subdomain) | ||
| summands.extend([Inc(rhs, (weights * _expr).xreplace(idx_subs), | ||
| summands.extend([Inc(rhs, (weights * expr).xreplace(idx_subs), | ||
| implicit_dims=implicit_dims)]) | ||
|
|
||
| # Write/Incr `self` | ||
|
|
@@ -451,35 +466,48 @@ def _inject(self, field, expr, implicit_dims=None): | |
|
|
||
| subdomain = _extract_subdomain(fields) | ||
|
|
||
| # Derivatives must be evaluated before the introduction of indirect accesses | ||
| try: | ||
| _exprs = tuple(e.evaluate for e in exprs) | ||
| except AttributeError: | ||
| # E.g., a generic SymPy expression or a number | ||
| _exprs = exprs | ||
|
|
||
| variables = list(v for e in _exprs for v in retrieve_function_carriers(e)) | ||
|
|
||
| # Implicit dimensions | ||
| implicit_dims = self._augment_implicit_dims(implicit_dims, variables) | ||
| # Move all temporaries inside inner loop to improve parallelism | ||
| # Can only be done for inject as interpolation need a temporary | ||
| # summing temp that wouldn't allow collapsing | ||
| implicit_dims = implicit_dims + tuple(r.parent for r in | ||
| self._rdim(subdomain=subdomain)) | ||
|
|
||
| # List of indirection indices for all adjacent grid points | ||
| finterp = fields + as_tuple(variables) | ||
| idx_subs, temps = self._interp_idx(finterp, implicit_dims=implicit_dims, | ||
| subdomain=subdomain) | ||
|
|
||
| # Substitute coordinate base symbols into the interpolation coefficients | ||
| eqns = [Inc(_field.xreplace(idx_subs), | ||
| (self._weights(subdomain=subdomain) * _expr).xreplace(idx_subs), | ||
| implicit_dims=implicit_dims) | ||
| for (_field, _expr) in zip(fields, _exprs, strict=True)] | ||
|
|
||
| return temps + eqns | ||
| # Derivatives must be evaluated before the introduction of indirect | ||
| # accesses. Variables are sampled at their own grid location; the | ||
| # position map for the target field carries the staggering so the | ||
| # field's stencil neighbors land on the right indices. | ||
| with suppress(AttributeError): | ||
| exprs = tuple(e._eval_at(f).evaluate | ||
| for e, f in zip(exprs, fields, strict=True)) | ||
|
|
||
| eqns = [] | ||
| temps = [] | ||
| # We need to create one set of equations (temps and and coeffs) per staggering | ||
| # field in which we inject as the reference index depends on the field's origin | ||
| for _, g in groupby(zip(fields, exprs, strict=True), lambda f: f[0].staggered): | ||
| g_fields, g_exprs = zip(*g, strict=True) | ||
| variables = list(v for e in g_exprs for v in retrieve_function_carriers(e)) | ||
|
|
||
| implicit_dims = self._augment_implicit_dims(implicit_dims, variables) | ||
|
|
||
| # All fields in a single injection share the same staggering by | ||
| # construction (they are written together at the same indices), so | ||
| # derive shifts from the first field. | ||
| shifts = self._field_shifts(g_fields[0]) | ||
|
|
||
| # Move all temporaries inside inner loop to improve parallelism | ||
| # Can only be done for inject as interpolation needs a summing temp | ||
| # that wouldn't allow collapsing | ||
| implicit_dims = implicit_dims + tuple(r.parent for r in | ||
| self._rdim(subdomain=subdomain, | ||
| shifts=shifts)) | ||
|
|
||
| # List of indirection indices for all adjacent grid points | ||
| idx_subs, _temps = self._interp_idx(list(g_fields) + variables, | ||
| implicit_dims=implicit_dims, | ||
| subdomain=subdomain, shifts=shifts) | ||
|
|
||
| w = self._weights(subdomain=subdomain, shifts=shifts) | ||
| temps.extend(_temps) | ||
| eqns.extend([Inc(f.xreplace(idx_subs), (w * e).xreplace(idx_subs), | ||
| implicit_dims=implicit_dims) | ||
| for f, e in zip(g_fields, g_exprs, strict=True)]) | ||
|
|
||
| return filter_ordered(temps) + eqns | ||
|
|
||
|
|
||
| class LinearInterpolator(WeightedInterpolator): | ||
|
|
@@ -495,24 +523,30 @@ class LinearInterpolator(WeightedInterpolator): | |
| _name = 'linear' | ||
|
|
||
| @memoized_meth | ||
| def _weights(self, subdomain=None): | ||
| rdim = self._rdim(subdomain=subdomain) | ||
| def _weights(self, subdomain=None, shifts=None): | ||
| rdim = self._rdim(subdomain=subdomain, shifts=shifts) | ||
| c = [(1 - p) * (1 - r) + p * r | ||
| for (p, d, r) in zip(self._point_symbols, self._gdims, rdim, strict=True)] | ||
| for (p, d, r) in zip(self._point_symbols(shifts), self._gdims, rdim, | ||
| strict=True)] | ||
| return Mul(*c) | ||
|
|
||
| @cached_property | ||
| def _point_symbols(self): | ||
| @memoized_meth | ||
| def _point_symbols(self, shifts=None): | ||
| """Symbol for coordinate value in each Dimension of the point.""" | ||
| dtype = self.sfunction.coordinates.dtype | ||
| return DimensionTuple(*(Symbol(name=f'p{d}', dtype=dtype) | ||
| for d in self.grid.dimensions), | ||
| getters=self.grid.dimensions) | ||
| symbols = [] | ||
| for d in self.grid.dimensions: | ||
| if shifts and shifts[self.grid.dimensions.index(d)] != 0: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thinking about this -- if you make
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no guaranty this is called with |
||
| symbols.append(Symbol(name=f'p{d}_s1', dtype=dtype)) | ||
| else: | ||
| symbols.append(Symbol(name=f'p{d}', dtype=dtype)) | ||
| return DimensionTuple(*symbols, getters=self.grid.dimensions) | ||
|
|
||
| def _coeff_temps(self, implicit_dims): | ||
| def _coeff_temps(self, implicit_dims, shifts=None): | ||
| # Positions | ||
| pmap = self.sfunction._position_map | ||
| poseq = [Eq(self._point_symbols[d], pos - floor(pos), | ||
| pmap = self.sfunction._position_map(shifts=shifts) | ||
| psyms = self._point_symbols(shifts) | ||
| poseq = [Eq(psyms[d], pos - floor(pos), | ||
| implicit_dims=implicit_dims) | ||
| for (d, pos) in zip(self._gdims, pmap.keys(), strict=True)] | ||
| return poseq | ||
|
|
@@ -531,23 +565,24 @@ class PrecomputedInterpolator(WeightedInterpolator): | |
|
|
||
| _name = 'precomp' | ||
|
|
||
| def _positions(self, implicit_dims): | ||
| def _positions(self, implicit_dims, shifts=None): | ||
| if self.sfunction.gridpoints_data is None: | ||
| return super()._positions(implicit_dims) | ||
| return super()._positions(implicit_dims, shifts=shifts) | ||
| else: | ||
| # No position temp as we have directly the gridpoints | ||
| return[Eq(p, k, implicit_dims=implicit_dims) | ||
| for (k, p) in self.sfunction._position_map.items()] | ||
| for (k, p) in self.sfunction._position_map(shifts=shifts).items()] | ||
|
|
||
| @property | ||
| def interpolation_coeffs(self): | ||
| return self.sfunction.interpolation_coeffs | ||
|
|
||
| @memoized_meth | ||
| def _weights(self, subdomain=None): | ||
| def _weights(self, subdomain=None, shifts=None): | ||
| ddim, cdim = self.interpolation_coeffs.dimensions[1:] | ||
| mappers = [{ddim: ri, cdim: rd-rd.parent.symbolic_min} | ||
| for (ri, rd) in enumerate(self._rdim(subdomain=subdomain))] | ||
| for (ri, rd) in enumerate(self._rdim(subdomain=subdomain, | ||
| shifts=shifts))] | ||
| return Mul(*[self.interpolation_coeffs.subs(mapper) | ||
| for mapper in mappers]) | ||
|
|
||
|
|
@@ -592,8 +627,8 @@ def interpolation_coeffs(self): | |
| return tuple(coeffs) | ||
|
|
||
| @memoized_meth | ||
| def _weights(self, subdomain=None): | ||
| rdims = self._rdim(subdomain=subdomain) | ||
| def _weights(self, subdomain=None, shifts=None): | ||
| rdims = self._rdim(subdomain=subdomain, shifts=shifts) | ||
| return Mul(*[ | ||
| w._subs(rd, rd-rd.parent.symbolic_min) | ||
| for (rd, w) in zip(rdims, self.interpolation_coeffs, strict=True) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably you can leave a blank line here