Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,9 @@ def _eval_at(self, func):
# compare staggering
if self.expr.staggered == func.staggered and self.expr.is_Function:
return self
# Time derivatives are not affected by space staggering
if all(d.is_Time for d in self.dims):
return self

# Check if x0's keys come from a DerivedDimension
x0 = func.indices_ref.getters
Expand Down
185 changes: 110 additions & 75 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod
from contextlib import suppress
from functools import cached_property, wraps
from itertools import groupby

import numpy as np
import sympy
Expand Down Expand Up @@ -67,12 +69,9 @@ def _extract_subdomain(variables):
"""
sdms = set()
for v in variables:
try:
with suppress(AttributeError):
if v.grid.is_SubDomain:
sdms.add(v.grid)
except AttributeError:
# Variable not on a grid (Indexed for example)
pass

if len(sdms) > 1:
raise NotImplementedError("Sparse operation on multiple Functions defined on"
Expand Down Expand Up @@ -230,7 +229,7 @@ def r(self):
return self.sfunction.r

@memoized_meth
def _weights(self, subdomain=None):
def _weights(self, subdomain=None, shifts=None):
raise NotImplementedError

@property
Expand All @@ -243,8 +242,22 @@ def _cdim(self):
dims = [self.sfunction._crdim(d) for d in self._gdims]
return dims

def _field_shifts(self, field):
"""
Per-grid-Dimension half-cell shift induced by `field`'s staggering
(e.g. `h_x/2` for a field staggered in `x`). Returns None for
unstaggered fields. SubDomain-induced origin offsets are deliberately
ignored — they are not staggering.
"""
staggered = field.staggered
if not staggered or staggered.on_node:
return ()
return tuple((d.spacing / 2) if s else 0
for d, s in zip(field.dimensions, staggered, strict=True)
if d.is_Space)

@memoized_meth
def _rdim(self, subdomain=None):
def _rdim(self, subdomain=None, shifts=None):
# If the interpolation operation is limited to a SubDomain,
# use the SubDimensions of that SubDomain
if subdomain:
Expand All @@ -254,7 +267,7 @@ def _rdim(self, subdomain=None):

# Make radius dimension conditional to avoid OOB
rdims = []
pos = self.sfunction._position_map.values()
pos = self.sfunction._position_map(shifts=shifts).values()

for (d, rd, p) in zip(gdims, self._cdim, pos, strict=True):
# Add conditional to avoid OOB
Expand All @@ -279,12 +292,10 @@ def _augment_implicit_dims(self, implicit_dims, extras=None):
# dimensions of that SubDomain from any extra dimensions found
edims = []
for v in extras:
try:
with suppress(AttributeError):
if v.grid.is_SubDomain:
edims.extend([d for d in v.grid.dimensions
if d.is_Sub and d.root in self._gdims])
except AttributeError:
pass

gdims = filter_ordered(edims + list(self._gdims))
extra = filter_ordered([i for v in extras for i in v.dimensions
Expand All @@ -300,27 +311,34 @@ def _augment_implicit_dims(self, implicit_dims, extras=None):
idims = extra + as_tuple(implicit_dims) + self.sfunction.dimensions
return tuple(idims)

def _coeff_temps(self, implicit_dims):
def _coeff_temps(self, implicit_dims, shifts=None):
return []

def _positions(self, implicit_dims):
def _positions(self, implicit_dims, shifts=None):
return [Eq(v, INT(floor(k)), implicit_dims=implicit_dims)
for k, v in self.sfunction._position_map.items()]
for k, v in self.sfunction._position_map(shifts=shifts).items()]

def _interp_idx(self, variables, implicit_dims=None, subdomain=None):
def _interp_idx(self, variables, implicit_dims=None, subdomain=None,
shifts=None):
"""
Generate interpolation indices for the DiscreteFunctions in ``variables``.
Generate interpolation indices for the DiscreteFunctions in `variables`.

`shifts` is a per-Dimension physical offset for the target field's
origin: it only affects the integer position symbol via the position
map (`pos = floor((c - o - shift)/h)`). The index substitution itself
is unchanged — any staggered offset in a field's own symbolic access is
absorbed by Devito's normal indexing.
"""
pos = self.sfunction._position_map.values()
pos = self.sfunction._position_map(shifts=shifts).values()

# Temporaries for the position
temps = self._positions(implicit_dims)
temps = self._positions(implicit_dims, shifts=shifts)

# Coefficient symbol expression
temps.extend(self._coeff_temps(implicit_dims))
temps.extend(self._coeff_temps(implicit_dims, shifts=shifts))

# Substitution mapper for variables
mapper = self._rdim(subdomain=subdomain).getters
mapper = self._rdim(subdomain=subdomain, shifts=shifts).getters

# Index substitution to make in variables
subs = {
Expand All @@ -337,7 +355,7 @@ def _interp_idx(self, variables, implicit_dims=None, subdomain=None):
@check_coords
def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None):
"""
Generate equations interpolating an arbitrary expression into ``self``.
Generate equations interpolating an arbitrary expression into `self`.

