diff --git a/docs/beginner/parameters.md b/docs/beginner/parameters.md index 1d18e301..a82b48e2 100644 --- a/docs/beginner/parameters.md +++ b/docs/beginner/parameters.md @@ -251,6 +251,47 @@ mpirun -np 256 python convection.py \ -uw_max_iterations 100 ``` +## Using Parameters with Solvers: Expressions + +When passing parameters to solver constitutive models, wrap them in +`uw.expression()`. This creates a named symbolic container that the JIT +compiler can update efficiently — changing the value between time steps +does **not** trigger recompilation of the C extension. + +```python +import underworld3 as uw + +# Define parameters +VISCOSITY = 1e21 # Named constant +MODULUS = 1e10 # Named constant +DT = 0.01 # Timestep + +params = uw.Params( + uw_viscosity = VISCOSITY, + uw_modulus = MODULUS, +) + +# Wrap in expressions for solver use +eta = uw.expression("eta", params.uw_viscosity) +mu = uw.expression("mu", params.uw_modulus) +dt_e = uw.expression("dt_e", DT) + +# Pass expressions to constitutive model +stokes.constitutive_model.Parameters.shear_viscosity_0 = eta +stokes.constitutive_model.Parameters.shear_modulus = mu +stokes.constitutive_model.Parameters.dt_elastic = dt_e + +# Time-stepping: change dt without recompilation +for step in range(100): + dt_e.sym = compute_new_timestep() # Updates value, ~0ms + stokes.solve() # No JIT rebuild needed +``` + +Without expressions, changing a solver parameter between steps requires +setting `_force_setup=True` on the solve call, which triggers a full JIT +recompilation (~5–15 seconds). With expressions, parameter updates go +through PETSc's `constants[]` array and cost essentially nothing. + ## Angle Units Angles work naturally - you can define in degrees and provide radians (or vice versa): diff --git a/docs/developer/UW3_Developers_MathematicalObjects.md b/docs/developer/UW3_Developers_MathematicalObjects.md index 0f4a323a..95551a1f 100644 --- a/docs/developer/UW3_Developers_MathematicalObjects.md +++ b/docs/developer/UW3_Developers_MathematicalObjects.md @@ -428,13 +428,13 @@ The JIT compilation system needs to: 1. Identify SymPy Function atoms in expressions 2. Map them to PETSc vector components 3. Generate C code with appropriate substitutions +4. Allow constant parameters to change without recompilation ## The Solution: Transparent SymPy Objects The mathematical object system preserves JIT compatibility by ensuring all operations return pure SymPy objects: ```python - # User writes natural syntax momentum = density * velocity @@ -447,23 +447,50 @@ atoms = momentum.atoms(sympy.Function) # Finds V_0, V_1 # Maps to velocity.fn for PETSc substitution ``` -## Expression Unwrapping +## Expression Unwrapping and Constants -The `unwrap()` function resolves nested expressions before compilation: +The JIT compiler performs a **two-phase unwrap** on expressions: -```python +1. **Phase 1 — Constants extraction**: `UWexpression` atoms that resolve to + pure numbers (no spatial/field dependencies) are replaced with + `_JITConstant` symbols that render as `constants[i]` in C code. + +2. **Phase 2 — Full unwrap**: Remaining `UWexpression` atoms are expanded + to their numerical values and baked into the C code. +```python # Expression with nested UWexpressions complex_expr = alpha * (temperature - T0) * velocity -# unwrap() substitutes all UWexpression.sym values -unwrapped = unwrap(complex_expr) -# Result: 2e-5 * (T(x,y,z) - 293) * Matrix([[V_0(x,y,z)], [V_1(x,y,z)]]) +# Phase 1: alpha and T0 are constants → constants[0], constants[1] +# Phase 2: temperature and velocity are field variables → petsc_a[], petsc_u[] + +# Generated C code: +# result = constants[0] * (petsc_a[0] - constants[1]) * petsc_u[0]; +``` + +This means **changing `alpha` or `T0` between solves does not require +recompilation** — only `PetscDSSetConstants()` is called (~0ms vs ~10s). + +## Why Use UWexpressions for Solver Parameters + +**UWexpressions are the preferred way to define solver parameters.** +Using raw numbers forces recompilation when values change: -# JIT compilation proceeds normally -compiled = uw.systems.compile(unwrapped) +```python +# PREFERRED — expression parameter, efficient for time-stepping +eta = uw.expression("eta", 1e21) +stokes.constitutive_model.Parameters.shear_viscosity_0 = eta +# Changing eta.sym later → no recompilation + +# AVOID for time-varying parameters — raw number, requires rebuild +stokes.constitutive_model.Parameters.shear_viscosity_0 = 1e21 +# Changing this later → full JIT rebuild (~10s) ``` +This is particularly important for viscoelastic and Navier-Stokes +solvers where parameters like `dt_elastic` change every time step. + # Migration from Legacy Patterns ## Automatic Migration Patterns diff --git a/docs/developer/guides/HOW-TO-WRITE-UW3-SCRIPTS.md b/docs/developer/guides/HOW-TO-WRITE-UW3-SCRIPTS.md index 4070330f..1646bd49 100644 --- a/docs/developer/guides/HOW-TO-WRITE-UW3-SCRIPTS.md +++ b/docs/developer/guides/HOW-TO-WRITE-UW3-SCRIPTS.md @@ -610,16 +610,34 @@ def test_swarm_functionality(): ### JIT Compilation Issues -If you see generated C code with symbolic expressions instead of numbers: +**Slow time-stepping loops**: If each `solver.solve()` call takes 10+ seconds +in a time-stepping loop, the solver is probably recompiling the JIT extension +every step. Use `UWexpression` objects for any parameter that changes between +steps: + +```python +# FAST — expression parameter, no recompilation on change +dt_e = uw.expression("dt_e", 0.01) +model.Parameters.dt_elastic = dt_e + +for step in range(100): + dt_e.sym = compute_timestep() # Updates constants[], ~0ms + solver.solve() # No JIT rebuild +``` + +**Symbolic names in generated C code**: If you see LaTeX-like names in the +generated C code (e.g., `\eta` instead of a number or `constants[i]`): ```text -// ERROR symptom in generated code: -out[0] = 1.0/{ \eta \hspace{ 0.0006pt } }; // Should be numeric! +// ERROR symptom — expression not unwrapped: +out[0] = 1.0/{ \eta \hspace{ 0.0006pt } }; ``` -**Cause**: `unwrap(fn, keep_constants=False)` not properly unwrapping constants. +**Cause**: A `UWexpression` was not properly detected as constant or unwrapped. -**Solution**: Check that constants (like UWQuantity) are being unwrapped to numeric values. +**Solution**: Check that the expression resolves to a pure number when fully +unwrapped. Composite expressions containing mesh variables or coordinates +cannot be routed through `constants[]`. ### PETSc DM Errors with Swarms diff --git a/docs/developer/subsystems/expressions-functions.md b/docs/developer/subsystems/expressions-functions.md index 5fe02453..9abad664 100644 --- a/docs/developer/subsystems/expressions-functions.md +++ b/docs/developer/subsystems/expressions-functions.md @@ -2,67 +2,197 @@ title: "Expressions & Functions System" --- -# Expressions & Functions Documentation +# Expressions & Functions -```{important} Critical Documentation Gap -**Module**: `function/expressions.py` (606 lines) -**Priority**: 🔴 Critical - highest priority for documentation -**Current Status**: Minimal documentation ❌ +## Overview + +The `UWexpression` class is the symbolic backbone of Underworld3. It wraps +SymPy symbols with metadata (units, values, descriptions) while remaining +fully compatible with SymPy arithmetic and the JIT compilation pipeline. + +**Key files:** + +| File | Purpose | +|------|---------| +| `function/expressions.py` | `UWexpression`, unwrapping, constant detection | +| `function/_function.pyx` | `UnderworldFunction` — mesh variable symbols | +| `utilities/_jitextension.py` | JIT compiler, constants extraction, C code generation | + +## Creating Expressions + +```python +import underworld3 as uw -This is user-facing but severely underdocumented - **immediate attention needed**. +# Scalar constant +viscosity = uw.expression("eta", 1e21) + +# With units (when scaling is active) +viscosity = uw.expression("eta", uw.quantity(1e21, "Pa*s")) + +# Composite expression — built from other expressions +Ra = uw.expression("Ra", rho * alpha * g * DeltaT * L**3 / (eta * kappa)) ``` -## Overview +Expressions are SymPy `Symbol` subclasses, so they work naturally in +equations: + +```python +# Arithmetic produces new SymPy expressions (not raw floats) +flux = viscosity * strain_rate # viscosity stays symbolic +buoyancy = Ra * temperature * unit_z # Ra stays symbolic +``` -The expressions and functions subsystem handles symbolic expression management and mathematical function definition. +## Why Expressions Matter for Performance -### Current State -- **Files**: - - `expressions.py`: 606 lines - Symbolic expression handling - - `analytic.py`: 379 lines - Analytic functions - - `utilities.py`: 207 lines - Function utilities -- **Complexity**: High - sympy integration, expression manipulation -- **Documentation Quality**: Minimal ❌ +**Expressions are the preferred way to pass parameters to solvers.** +When a solver parameter is a `UWexpression`, changing its value between +time steps does not trigger JIT recompilation. When the parameter is a +raw Python number, changing it requires a full rebuild of the compiled +C extension (~5–15 seconds per solve). -### Key Components -- `UWExpression`: Base symbolic expression class -- Expression registry with unique naming -- SymPy integration for mathematical operations -- JIT compilation support +```python +# GOOD — expression parameter, no recompilation on change +eta = uw.expression("eta", 1e21) +stokes.constitutive_model.Parameters.shear_viscosity_0 = eta -## Critical Documentation Needs +for step in range(100): + eta.sym = compute_new_viscosity(step) # Just updates constants[] + stokes.solve() # ~0.3s per solve -### Missing Essential Content -- ❌ Limited usage examples -- ❌ Expression building patterns missing -- ❌ JIT compilation workflow undocumented -- ❌ Integration with mathematical objects unclear -- ❌ Performance implications unknown +# SLOW — raw number, forces recompilation every step +for step in range(100): + stokes.constitutive_model.Parameters.shear_viscosity_0 = new_value + stokes.solve(_force_setup=True) # ~15s per solve (JIT rebuild) +``` + +This is especially important for: -### User Impact -This system is central to user workflows but lacks documentation, creating a significant barrier to adoption and effective use. +- **Viscoelastic solvers** — `dt_elastic` changes every step +- **Parameter sweeps** — varying viscosity, yield stress, etc. +- **Time-dependent BCs** — oscillatory or ramped boundary conditions +- **Navier-Stokes** — any time-varying forcing or material property -## Implementation Tasks +## How It Works: The Constants Mechanism -```{tip} Urgent - For Contributors -This section desperately needs: +### The Problem -1. **Complete API reference** with examples for every function -2. **Expression building cookbook** with common patterns -3. **JIT compilation guide** showing workflow from expression to compiled code -4. **20+ usage examples** covering typical user scenarios -5. **Integration documentation** showing how expressions work with variables -6. **Performance guidance** for optimal expression construction -7. **Debugging help** for common expression issues +The JIT compiler translates SymPy expressions into C code for PETSc's +pointwise function interface. Previously, all constant values were baked +as C literals: -**Estimated effort**: Substantial development time for comprehensive documentation +```c +// Old: value baked into compiled code +double result = 1e+21 * velocity_gradient; // must recompile to change ``` -## Related Systems +### The Solution -- Works closely with [Mathematical Objects](../UW3_Developers_MathematicalObjects.md) -- Used by [Solvers](solvers.md) for symbolic equation definition +Every PETSc pointwise function signature includes `numConstants` and +`constants[]` parameters that were previously unused. Now, `UWexpression` +atoms that are spatially constant (no dependence on coordinates or field +variables) are automatically routed through this array: ---- +```c +// New: value read from constants array at runtime +double result = constants[0] * velocity_gradient; // update via PetscDSSetConstants() +``` + +### What Happens Automatically + +1. **Constant detection** — Before JIT compilation, `_extract_constants()` + scans all expression trees for `UWexpression` atoms whose fully-unwrapped + value is a pure number. This works at any nesting depth (user expression → + constitutive model parameter → solver template). + +2. **Structural hashing** — The JIT cache key is computed from the + *structural* form of expressions (constants replaced with placeholders). + Changing a constant value produces the same hash → cache hit → no + recompilation. + +3. **Two-phase unwrap** — During code generation: + - Phase 1: constant UWexpressions → `_JITConstant` symbols (render as `constants[i]`) + - Phase 2: remaining UWexpressions → numerical values (baked into C code) + +4. **Runtime update** — Before every `snes.solve()`, the solver calls + `_update_constants()` which packs current values from the manifest + and calls `PetscDSSetConstants()`. This propagates to all levels + of the multigrid hierarchy. + +### What Goes Through Constants + +Any `UWexpression` that resolves to a number when fully unwrapped: + +| Example | In constants[]? | Why | +|---------|-----------------|-----| +| `uw.expression("eta", 1e21)` | Yes | Pure number | +| `uw.expression("Ra", rho*g*alpha*...)` | Yes | Composite of numbers | +| `constitutive_model.Parameters.shear_viscosity_0` | Yes | Wraps user expression | +| `uw.expression("f", sin(x))` | No | Depends on coordinate `x` | +| `velocity.sym[0]` | No | Mesh variable (field dependency) | + +### Inspecting the Constants Manifest + +After the first solve, the constants manifest is available: + +```python +stokes.solve() +for idx, expr in stokes.constants_manifest: + print(f"constants[{idx}] = {expr.name} = {expr.sym}") +``` + +## Expression Unwrapping + +The `unwrap()` function resolves nested `UWexpression` atoms to their +underlying values. Two modes are used internally: + +| Mode | Purpose | Used by | +|------|---------|---------| +| `nondimensional` | Numeric values for JIT/evaluate | `_createext()`, `evaluate()` | +| `dimensional` | Display values with units | `print()`, notebooks | + +```python +# Nested expressions +alpha = uw.expression("alpha", 3e-5) +DeltaT = uw.expression("DeltaT", 1000) +buoyancy = alpha * DeltaT # SymPy expression, not a float + +# Unwrap reveals the numeric value +from underworld3.function.expressions import unwrap +unwrap(buoyancy, keep_constants=False) # → 0.03 +``` + +## Integration with Constitutive Models + +Constitutive model parameters are themselves `UWexpression` objects. +When you assign a user expression to a parameter, it becomes nested: + +``` +User: K = uw.expression("K", 1.0) + ↓ assign to constitutive model +Model: \upkappa.sym = K (UWexpression wrapping UWexpression) + ↓ used in solver template +Solver: F1.sym = \upkappa * grad(u) + ↓ constants extraction finds \upkappa +JIT: F1 → constants[0] * petsc_u_x[0] +``` + +Changing `K.sym = 2.0` propagates through the chain: `\upkappa` still +wraps `K`, so `_pack_constants()` reads the new value automatically. + +## Key Functions + +| Function | Location | Purpose | +|----------|----------|---------| +| `uw.expression(name, value)` | `expressions.py` | Create a named expression | +| `unwrap(expr, mode)` | `expressions.py` | Resolve nested expressions | +| `is_constant_expr(expr)` | `expressions.py` | Check for spatial dependencies | +| `_extract_constants(fns)` | `_jitextension.py` | Find constants in expression trees | +| `_pack_constants(manifest)` | `_jitextension.py` | Get current values for PetscDS | +| `getext(...)` | `_jitextension.py` | Full JIT pipeline: extract → compile → return | + +## Related Systems -*This document represents the highest priority documentation gap in Underworld3.* \ No newline at end of file +- [Mathematical Objects](../UW3_Developers_MathematicalObjects.md) — `MathematicalMixin` for natural syntax +- [Template Expressions](../TEMPLATE_EXPRESSION_PATTERN.md) — `ExpressionProperty` for solver templates +- [Solvers](solvers.md) — consume compiled expressions via PetscDS +- [Constitutive Models](constitutive-models.md) — parameter expressions feed into solver templates diff --git a/src/underworld3/cython/petsc_extras.pxi b/src/underworld3/cython/petsc_extras.pxi index c4cad934..c39d9c71 100644 --- a/src/underworld3/cython/petsc_extras.pxi +++ b/src/underworld3/cython/petsc_extras.pxi @@ -46,6 +46,7 @@ cdef extern from "petsc_compat.h": PetscErrorCode UW_DMPlexComputeBdIntegral( PetscDM, PetscVec, PetscDMLabel, PetscInt, const PetscInt*, void*, PetscScalar*, void*) cdef extern from "petsc.h" nogil: + PetscErrorCode PetscDSSetConstants(PetscDS, PetscInt, const PetscScalar[]) PetscErrorCode DMPlexSNESComputeBoundaryFEM( PetscDM, void *, void *) # PetscErrorCode DMPlexSetSNESLocalFEM( PetscDM, void *, void *, void *) # PetscErrorCode DMPlexSetSNESLocalFEM( PetscDM, PetscBool, void *) diff --git a/src/underworld3/cython/petsc_generic_snes_solvers.pyx b/src/underworld3/cython/petsc_generic_snes_solvers.pyx index 20d5c7e2..1f9cebea 100644 --- a/src/underworld3/cython/petsc_generic_snes_solvers.pyx +++ b/src/underworld3/cython/petsc_generic_snes_solvers.pyx @@ -37,6 +37,7 @@ class SolverBaseClass(uw_object): self.mesh = mesh self.mesh_dm_coordinate_hash = None self.compiled_extensions = None + self.constants_manifest = [] self.Unknowns = self._Unknowns(self) @@ -454,7 +455,7 @@ class SolverBaseClass(uw_object): # to let the rest of the machinery work. if len(self.natural_bcs) > 0: - if not "Null_Boundary" in self.natural_bcs: + if not any(bc.boundary == "Null_Boundary" for bc in self.natural_bcs): bc = (0,)*self.Unknowns.u.shape[1] self.add_natural_bc(bc, "Null_Boundary") @@ -473,6 +474,48 @@ class SolverBaseClass(uw_object): return + def _set_constants_on_ds(self, ds): + """Pack current constant values and call PetscDSSetConstants. + + Parameters + ---------- + ds : PETSc DS object + The PetscDS to set constants on. + """ + if not self.constants_manifest: + return + + from underworld3.utilities._jitextension import _pack_constants + import numpy as np + + values = _pack_constants(self.constants_manifest) + + cdef DS cds = ds + cdef int n_constants = len(values) + cdef double[::1] vals_view = np.ascontiguousarray(values, dtype=np.float64) + CHKERRQ(PetscDSSetConstants(cds.ds, n_constants, &vals_view[0])) + + def _update_constants(self): + """Re-pack current UWexpression values and call PetscDSSetConstants. + + Called before each solve() to ensure constants are current without + requiring JIT recompilation. + """ + if not self.constants_manifest or self.dm is None: + return + + ds = self.dm.getDS() + self._set_constants_on_ds(ds) + + # Also propagate to coarse DMs in multigrid hierarchy + if hasattr(self, 'dm_hierarchy') and self.dm_hierarchy: + for coarse_dm in self.dm_hierarchy[:-1]: + try: + coarse_ds = coarse_dm.getDS() + self._set_constants_on_ds(coarse_ds) + except Exception: + pass + # Deprecate in favour of properties for solver.F0, solver.F1 @timing.routine_timer_decorator def _setup_problem_description(self): @@ -1396,8 +1439,10 @@ class SNES_Scalar(SolverBaseClass): # f0 = sympy.Array(uw.function.fn_substitute_expressions(self.F0.sym)).reshape(1).as_immutable() # F1 = sympy.Array(uw.function.fn_substitute_expressions(self.F1.sym)).reshape(dim).as_immutable() - f0 = sympy.Array(uw.function.expressions._unwrap_for_compilation(self.F0.sym, keep_constants=False, return_self=False)).reshape(1).as_immutable() - F1 = sympy.Array(uw.function.expressions._unwrap_for_compilation(self.F1.sym, keep_constants=False, return_self=False)).reshape(dim).as_immutable() + # Don't unwrap here — let getext()'s two-phase unwrap handle it. + # This preserves constant UWexpressions as symbols for the constants[] mechanism. + f0 = sympy.Array(self.F0.sym).reshape(1).as_immutable() + F1 = sympy.Array(self.F1.sym).reshape(dim).as_immutable() self._u_f0 = f0 self._u_F1 = F1 @@ -1492,7 +1537,7 @@ class SNES_Scalar(SolverBaseClass): print(f"Scalar SNES: Jacobians complete, now compile", flush=True) prim_field_list = [self.u] - self.compiled_extensions, self.ext_dict = getext(self.mesh, + _getext_result = getext(self.mesh, tuple(fns_residual), tuple(fns_jacobian), [x.fn for x in self.essential_bcs], @@ -1501,6 +1546,9 @@ class SNES_Scalar(SolverBaseClass): primary_field_list=prim_field_list, verbose=verbose, debug=debug,) + self.compiled_extensions = _getext_result.ptrobj + self.ext_dict = _getext_result.fn_dicts + self.constants_manifest = _getext_result.constants_manifest return @@ -1577,6 +1625,9 @@ class SNES_Scalar(SolverBaseClass): NULL, ) + # Set constants on DS before copying to coarse levels + self._set_constants_on_ds(ds) + # Rebuild this lot for coarse_dm in self.dm_hierarchy: @@ -1691,6 +1742,9 @@ class SNES_Scalar(SolverBaseClass): ierr = DMSetAuxiliaryVec_UW(dm.dm, NULL, 0, 0, cmesh_lvec.vec); CHKERRQ(ierr) + # Update constants (e.g. changed material params) before solve + self._update_constants() + # solve self.snes.solve(None, gvec) @@ -2133,8 +2187,10 @@ class SNES_Vector(SolverBaseClass): # f0 = sympy.Array(uw.function.fn_substitute_expressions(self.F0.sym)).reshape(dim).as_immutable() # F1 = sympy.Array(uw.function.fn_substitute_expressions(self.F1.sym)).reshape(dim,dim).as_immutable() - f0 = sympy.Array(uw.function.expressions._unwrap_for_compilation(self.F0.sym, keep_constants=False, return_self=False)).reshape(dim).as_immutable() - F1 = sympy.Array(uw.function.expressions._unwrap_for_compilation(self.F1.sym, keep_constants=False, return_self=False)).reshape(dim,dim).as_immutable() + # Don't unwrap here — let getext()'s two-phase unwrap handle it. + # This preserves constant UWexpressions as symbols for the constants[] mechanism. + f0 = sympy.Array(self.F0.sym).reshape(dim).as_immutable() + F1 = sympy.Array(self.F1.sym).reshape(dim,dim).as_immutable() self._u_f0 = f0 @@ -2230,7 +2286,7 @@ class SNES_Vector(SolverBaseClass): # note also that the order here is important. prim_field_list = [self.u,] - self.compiled_extensions, self.ext_dict = getext(self.mesh, + _getext_result = getext(self.mesh, tuple(fns_residual), tuple(fns_jacobian), [x.fn for x in self.essential_bcs], @@ -2239,6 +2295,9 @@ class SNES_Vector(SolverBaseClass): primary_field_list=prim_field_list, verbose=verbose, debug=debug,) + self.compiled_extensions = _getext_result.ptrobj + self.ext_dict = _getext_result.fn_dicts + self.constants_manifest = _getext_result.constants_manifest cdef PtrContainer ext = self.compiled_extensions @@ -2331,6 +2390,9 @@ class SNES_Vector(SolverBaseClass): for boundary in self.natural_bcs: UW_PetscDSViewBdWF(ds.ds, boundary.PETScID) + # Set constants on DS before copying to coarse levels + self._set_constants_on_ds(ds) + # Rebuild this lot for coarse_dm in self.dm_hierarchy: @@ -2355,8 +2417,6 @@ class SNES_Vector(SolverBaseClass): self.constitutive_model._solver_is_setup = True - - @timing.routine_timer_decorator def solve(self, zero_init_guess: bool =True, @@ -2454,6 +2514,9 @@ class SNES_Vector(SolverBaseClass): cmesh_lvec = self.mesh.lvec ierr = DMSetAuxiliaryVec_UW(dm.dm, NULL, 0, 0, cmesh_lvec.vec); CHKERRQ(ierr) + # Update constants (e.g. changed material params) before solve + self._update_constants() + # solve self.snes.solve(None,gvec) @@ -3173,9 +3236,11 @@ class SNES_Stokes_SaddlePt(SolverBaseClass): ## and do these one by one as required by PETSc. However, at the moment, this ## is working .. so be careful !! - F0 = sympy.Array(uw.function.expressions._unwrap_for_compilation(self.F0.sym, keep_constants=False, return_self=False)) - F1 = sympy.Array(uw.function.expressions._unwrap_for_compilation(self.F1.sym, keep_constants=False, return_self=False)) - PF0 = sympy.Array(uw.function.expressions._unwrap_for_compilation(self.PF0.sym, keep_constants=False, return_self=False)) + # Don't unwrap here — let getext()'s two-phase unwrap handle it. + # This preserves constant UWexpressions as symbols for the constants[] mechanism. + F0 = sympy.Array(self.F0.sym) + F1 = sympy.Array(self.F1.sym) + PF0 = sympy.Array(self.PF0.sym) # JIT compilation needs immutable, matrix input (not arrays) self._u_F0 = sympy.ImmutableDenseMatrix(F0) @@ -3336,7 +3401,7 @@ class SNES_Stokes_SaddlePt(SolverBaseClass): print(f"Stokes: Jacobians complete, now compile", flush=True) prim_field_list = [self.u, self.p] - self.compiled_extensions, self.ext_dict = getext(self.mesh, + _getext_result = getext(self.mesh, tuple(fns_residual), tuple(fns_jacobian), [x.fn for x in self.essential_bcs], @@ -3347,7 +3412,9 @@ class SNES_Stokes_SaddlePt(SolverBaseClass): debug=debug, debug_name=debug_name, cache=False) - + self.compiled_extensions = _getext_result.ptrobj + self.ext_dict = _getext_result.fn_dicts + self.constants_manifest = _getext_result.constants_manifest self.is_setup = False @@ -3647,6 +3714,8 @@ class SNES_Stokes_SaddlePt(SolverBaseClass): # self.dm.setUp() # self.dm.ds.setUp() + # Set constants on DS before copying to coarse levels + self._set_constants_on_ds(ds) # Rebuild this lot @@ -3762,6 +3831,9 @@ class SNES_Stokes_SaddlePt(SolverBaseClass): self.mesh.update_lvec() self.dm.setAuxiliaryVec(self.mesh.lvec, None) + # Update constants (e.g. changed material params) before solve + self._update_constants() + gvec = self.dm.getGlobalVec() gvec.setArray(0.0) diff --git a/src/underworld3/cython/petsc_maths.pyx b/src/underworld3/cython/petsc_maths.pyx index da12acdf..fe2d19f8 100644 --- a/src/underworld3/cython/petsc_maths.pyx +++ b/src/underworld3/cython/petsc_maths.pyx @@ -89,8 +89,8 @@ class Integral: self.dm = self.mesh.dm # .clone() mesh=self.mesh - compiled_extns, dictionaries = getext(self.mesh, [self.fn,], [], [], [], [], self.mesh.vars.values(), verbose=verbose) - cdef PtrContainer ext = compiled_extns + _getext_result = getext(self.mesh, [self.fn,], [], [], [], [], self.mesh.vars.values(), verbose=verbose) + cdef PtrContainer ext = _getext_result.ptrobj # Pull out vec for variables, and go ahead with the integral @@ -273,7 +273,7 @@ class CellWiseIntegral: elif isinstance(self.fn, sympy.vector.Dyadic): raise RuntimeError("Integral evaluation for Dyadic integrands not supported.") - cdef PtrContainer ext = getext(self.mesh, [self.fn,], [], [], self.mesh.vars.values()) + cdef PtrContainer ext = getext(self.mesh, [self.fn,], [], [], [], [], self.mesh.vars.values()).ptrobj # Pull out vec for variables, and go ahead with the integral self.mesh.update_lvec() @@ -387,10 +387,10 @@ class BdIntegral: mesh = self.mesh # Compile integrand using the boundary residual slot (includes petsc_n[] in signature) - compiled_extns, dictionaries = getext( + _getext_result = getext( self.mesh, [], [], [], [self.fn,], [], self.mesh.vars.values(), verbose=verbose ) - cdef PtrContainer ext = compiled_extns + cdef PtrContainer ext = _getext_result.ptrobj # Prepare the solution vector self.mesh.update_lvec() diff --git a/src/underworld3/utilities/_jitextension.py b/src/underworld3/utilities/_jitextension.py index e88697dc..5eddb90f 100644 --- a/src/underworld3/utilities/_jitextension.py +++ b/src/underworld3/utilities/_jitextension.py @@ -33,6 +33,197 @@ _ext_dict = {} +# ============================================================================ +# JIT Constants Support +# ============================================================================ +# +# UWexpressions that are "constant" (no spatial/field dependencies) are routed +# through PETSc's constants[] array instead of being baked as C literals. +# This allows parameter changes without JIT recompilation. +# ============================================================================ + +class _JITConstant(sympy.Symbol): + """Symbol subclass that renders as constants[i] in generated C code. + + Used by the JIT compiler to route constant UWexpressions through + PETSc's PetscDSSetConstants() mechanism instead of baking values + as C literals. + """ + + def __new__(cls, index, name=None): + if name is None: + name = f"_jit_const_{index}" + obj = super().__new__(cls, name) + obj._const_index = index + obj._ccodestr = f"constants[{index}]" + return obj + + def _ccode(self, printer): + return self._ccodestr + + +def _extract_constants(all_fns, mesh): + """Extract constant UWexpressions from a list of pre-unwrap functions. + + Scans all expressions for UWexpression atoms where is_constant_expr() + is True (no spatial/field dependencies). Assigns deterministic indices + sorted by expression name for MPI consistency. + + Parameters + ---------- + all_fns : tuple of sympy expressions + The raw (pre-unwrap) function list. + mesh : underworld3.discretisation.Mesh + The mesh (currently unused, reserved for future mesh.t support). + + Returns + ------- + list of (int, UWexpression) + Ordered mapping from constants[] index to UWexpression reference. + dict + Mapping from UWexpression to _JITConstant symbol for substitution. + """ + from underworld3.function.expressions import ( + is_constant_expr, + extract_expressions, + UWexpression, + ) + + constant_exprs = set() + + for fn in all_fns: + if fn is None: + continue + + # Handle Matrix expressions + if isinstance(fn, sympy.MatrixBase): + for elem in fn: + _collect_constant_atoms(elem, constant_exprs, is_constant_expr, UWexpression) + else: + _collect_constant_atoms(fn, constant_exprs, is_constant_expr, UWexpression) + + if not constant_exprs: + return [], {} + + # Sort by name for deterministic MPI-consistent ordering + sorted_constants = sorted(constant_exprs, key=lambda e: str(e)) + + manifest = [] + subs_map = {} + for i, expr in enumerate(sorted_constants): + jit_const = _JITConstant(i, name=f"_jit_const_{str(expr)}") + manifest.append((i, expr)) + subs_map[expr] = jit_const + + return manifest, subs_map + + +def _is_truly_constant(expr, UWexpression): + """Check if a UWexpression resolves to a pure constant (no spatial deps). + + Unlike is_constant_expr(), this handles nested UWexpressions correctly + by fully unwrapping and checking if the result has any spatial/field + symbols (BaseScalar, UnderworldFunction, etc.). + """ + try: + unwrapped = underworld3.function.expressions.unwrap_expression( + expr, mode='nondimensional' + ) + except Exception: + return False + + # If it unwraps to a plain number, it's constant + if isinstance(unwrapped, (int, float)): + return True + if isinstance(unwrapped, sympy.Number): + return True + + if not hasattr(unwrapped, 'free_symbols'): + try: + float(unwrapped) + return True + except (TypeError, ValueError): + return False + + # Check remaining free symbols — any spatial/field dependency makes it non-constant + from sympy.vector.scalar import BaseScalar + for sym in unwrapped.free_symbols: + if isinstance(sym, BaseScalar): + return False + if isinstance(sym, sympy.Function): + return False + # UnderworldFunction symbols have _ccodestr pointing to petsc arrays + if hasattr(sym, '_ccodestr') and not isinstance(sym, _JITConstant): + ccode = sym._ccodestr + if 'petsc_u' in ccode or 'petsc_a' in ccode or 'petsc_x' in ccode or 'petsc_n' in ccode: + return False + # Other UWexpressions that didn't fully unwrap — not constant + if isinstance(sym, UWexpression): + return False + + return True + + +def _collect_constant_atoms(expr, result_set, is_constant_expr, UWexpression): + """Recursively collect constant UWexpression atoms from an expression.""" + + if isinstance(expr, UWexpression): + if _is_truly_constant(expr, UWexpression): + result_set.add(expr) + return # Don't recurse into constant expressions + # Non-constant UWexpression: check its inner sym for nested constants + if hasattr(expr, '_sym') and expr._sym is not None: + _collect_constant_atoms(expr._sym, result_set, is_constant_expr, UWexpression) + return + + if not hasattr(expr, 'atoms'): + return + + # Check all UWexpression atoms + for atom in expr.atoms(sympy.Symbol): + if isinstance(atom, UWexpression) and _is_truly_constant(atom, UWexpression): + result_set.add(atom) + elif isinstance(atom, UWexpression): + # Non-constant UWexpression: recurse into its sym + if hasattr(atom, '_sym') and atom._sym is not None: + _collect_constant_atoms(atom._sym, result_set, is_constant_expr, UWexpression) + + +def _pack_constants(manifest): + """Pack current values from a constants manifest into a flat array. + + Parameters + ---------- + manifest : list of (int, UWexpression) + The constants manifest from _extract_constants(). + + Returns + ------- + list of float + Current nondimensional values in index order. + """ + import numpy as np + + if not manifest: + return np.array([], dtype=np.float64) + + values = np.zeros(len(manifest), dtype=np.float64) + for idx, uw_expr in manifest: + try: + values[idx] = float( + underworld3.function.expressions.unwrap_expression( + uw_expr, mode='nondimensional' + ) + ) + except (TypeError, ValueError): + # Fallback: try .data property + try: + values[idx] = float(uw_expr.data) + except Exception: + values[idx] = 0.0 + return values + + # Generates the C debugging string for the compiled function block def debugging_text(randstr, fn, fn_type, eqn_no): try: @@ -78,6 +269,9 @@ def debugging_text_bd(randstr, fn, fn_type, eqn_no): return debug_str +_GextResult = namedtuple("GextResult", ["ptrobj", "fn_dicts", "constants_manifest"]) + + @timing.routine_timer_decorator def getext( mesh, @@ -95,6 +289,13 @@ def getext( """ Check if we've already created an equivalent extension and use if available. + + Returns + ------- + GextResult + Named tuple with fields (ptrobj, fn_dicts, constants_manifest). + constants_manifest is a list of (index, uw_expression_ref) tuples + for use with PetscDSSetConstants(). """ import time @@ -108,14 +309,26 @@ def getext( + tuple(fns_bd_jacobian) ) - ## Expand all functions to ensure that changes in constants are recognised - ## in the caching process. + # Extract constant UWexpressions that will go through constants[] array + constants_manifest, constants_subs_map = _extract_constants(raw_fns, mesh) + # Build structurally-expanded functions for cache hashing. + # Constants are replaced with placeholder symbols (value-independent), + # so changing a constant value won't cause a cache miss. expanded_fns = [] - for fn in raw_fns: + # Phase 1: Substitute constants with _JITConstant placeholders + if constants_subs_map and fn is not None: + try: + fn_structural = fn.xreplace(constants_subs_map) if hasattr(fn, 'xreplace') else fn + except Exception: + fn_structural = fn + else: + fn_structural = fn + + # Phase 2: Unwrap remaining (non-constant) expressions expanded_fns.append( - underworld3.function.expressions.unwrap(fn, keep_constants=False, return_self=False) + underworld3.function.expressions.unwrap(fn_structural, keep_constants=False, return_self=False) ) fns = tuple(expanded_fns) @@ -124,13 +337,13 @@ def getext( print(f"Expanded functions for compilation:") for i, fn in enumerate(fns): print(f"{i}: {fn}") + if constants_manifest: + print(f"Constants manifest ({len(constants_manifest)} entries):") + for idx, expr in constants_manifest: + print(f" constants[{idx}] = {expr} (current value: {expr.data})") import os - # if verbose and uw.mpi.rank == 0: - # for i, fn in enumerate(fns): - # print(f"JIT: [{i:3d}] -> {fn}", flush=True) - if debug_name is not None: jitname = debug_name @@ -140,7 +353,7 @@ def getext( # unique modules. jitname += "_" + str(len(_ext_dict.keys())) - else: # Else name from fns hash + else: # Else name from fns hash — uses structural form (constants as placeholders) jitname = abs(hash((mesh, fns, tuple(mesh.vars.keys())))) # Create the module if not in dictionary @@ -154,6 +367,7 @@ def getext( fns_bd_residual, fns_bd_jacobian, primary_field_list, + constants_subs_map=constants_subs_map, verbose=verbose, debug=debug, debug_name=debug_name, @@ -162,13 +376,8 @@ def getext( if verbose and underworld3.mpi.rank == 0: print(f"JIT compiled module cached ... {jitname} ", flush=True) - ## TODO: Return a dictionary to recover the function pointers from the compiled - ## functions. Note, keep these by category as the same sympy function has - ## different compiled form depending on the function signature - module = _ext_dict[jitname] ptrobj = module.getptrobj() - # print(f"jit time {time.time()-time_s}", flush=True) i_res = {} for index, fn in enumerate(fns_residual): @@ -197,7 +406,7 @@ def getext( extensions_functions_dicts = extn_fn_dict(i_res, i_jac, i_ebc, i_bd_res, i_bd_jac) - return ptrobj, extensions_functions_dicts + return _GextResult(ptrobj, extensions_functions_dicts, constants_manifest) @timing.routine_timer_decorator @@ -210,6 +419,7 @@ def _createext( fns_bd_residual: List[sympy.Basic], fns_bd_jacobian: List[sympy.Basic], primary_field_list: List[underworld3.discretisation.MeshVariable], + constants_subs_map: Optional[dict] = None, verbose: Optional[bool] = False, debug: Optional[bool] = False, debug_name=None, @@ -432,6 +642,16 @@ def _basescalar_ccode(self, printer): # Save original for debugging fn_original = fn + # Two-phase unwrap: + # Phase 1: Substitute constant UWexpressions with _JITConstant symbols + # These survive into C code as constants[i] + if constants_subs_map and fn is not None: + try: + fn = fn.xreplace(constants_subs_map) if hasattr(fn, 'xreplace') else fn + except Exception: + pass + + # Phase 2: Unwrap remaining non-constant UWexpressions to numerical values fn = underworld3.function.expressions.unwrap(fn, keep_constants=False, return_self=False) if isinstance(fn, sympy.vector.Vector): diff --git a/tests/test_0004_pointwise_fns.py b/tests/test_0004_pointwise_fns.py index 6e88da1c..7c5f3b60 100644 --- a/tests/test_0004_pointwise_fns.py +++ b/tests/test_0004_pointwise_fns.py @@ -62,7 +62,7 @@ def test_getext_simple(): bd_jac_fn = sympy.ImmutableDenseMatrix([sympy.sympify(1), sympy.sympify(2)]) with uw.utilities.CaptureStdout(split=True) as captured_setup_solver: - compiled_extns, dictionaries = getext( + _getext_result = getext( mesh, [res_fn, res_fn], [jac_fn], @@ -107,7 +107,7 @@ def test_getext_sympy_fns(): ) with uw.utilities.CaptureStdout(split=True) as captured_setup_solver: - compiled_extns, dictionaries = getext( + _getext_result = getext( mesh, [res_fn, res_fn], [jac_fn], @@ -161,7 +161,7 @@ def test_getext_meshVar(): ) with uw.utilities.CaptureStdout(split=True) as captured_setup_solver: - compiled_extns, dictionaries = getext( + _getext_result = getext( mesh, [res_fn, res_fn], [jac_fn], diff --git a/tests/test_1001_poisson_constants.py b/tests/test_1001_poisson_constants.py new file mode 100644 index 00000000..c6beb0b3 --- /dev/null +++ b/tests/test_1001_poisson_constants.py @@ -0,0 +1,314 @@ +# Tests for the PetscDS constants array mechanism. +# +# These tests verify that UWexpression parameters routed through +# PETSc's constants[] array work correctly: +# 1. Solver produces correct results with constants +# 2. Changing a constant value and re-solving WITHOUT _force_setup +# gives correct results for the new value +# 3. No JIT recompilation occurs when only constant values change + +import numpy as np +import pytest +import sympy + +import underworld3 as uw +from underworld3.utilities._jitextension import _ext_dict + +# Solver-level tests +pytestmark = pytest.mark.level_1 + + +@pytest.fixture(autouse=True) +def reset_model_state(): + """Reset model state before each test.""" + uw.reset_default_model() + uw.use_strict_units(False) + uw.use_nondimensional_scaling(False) + yield + uw.reset_default_model() + uw.use_strict_units(False) + uw.use_nondimensional_scaling(False) + + +def test_poisson_constant_diffusivity_expression(): + """Poisson with UWexpression diffusivity solves correctly.""" + + mesh = uw.meshing.UnstructuredSimplexBox(cellSize=0.2) + x, y = mesh.X + + u = uw.discretisation.MeshVariable("u_cK", mesh, 1, degree=2) + + K = uw.expression("K_diff", 1.0) + + poisson = uw.systems.Poisson(mesh, u_Field=u) + poisson.constitutive_model = uw.constitutive_models.DiffusionModel + poisson.constitutive_model.Parameters.diffusivity = K + poisson.f = 0.0 + + poisson.add_dirichlet_bc(1.0, "Bottom") + poisson.add_dirichlet_bc(0.0, "Top") + poisson.solve() + + assert poisson.snes.getConvergedReason() > 0 + + # Check linear profile u(y) = 1 - y + sample_y = np.linspace(0.05, 0.95, 10) + sample_x = np.full_like(sample_y, 0.5) + sample_points = np.column_stack([sample_x, sample_y]) + + u_num = uw.function.evaluate(u.sym[0], sample_points, rbf=False).squeeze() + u_exact = 1 - sample_y + + error = np.sqrt(np.mean((u_num - u_exact) ** 2)) + assert error < 1e-3, f"K=1 linear profile error {error:.3e} too large" + + del poisson + + +def test_poisson_change_constant_no_recompile(): + """Change a constant UWexpression value and re-solve without recompilation. + + This is the key test for the constants[] mechanism: + - Solve with K=1, source f=-2 → u(y) = y^2 + - Change K to 2 (but same structural expression) + - Re-solve WITHOUT _force_setup + - Verify: correct result for K=2, and no new JIT module compiled + """ + + mesh = uw.meshing.UnstructuredSimplexBox(cellSize=0.15) + x, y = mesh.X + + u = uw.discretisation.MeshVariable("u_recomp", mesh, 1, degree=2) + + K = uw.expression("K_recomp", 1.0) + + poisson = uw.systems.Poisson(mesh, u_Field=u) + poisson.constitutive_model = uw.constitutive_models.DiffusionModel + poisson.constitutive_model.Parameters.diffusivity = K + # Source term: f = -2 with K=1 and BCs u(0)=0, u(1)=1 + # gives u(y) = y^2 (since -K * u'' = f → -1 * 2 = -2 ✓, u(0)=0, u(1)=1) + poisson.f = -2.0 + + poisson.add_dirichlet_bc(0.0, "Bottom") + poisson.add_dirichlet_bc(1.0, "Top") + + # --- First solve with K=1 --- + poisson.solve() + assert poisson.snes.getConvergedReason() > 0 + + sample_y = np.linspace(0.05, 0.95, 15) + sample_x = np.full_like(sample_y, 0.5) + sample_points = np.column_stack([sample_x, sample_y]) + + u_num_1 = uw.function.evaluate(u.sym[0], sample_points, rbf=False).squeeze() + u_exact_1 = sample_y ** 2 + error_1 = np.sqrt(np.mean((u_num_1 - u_exact_1) ** 2)) + assert error_1 < 5e-3, f"K=1 solve error {error_1:.3e} too large" + + # Record JIT cache size + n_modules_before = len(_ext_dict) + + # --- Change K to 2 and re-solve --- + # With K=2 and f=-2: -K*u'' = f → -2*u'' = -2 → u'' = 1 → u(y) = y²/2 + Ay + B + # BCs: u(0)=0 → B=0, u(1)=1 → 1/2 + A = 1 → A = 1/2 + # So u(y) = y²/2 + y/2 = y(y+1)/2 + K.sym = 2.0 + + # Re-solve — should use _update_constants(), NOT recompile + poisson.solve() + assert poisson.snes.getConvergedReason() > 0 + + n_modules_after = len(_ext_dict) + + u_num_2 = uw.function.evaluate(u.sym[0], sample_points, rbf=False).squeeze() + u_exact_2 = sample_y * (sample_y + 1) / 2 + error_2 = np.sqrt(np.mean((u_num_2 - u_exact_2) ** 2)) + assert error_2 < 5e-3, f"K=2 solve error {error_2:.3e} too large" + + # Verify no new JIT compilation occurred + assert n_modules_after == n_modules_before, ( + f"JIT recompilation detected: {n_modules_before} → {n_modules_after} modules. " + f"Constants mechanism should have avoided recompilation." + ) + + # Verify the two solutions are actually different + diff = np.max(np.abs(u_num_2 - u_num_1)) + assert diff > 0.01, ( + f"Solutions with K=1 and K=2 are suspiciously similar (max diff={diff:.3e}). " + f"Constants update may not be working." + ) + + del poisson + + +def test_poisson_constant_source_expression(): + """Poisson with UWexpression source term routed through constants[].""" + + mesh = uw.meshing.UnstructuredSimplexBox(cellSize=0.2) + x, y = mesh.X + + u = uw.discretisation.MeshVariable("u_cS", mesh, 1, degree=2) + + S = uw.expression("S_source", 1.0) + + poisson = uw.systems.Poisson(mesh, u_Field=u) + poisson.constitutive_model = uw.constitutive_models.DiffusionModel + poisson.constitutive_model.Parameters.diffusivity = 1 + # -u'' = S with u(0)=0, u(1)=0 → u(y) = S/2 * y * (1-y) + poisson.f = S + + poisson.add_dirichlet_bc(0.0, "Bottom") + poisson.add_dirichlet_bc(0.0, "Top") + poisson.solve() + + assert poisson.snes.getConvergedReason() > 0 + + sample_y = np.linspace(0.05, 0.95, 15) + sample_x = np.full_like(sample_y, 0.5) + sample_points = np.column_stack([sample_x, sample_y]) + + u_num = uw.function.evaluate(u.sym[0], sample_points, rbf=False).squeeze() + u_exact = 0.5 * sample_y * (1 - sample_y) + + error = np.sqrt(np.mean((u_num - u_exact) ** 2)) + assert error < 5e-3, f"Constant source error {error:.3e} too large" + + del poisson + + +def test_stokes_constant_viscosity_expression(): + """Stokes with UWexpression viscosity routed through constants[].""" + + mesh = uw.meshing.UnstructuredSimplexBox(cellSize=0.2) + + v = uw.discretisation.MeshVariable("v_cV", mesh, mesh.dim, degree=2) + p = uw.discretisation.MeshVariable("p_cV", mesh, 1, degree=1, continuous=True) + + eta = uw.expression("eta_const", 1.0) + + stokes = uw.systems.Stokes(mesh, velocityField=v, pressureField=p) + stokes.constitutive_model = uw.constitutive_models.ViscousFlowModel + stokes.constitutive_model.Parameters.shear_viscosity_0 = eta + stokes.bodyforce = sympy.Matrix([0, 0]) + + stokes.add_dirichlet_bc((0.0, 0.0), "Bottom") + stokes.add_dirichlet_bc((1.0, 0.0), "Top") + stokes.add_dirichlet_bc((sympy.oo, 0.0), "Left") + stokes.add_dirichlet_bc((sympy.oo, 0.0), "Right") + + stokes.solve() + assert stokes.snes.getConvergedReason() > 0 + + # Simple shear: v_x should be linear in y (from 0 at bottom to 1 at top) + sample_y = np.linspace(0.05, 0.95, 10) + sample_x = np.full_like(sample_y, 0.5) + sample_points = np.column_stack([sample_x, sample_y]) + + vx_num = uw.function.evaluate(v.sym[0], sample_points, rbf=False).squeeze() + vx_exact = sample_y + + error = np.sqrt(np.mean((vx_num - vx_exact) ** 2)) + assert error < 1e-2, f"Stokes simple shear error {error:.3e} too large" + + del stokes + + +def test_poisson_constant_in_essential_bc(): + """Essential BC with UWexpression amplitude — change without recompile.""" + + mesh = uw.meshing.UnstructuredSimplexBox(cellSize=0.2) + x, y = mesh.X + + u = uw.discretisation.MeshVariable("u_ebc", mesh, 1, degree=2) + + A = uw.expression("A_bc", 1.0) + + poisson = uw.systems.Poisson(mesh, u_Field=u) + poisson.constitutive_model = uw.constitutive_models.DiffusionModel + poisson.constitutive_model.Parameters.diffusivity = 1 + poisson.f = 0.0 + + # u = A*sin(pi*x) on top, u = 0 on bottom + poisson.add_dirichlet_bc(0.0, "Bottom") + poisson.add_dirichlet_bc(sympy.Matrix([A * sympy.sin(sympy.pi * x)]), "Top") + poisson.solve() + + assert poisson.snes.getConvergedReason() > 0 + + # Check BC value at top + pts_top = np.array([[0.5, 1.0]]) + u_top_1 = uw.function.evaluate(u.sym[0], pts_top, rbf=False).squeeze() + assert abs(u_top_1 - 1.0) < 1e-3, f"A=1 BC at top: {u_top_1:.4f} != 1.0" + + # Change A, re-solve without recompilation + n_before = len(_ext_dict) + A.sym = 3.0 + poisson.solve() + n_after = len(_ext_dict) + + assert poisson.snes.getConvergedReason() > 0 + assert n_after == n_before, ( + f"Recompiled when changing BC constant: {n_before} → {n_after}" + ) + + u_top_2 = uw.function.evaluate(u.sym[0], pts_top, rbf=False).squeeze() + assert abs(u_top_2 - 3.0) < 1e-3, f"A=3 BC at top: {u_top_2:.4f} != 3.0" + + del poisson + + +def test_stokes_natural_bc_constant_no_recompile(): + """Stokes natural BC with constant traction — change without recompile. + + Also verifies the Null_Boundary bug fix: _build() must not re-add + Null_Boundary on every call, which would reset is_setup and force + recompilation. + """ + + mesh = uw.meshing.UnstructuredSimplexBox(cellSize=0.25) + + v = uw.discretisation.MeshVariable("v_nbc", mesh, mesh.dim, degree=2) + p = uw.discretisation.MeshVariable("p_nbc", mesh, 1, degree=1, continuous=True) + + tau = uw.expression("tau_nbc", 1.0) + + stokes = uw.systems.Stokes(mesh, velocityField=v, pressureField=p) + stokes.constitutive_model = uw.constitutive_models.ViscousFlowModel + stokes.constitutive_model.Parameters.shear_viscosity_0 = 1.0 + stokes.bodyforce = sympy.Matrix([0, 0]) + + stokes.add_dirichlet_bc((0.0, 0.0), "Bottom") + stokes.add_dirichlet_bc((sympy.oo, 0.0), "Left") + stokes.add_dirichlet_bc((sympy.oo, 0.0), "Right") + stokes.add_natural_bc((tau, 0.0), "Top") + + stokes.solve() + assert stokes.snes.getConvergedReason() > 0 + + # tau should be in the constants manifest + const_names = [expr.name for _, expr in stokes.constants_manifest] + assert "tau_nbc" in const_names, ( + f"tau_nbc not found in constants manifest: {const_names}" + ) + + # Null_Boundary should appear exactly once + null_count = sum(1 for bc in stokes.natural_bcs if bc.boundary == "Null_Boundary") + assert null_count == 1, f"Null_Boundary appears {null_count} times (expected 1)" + + # Change traction and re-solve — no recompilation + n_before = len(_ext_dict) + tau.sym = 5.0 + stokes.solve() + n_after = len(_ext_dict) + + assert stokes.snes.getConvergedReason() > 0 + assert n_after == n_before, ( + f"Recompiled when changing natural BC constant: {n_before} → {n_after}. " + f"Null_Boundary count: {sum(1 for bc in stokes.natural_bcs if bc.boundary == 'Null_Boundary')}" + ) + + # Null_Boundary should still be exactly one + null_count_2 = sum(1 for bc in stokes.natural_bcs if bc.boundary == "Null_Boundary") + assert null_count_2 == 1, f"Null_Boundary duplicated to {null_count_2} after second solve" + + del stokes