Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/underworld3/discretisation/discretisation_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1832,6 +1832,19 @@ def _deform_mesh(self, new_coords: numpy.ndarray, verbose=False):
if solver is not None and hasattr(solver, "is_setup"):
solver.is_setup = False

# Invalidate caches whose contents become stale when mesh
# coordinates change. Matches the cache hygiene already
# performed by mesh.adapt() and _legacy_access. Without
# these, uw.function.evaluate (and any user code that keys
# lookups off _topology_version) can return values
# computed against the pre-deform mesh.
self._evaluation_hash = None
self._evaluation_interpolated_results = None
if hasattr(self, '_dminterpolation_cache'):
self._dminterpolation_cache.invalidate_all(
reason="mesh deformed")
self._topology_version += 1

# Propagate coordinate changes to registered submeshes
for submesh in self._registered_submeshes:
submesh.sync_coordinates_from_parent()
Expand Down
91 changes: 91 additions & 0 deletions tests/test_0825_deform_mesh_cache_invalidation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Regression test for cache invalidation on Mesh._deform_mesh.

After a coordinate change, three mesh-level caches need to be marked
stale so subsequent lookups recompute against the new geometry:

- self._evaluation_hash + self._evaluation_interpolated_results
(the legacy uw.function.evaluate fast-path cache)
- self._dminterpolation_cache (DMInterpolation structures keyed on
coord-hash; encode cell residency and reference->physical maps)
- self._topology_version (a counter that downstream caches consult)

mesh.adapt() and _legacy_access already invalidate these. Before this
fix, _deform_mesh did not, leaving a gap where user code that re-uses
the same coord array across a deformation could hit a stale cache
built on the pre-deform mesh.

This test warms each cache (by performing an evaluate that populates
them), then calls _deform_mesh and verifies the caches are cleared
and the topology version has been bumped.
"""

import numpy as np
import pytest

import underworld3 as uw


pytestmark = [pytest.mark.level_1, pytest.mark.tier_a]


def _build_mesh_with_field():
"""A small structured-quad mesh plus a scalar field to drive evals."""
mesh = uw.meshing.StructuredQuadBox(
elementRes=(8, 8),
minCoords=(0.0, 0.0),
maxCoords=(1.0, 1.0),
)
s = uw.discretisation.MeshVariable(
"S_cache_test", mesh, 1, degree=2, continuous=True)
s.data[:, 0] = 1.0 # constant — exact value irrelevant
return mesh, s


def _warm_caches(mesh, var):
"""Trigger an evaluate that populates _evaluation_hash and
_dminterpolation_cache."""
coords = np.asarray(mesh.X.coords, dtype=np.double)
# uw.function.evaluate goes through the petsc_interpolate path,
# which populates _dminterpolation_cache.
_ = uw.function.evaluate(var.sym, coords)


def test_topology_version_bumps_on_deform():
mesh, s = _build_mesh_with_field()
_warm_caches(mesh, s)
before = mesh._topology_version
coords = np.asarray(mesh.X.coords)
perturbed = coords.copy()
perturbed[:, 0] += 1.0e-3 * coords[:, 1]
mesh._deform_mesh(perturbed)
assert mesh._topology_version > before, (
"_topology_version should be incremented by _deform_mesh")


def test_evaluation_hash_invalidated_on_deform():
mesh, s = _build_mesh_with_field()
_warm_caches(mesh, s)
# uw.function.evaluate may have populated _evaluation_hash
coords = np.asarray(mesh.X.coords)
perturbed = coords.copy()
perturbed[:, 1] += 5.0e-4 * coords[:, 0]
mesh._deform_mesh(perturbed)
assert mesh._evaluation_hash is None
assert mesh._evaluation_interpolated_results is None


def test_dminterpolation_cache_invalidated_on_deform():
mesh, s = _build_mesh_with_field()
_warm_caches(mesh, s)
# The cache should have at least one entry after the warm-up evaluate.
n_before = len(getattr(
mesh._dminterpolation_cache, "_cache", {}))
coords = np.asarray(mesh.X.coords)
perturbed = coords.copy()
perturbed[:, 1] += 1.0e-3 * coords[:, 0]
mesh._deform_mesh(perturbed)
n_after = len(getattr(
mesh._dminterpolation_cache, "_cache", {}))
assert n_after == 0, (
f"Expected DMInterpolation cache to be empty after deform; "
f"got {n_after} entries (before deform: {n_before})")
Loading