Parameters
----------
Expand Down Expand Up @@ -375,7 +393,7 @@ def inject(self, field, expr, implicit_dims=None):

def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None):
"""
Generate equations interpolating an arbitrary expression into ``self``.
Generate equations interpolating an arbitrary expression into `self`.

Parameters
----------
Expand All @@ -389,16 +407,13 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None
the operator.
"""
# Derivatives must be evaluated before the introduction of indirect accesses
try:
_expr = expr.evaluate
except AttributeError:
# E.g., a generic SymPy expression or a number
_expr = expr
with suppress(AttributeError):
expr = expr._eval_at(self.sfunction).evaluate

if self_subs is None:
self_subs = {}

variables = list(retrieve_function_carriers(_expr))
variables = list(retrieve_function_carriers(expr))
subdomain = _extract_subdomain(variables)

# Implicit dimensions
Expand All @@ -413,7 +428,7 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None
summands = [Eq(rhs, 0., implicit_dims=implicit_dims)]
# Substitute coordinate base symbols into the interpolation coefficients
weights = self._weights(subdomain=subdomain)
summands.extend([Inc(rhs, (weights * _expr).xreplace(idx_subs),
summands.extend([Inc(rhs, (weights * expr).xreplace(idx_subs),
implicit_dims=implicit_dims)])

# Write/Incr `self`
Expand Down Expand Up @@ -451,35 +466,48 @@ def _inject(self, field, expr, implicit_dims=None):

subdomain = _extract_subdomain(fields)

# Derivatives must be evaluated before the introduction of indirect accesses
try:
_exprs = tuple(e.evaluate for e in exprs)
except AttributeError:
# E.g., a generic SymPy expression or a number
_exprs = exprs

variables = list(v for e in _exprs for v in retrieve_function_carriers(e))

# Implicit dimensions
implicit_dims = self._augment_implicit_dims(implicit_dims, variables)
# Move all temporaries inside inner loop to improve parallelism
# Can only be done for inject as interpolation need a temporary
# summing temp that wouldn't allow collapsing
implicit_dims = implicit_dims + tuple(r.parent for r in
self._rdim(subdomain=subdomain))

# List of indirection indices for all adjacent grid points
finterp = fields + as_tuple(variables)
idx_subs, temps = self._interp_idx(finterp, implicit_dims=implicit_dims,
subdomain=subdomain)

# Substitute coordinate base symbols into the interpolation coefficients
eqns = [Inc(_field.xreplace(idx_subs),
(self._weights(subdomain=subdomain) * _expr).xreplace(idx_subs),
implicit_dims=implicit_dims)
for (_field, _expr) in zip(fields, _exprs, strict=True)]

return temps + eqns
# Derivatives must be evaluated before the introduction of indirect
# accesses. Variables are sampled at their own grid location; the
# position map for the target field carries the staggering so the
# field's stencil neighbors land on the right indices.
with suppress(AttributeError):
exprs = tuple(e._eval_at(f).evaluate
for e, f in zip(exprs, fields, strict=True))

eqns = []
temps = []
# We need to create one set of equations (temps and and coeffs) per staggering
# field in which we inject as the reference index depends on the field's origin
for _, g in groupby(zip(fields, exprs, strict=True), lambda f: f[0].staggered):
g_fields, g_exprs = zip(*g, strict=True)
variables = list(v for e in g_exprs for v in retrieve_function_carriers(e))

implicit_dims = self._augment_implicit_dims(implicit_dims, variables)

# All fields in a single injection share the same staggering by
# construction (they are written together at the same indices), so
# derive shifts from the first field.
shifts = self._field_shifts(g_fields[0])

# Move all temporaries inside inner loop to improve parallelism
# Can only be done for inject as interpolation needs a summing temp
# that wouldn't allow collapsing
implicit_dims = implicit_dims + tuple(r.parent for r in
self._rdim(subdomain=subdomain,
shifts=shifts))

# List of indirection indices for all adjacent grid points
idx_subs, _temps = self._interp_idx(list(g_fields) + variables,
implicit_dims=implicit_dims,
subdomain=subdomain, shifts=shifts)

w = self._weights(subdomain=subdomain, shifts=shifts)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably you can leave a blank line here

temps.extend(_temps)
eqns.extend([Inc(f.xreplace(idx_subs), (w * e).xreplace(idx_subs),
implicit_dims=implicit_dims)
for f, e in zip(g_fields, g_exprs, strict=True)])

return filter_ordered(temps) + eqns


class LinearInterpolator(WeightedInterpolator):
Expand All @@ -495,24 +523,30 @@ class LinearInterpolator(WeightedInterpolator):
_name = 'linear'

@memoized_meth
def _weights(self, subdomain=None):
rdim = self._rdim(subdomain=subdomain)
def _weights(self, subdomain=None, shifts=None):
rdim = self._rdim(subdomain=subdomain, shifts=shifts)
c = [(1 - p) * (1 - r) + p * r
for (p, d, r) in zip(self._point_symbols, self._gdims, rdim, strict=True)]
for (p, d, r) in zip(self._point_symbols(shifts), self._gdims, rdim,
strict=True)]
return Mul(*c)

@cached_property
def _point_symbols(self):
@memoized_meth
def _point_symbols(self, shifts=None):
"""Symbol for coordinate value in each Dimension of the point."""
dtype = self.sfunction.coordinates.dtype
return DimensionTuple(*(Symbol(name=f'p{d}', dtype=dtype)
for d in self.grid.dimensions),
getters=self.grid.dimensions)
symbols = []
for d in self.grid.dimensions:
if shifts and shifts[self.grid.dimensions.index(d)] != 0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thinking about this -- if you make _field_shift return a (0, 0, 0, 0) then

  1. you simplify _field_shift's implementation itself
  2. don't need extra checks here
  3. make everything ... "smoother"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no guaranty this is called with _field_shift it can be called with just _point_symbols() so None is a valid input, changing _field_shift won't change the need for this

symbols.append(Symbol(name=f'p{d}_s1', dtype=dtype))
else:
symbols.append(Symbol(name=f'p{d}', dtype=dtype))
return DimensionTuple(*symbols, getters=self.grid.dimensions)

def _coeff_temps(self, implicit_dims):
def _coeff_temps(self, implicit_dims, shifts=None):
# Positions
pmap = self.sfunction._position_map
poseq = [Eq(self._point_symbols[d], pos - floor(pos),
pmap = self.sfunction._position_map(shifts=shifts)
psyms = self._point_symbols(shifts)
poseq = [Eq(psyms[d], pos - floor(pos),
implicit_dims=implicit_dims)
for (d, pos) in zip(self._gdims, pmap.keys(), strict=True)]
return poseq
Expand All @@ -531,23 +565,24 @@ class PrecomputedInterpolator(WeightedInterpolator):

_name = 'precomp'

def _positions(self, implicit_dims):
def _positions(self, implicit_dims, shifts=None):
if self.sfunction.gridpoints_data is None:
return super()._positions(implicit_dims)
return super()._positions(implicit_dims, shifts=shifts)
else:
# No position temp as we have directly the gridpoints
return[Eq(p, k, implicit_dims=implicit_dims)
for (k, p) in self.sfunction._position_map.items()]
for (k, p) in self.sfunction._position_map(shifts=shifts).items()]

@property
def interpolation_coeffs(self):
return self.sfunction.interpolation_coeffs

@memoized_meth
def _weights(self, subdomain=None):
def _weights(self, subdomain=None, shifts=None):
ddim, cdim = self.interpolation_coeffs.dimensions[1:]
mappers = [{ddim: ri, cdim: rd-rd.parent.symbolic_min}
for (ri, rd) in enumerate(self._rdim(subdomain=subdomain))]
for (ri, rd) in enumerate(self._rdim(subdomain=subdomain,
shifts=shifts))]
return Mul(*[self.interpolation_coeffs.subs(mapper)
for mapper in mappers])

Expand Down Expand Up @@ -592,8 +627,8 @@ def interpolation_coeffs(self):
return tuple(coeffs)

@memoized_meth
def _weights(self, subdomain=None):
rdims = self._rdim(subdomain=subdomain)
def _weights(self, subdomain=None, shifts=None):
rdims = self._rdim(subdomain=subdomain, shifts=shifts)
return Mul(*[
w._subs(rd, rd-rd.parent.symbolic_min)
for (rd, w) in zip(rdims, self.interpolation_coeffs, strict=True)
Expand Down
4 changes: 2 additions & 2 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,12 +960,12 @@ def indices(self):
"""The indices of the object."""
return DimensionTuple(*self.args, getters=self.dimensions)

@property
@cached_property
def indices_ref(self):
"""The reference indices of the object (indices at first creation)."""
return DimensionTuple(*self.function.indices, getters=self.dimensions)

@property
@cached_property
def origin(self):
"""
Origin of the AbstractFunction in term of Dimension
Expand Down
3 changes: 2 additions & 1 deletion devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,8 @@ def _eval_at(self, func):
for d in self.dimensions:
try:
if self.indices_ref[d] is not func.indices_ref[d]:
f_idx = func.indices_ref[d]._subs(func.dimensions[d], d)
d0 = func.dimensions.get(d, d)
f_idx = func.indices_ref[d]._subs(d0, d)
mapper[self.indices_ref[d]] = f_idx
except KeyError:
pass
Expand Down
Loading
Loading