diff --git a/docs/advanced/snapshot-restore.md b/docs/advanced/snapshot-restore.md index 2acaf276..997090e6 100644 --- a/docs/advanced/snapshot-restore.md +++ b/docs/advanced/snapshot-restore.md @@ -6,7 +6,7 @@ title: "State Snapshots & Restore" ## Overview -`Model.snapshot()` and `Model.restore()` are a *stash for timesteps* — +`Model.save_state()` and `Model.load_state()` are a *stash for timesteps* — a quick "hold that thought, I might need to come back" mechanism for time-stepping code. Take a snapshot, try a step, and if you don't like the result, restore and try again. The system is put back exactly as it @@ -23,11 +23,13 @@ Typical uses: - **Multi-stage time integration** (RK-style) — restore to the start of a step between stages. -This is intentionally *not* archival checkpointing. It is fast, -in-memory, and meant to be used freely within a run. For long-term, -on-disk restart files, use the existing `mesh.write_timestep()` / -`read_timestep()` path, which is unchanged and serves a different -purpose. +The same entry points serve two storage modes — in-memory for fast +intra-run stash, and on-disk for persistent restart / postprocessing. +You pick by giving (or not giving) a ``file=``. The existing +``mesh.write_timestep()`` / ``read_timestep()`` and +``mesh.write_checkpoint()`` / ``MeshVariable.read_checkpoint()`` +paths are unchanged and continue to serve their existing roles (see +"Choosing between paths" at the bottom). ## The API @@ -38,20 +40,37 @@ model = uw.get_default_model() # ... set up mesh, variables, swarm, solvers, step a few times ... -token = model.snapshot() # capture everything, return a token +# In-memory: the "stash for timesteps" use case. +token = model.save_state() # capture everything, return a token # ... take a speculative step you might regret ... -model.restore(token) # put everything back exactly +model.load_state(token) # put everything back exactly ``` -`snapshot()` returns a plain in-memory token. You can hold several at -once and restore any of them. `restore()` returns the model to the -exact state at the moment that token was captured. +To persist on disk instead, pass a path. The same call captures the +same state — only the storage layer differs: + +```python +# On-disk: persistent snapshot for restart, postprocessing, +# bisection studies, or transferring a run to another machine. +model.save_state(file="step42.snap.h5") +# ... later, or in a fresh process with the same model set up ... +model.load_state("step42.snap.h5") +``` + +``save_state()`` returns a :class:`Snapshot` token when called +without ``file``; you can hold several at once and restore any of +them. With ``file=...`` it writes a self-contained on-disk snapshot +and returns the path. + +``load_state()`` takes either a token or a path string — it figures +out which from the argument type, so the same call works for both +storage modes. ## What is captured -You do not enumerate anything — `snapshot()` captures the full state +You do not enumerate anything — `save_state()` captures the full state of the model automatically: - mesh coordinates, @@ -69,13 +88,13 @@ and restore — that is exactly the situation restore exists for. A subtle trap in time-stepping scripts: your loop counter and simulation time usually live in plain Python variables, and -`restore()` has no way to know about them. +`load_state()` has no way to know about them. ```python model_time = 0.0 -token = model.snapshot() +token = model.save_state() model_time = 5.0 # advance -model.restore(token) +model.load_state(token) # model_time is still 5.0 — restore cannot reach a local variable ``` @@ -87,12 +106,12 @@ restored. model.tracker.time = 0.0 model.tracker.step = 0 -token = model.snapshot() +token = model.save_state() model.tracker.time = 5.0 model.tracker.step = 100 -model.restore(token) +model.load_state(token) model.tracker.time # 0.0 — reverted automatically model.tracker.step # 0 — reverted automatically @@ -109,7 +128,7 @@ model.tracker.energy_history = np.zeros(3) These now travel with every snapshot and revert on every restore — no extra code, no special handling in your solvers. Using the tracker is optional; solvers do not depend on it. It is simply the place to keep -the things you want `restore()` to manage. +the things you want `load_state()` to manage. ```{note} Reserved name `state` is reserved on the tracker (it is the snapshot mechanism's own @@ -139,7 +158,7 @@ cfl_limit = mesh.get_min_radius() dt = 0.5 while model.tracker.time < t_end: - token = model.snapshot() + token = model.save_state() coords_before = swarm._particle_coordinates.data.copy() # Speculative step at the current Δt. @@ -153,7 +172,7 @@ while model.tracker.time < t_end: if moved > cfl_limit: # Too big — discard and retry with a smaller Δt. - model.restore(token) + model.load_state(token) dt *= 0.5 continue @@ -180,24 +199,56 @@ attempt starts from precisely where the failed one began. ``` ```{warning} Limitations -- **In-memory only.** Snapshots live in process memory and are not - written to disk; they do not survive the process exiting. They are - also a full copy of model state — holding many large snapshots at - once costs memory. -- **Same rank count.** A snapshot taken on *N* MPI ranks is restored - on *N* ranks. Changing the rank count is not supported by this - mechanism (use the `write_timestep` restart path for that). +- **In-memory tokens are intra-run only.** Tokens returned by + ``save_state()`` (no ``file``) live in process memory and do not + survive the process exiting. They are also a full copy of model + state — holding many large tokens at once costs memory. To persist + across runs, use ``save_state(file=…)`` instead. +- **Same rank count.** A snapshot written on *N* MPI ranks must be + read on *N* ranks. Cross-rank-count restart is not supported by + this mechanism (use the ``mesh.write_timestep`` path for that). - **No mesh adaptation across a snapshot.** If the mesh is adapted - between snapshot and restore, restore refuses with a clear error + between save and load, ``load_state`` refuses with a clear error rather than corrupting state. - **Recovery vs. a never-snapshotted run** is bit-exact for the - *discarded-step* guarantee above. Continuing after a restore that - ran a real solver may differ from a run that never snapshotted by a - small amount within solver tolerance — restore resyncs solver - fields rather than reproducing their exact internal buffers. This - does not affect the correctness of backtracking. + *discarded-step* guarantee above. Continuing after a load_state + that ran a real solver may differ from a run that never + snapshotted by a small amount within solver tolerance — load_state + resyncs solver fields rather than reproducing their exact internal + buffers. This does not affect the correctness of backtracking. ``` +## On-disk file layout (when ``save_state(file=…)`` is used) + +A disk snapshot is **two artifacts**: a small wrapper HDF5 file plus +a sibling ``.bulk/`` directory holding the bulk data. They are a +unit — move them together. The wrapper is rich in metadata and +inspectable with standard h5 tools without UW3 in the loop: + +```text +my_run.snap.h5 (~tens of KB; metadata, group structure) +my_run.snap.bulk/ (per-mesh + per-swarm sidecars) + {mesh}.mesh.00000.h5 + {mesh}.{var}.00000.h5 (one per mesh-variable) + {swarm}.swarm.h5 (one per swarm) +``` + +A quick ``h5ls -v my_run.snap.h5/metadata`` shows you the run name, +schema version, simulation time, step, dimensions, MPI rank count at +write, and inventories of meshes / swarms / variables / state-bearer +classes. For a Python-side summary use +``uw.checkpoint.inspect_snapshot(path)``. + +## Choosing between paths + +| Need | Use | +|---|---| +| Backtrack a few steps inside a running script (RK staging, adaptive Δt, predictor–corrector probes) | ``save_state()`` → token | +| Persist whole-model state across runs (crash recovery, bisection studies, full restart) | ``save_state(file=…)`` / ``load_state(file=…)`` | +| Restart from a previous run on a *different* rank count or remap onto a *different* resolution | ``mesh.write_timestep()`` / ``MeshVariable.read_timestep()`` (KDTree remap) | +| Efficient same-rank restart writing only specific variables for postprocessing | ``mesh.write_checkpoint()`` / ``MeshVariable.read_checkpoint()`` (PETSc DMPlex per-variable) | +| Visualisation for ParaView (XDMF + per-step HDF5) | ``mesh.write_timestep(create_xdmf=True)`` | + ## Related - [Parallel-Safe Scripting](parallel-computing.md) — MPI patterns; diff --git a/src/underworld3/checkpoint/__init__.py b/src/underworld3/checkpoint/__init__.py index c601b1ff..eada5e59 100644 --- a/src/underworld3/checkpoint/__init__.py +++ b/src/underworld3/checkpoint/__init__.py @@ -27,6 +27,16 @@ ) from .state import Snapshottable, SnapshottableState from .tracker import ModelTracker, TrackerState +from .disk_snapshot import ( + DISK_SNAPSHOT_SCHEMA_VERSION, + extract_var_via_bridge, + inspect_snapshot, + is_snapshot_wrapper, + read_snapshot, + read_snapshot_metadata, + write_snapshot, + write_snapshot_skeleton, +) __all__ = [ "CheckpointBackend", @@ -40,4 +50,12 @@ "SnapshottableState", "ModelTracker", "TrackerState", + "DISK_SNAPSHOT_SCHEMA_VERSION", + "extract_var_via_bridge", + "inspect_snapshot", + "is_snapshot_wrapper", + "read_snapshot", + "read_snapshot_metadata", + "write_snapshot", + "write_snapshot_skeleton", ] diff --git a/src/underworld3/checkpoint/disk_snapshot.py b/src/underworld3/checkpoint/disk_snapshot.py new file mode 100644 index 00000000..91716fce --- /dev/null +++ b/src/underworld3/checkpoint/disk_snapshot.py @@ -0,0 +1,986 @@ +"""On-disk snapshot format (v1.1) — metadata wrapper around PETSc bulk. + +Design (see ``memory/project_snapshot_v1_1_disk_format.md`` and +``docs/developer/design/in_memory_checkpoint_design.md``): + + - Single self-contained HDF5 file (with the freedom to externalise + bulky swarm data into companion files later — phase 3). + - UW3-controlled rich metadata wrapper around PETSc-format bulk + data, so ``h5ls`` / ``h5dump`` show useful information about a + snapshot file without UW3 needing to be in the loop. + - Bulk data layers (PETSc DMPlex topology, sections, vectors) are + delegated to the primitives that landed in #146; this module owns + the *layout* and the *metadata*, not the binary serialisation of + fields. + +File structure (target — phases 2+ fill in the bulk under these +groups; phase 1 writes the metadata and empty stub groups): + + my_run.snap.h5/ + ├── /metadata (attrs: uw3_version, schema_version, + │ created_at, step, sim_time, dt, dim, + │ mesh_type, coordinate_system, + │ mpi_ranks_at_write, variables_summary, ...) + ├── /mesh (phase 2 — DMPlex topology + coords + labels) + ├── /variables (phase 2 — one subgroup per mesh-variable) + ├── /swarms (phase 3 — possibly @external_file refs) + └── /python_state (phase 3 — Snapshottable dataclasses as attrs) + +Phase 1 (this commit): the metadata layer and the skeleton group +structure, with an inspectability acceptance test that asserts an +external reader (h5py here) sees meaningful information without any +UW3 imports. +""" + +from __future__ import annotations + +import dataclasses +import datetime +import json +import os +from typing import Any, Optional + +import numpy as np + +import underworld3 as uw + + +DISK_SNAPSHOT_SCHEMA_VERSION = 1 + +# Top-level group names — fixed; renaming would be a schema-version bump. +# Variables are NOT a top-level group; they nest under each mesh: +# /meshes/{name}/variables/{var}. Swarms similarly carry their own +# variables when phase 3 lands. +_GROUP_METADATA = "metadata" +_GROUP_MESHES = "meshes" +_GROUP_SWARMS = "swarms" +_GROUP_PYTHON_STATE = "python_state" + +_TOP_LEVEL_GROUPS = ( + _GROUP_METADATA, + _GROUP_MESHES, + _GROUP_SWARMS, + _GROUP_PYTHON_STATE, +) + + +def _collect_metadata(model) -> dict: + """Build the metadata dict that gets written into ``/metadata`` attrs. + + Stable, h5-friendly types only: strings, ints, floats, lists of + strings (stored as JSON for compactness and to keep h5 attrs + scalar-typed where possible). No pickling, no repr of UW3 objects. + """ + now_iso = datetime.datetime.now(datetime.timezone.utc).isoformat( + timespec="seconds" + ) + + # Mesh-derived info (gracefully absent if no mesh registered). + meshes = list(model._meshes.values()) + first_mesh = meshes[0] if meshes else None + mesh_names = [m.name for m in meshes] + if first_mesh is not None: + dim = int(first_mesh.dim) + mesh_type = type(first_mesh).__name__ + coord_system = ( + first_mesh.CoordinateSystem.type + if hasattr(first_mesh, "CoordinateSystem") + else "unknown" + ) + else: + dim = -1 + mesh_type = "" + coord_system = "" + + # Swarm names (WeakValueDictionary, snapshot it). + swarm_names = [] + for s in list(model._swarms.values()): + name = getattr(s, "name", None) or f"swarm_{s.instance_number}" + swarm_names.append(name) + + # State-bearer class summary (just the class names — useful in + # h5ls without needing UW3 to interpret). + state_bearer_classes = sorted( + {type(o).__name__ for o in list(model._state_bearers)} + ) + + # Tracker conventions (the model-dwelling record). + tracker = getattr(model, "tracker", None) + if tracker is not None: + sim_time = float(tracker.time) if tracker.time is not None else 0.0 + step = int(tracker.step) if tracker.step is not None else 0 + dt_val = tracker.dt + dt = float(dt_val) if dt_val is not None else float("nan") + else: + sim_time, step, dt = 0.0, 0, float("nan") + + # Per-variable summary across all registered meshes. + var_entries = [] + for m in meshes: + for var in m.vars.values(): + kind = "vector" if var.num_components > 1 else "scalar" + var_entries.append( + f"{m.name}.{var.clean_name} ({kind}, " + f"components={var.num_components}, degree={var.degree})" + ) + variables_summary = "; ".join(var_entries) if var_entries else "" + + md = { + # Versioning + "uw3_version": str(getattr(uw, "__version__", "0.0.0")), + "schema_version": int(DISK_SNAPSHOT_SCHEMA_VERSION), + "created_at": now_iso, + # Identity + "run_name": str(getattr(model, "name", "default")), + # Time / step (from the tracker — pre-seeded conventions) + "step": step, + "sim_time": sim_time, + "dt": dt, + # Geometry / topology + "dim": dim, + "mesh_type": mesh_type, + "coordinate_system": str(coord_system), + # MPI + "mpi_ranks_at_write": int(uw.mpi.size), + # Inventories — JSON for list-typed values so h5 attrs stay scalar. + "mesh_names_json": json.dumps(mesh_names), + "swarm_names_json": json.dumps(swarm_names), + "state_bearer_classes_json": json.dumps(state_bearer_classes), + "variables_summary": variables_summary, + } + return md + + +def _write_metadata_attrs(h5group, metadata: dict) -> None: + """Write a metadata dict as HDF5 attrs on a group. Plain types only.""" + for k, v in metadata.items(): + h5group.attrs[k] = v + + +def write_snapshot_skeleton(model, path: str) -> str: + """Phase 1: write the metadata + empty skeleton group structure. + + Returns the path written. Subsequent phases (2: mesh + meshvar + bulk; 3: swarms + python_state) populate the empty top-level + groups using PETSc primitives and dataclass serialisation + respectively. Writing is rank-0-only at this phase since no + collective PETSc operations are involved yet. + """ + import h5py + + with uw.selective_ranks(0) as should_execute: + if not should_execute: + uw.mpi.barrier() + return path + + metadata = _collect_metadata(model) + + with h5py.File(path, "w") as f: + md_group = f.create_group(_GROUP_METADATA) + _write_metadata_attrs(md_group, metadata) + + # Stub the other top-level groups so external readers can + # see the file's intended shape from day one — phases 2/3 + # populate them. + for name in ( + _GROUP_MESHES, + _GROUP_SWARMS, + _GROUP_PYTHON_STATE, + ): + grp = f.create_group(name) + grp.attrs["filled_by"] = "" # set to "phase2" / "phase3" later + + uw.mpi.barrier() + return path + + +def read_snapshot_metadata(path: str) -> dict: + """Read the ``/metadata`` group's attrs back as a plain dict. + + Validates the schema version. Lists stored as ``*_json`` are + decoded back into Python lists for caller convenience but the + on-disk form stays JSON for h5-tool friendliness. + """ + import h5py + + with h5py.File(path, "r") as f: + if _GROUP_METADATA not in f: + raise ValueError( + f"{path}: not a UW3 snapshot file (no /{_GROUP_METADATA} group)" + ) + md_group = f[_GROUP_METADATA] + md = {} + for k in md_group.attrs.keys(): + v = md_group.attrs[k] + # h5py returns bytes for some string attrs; normalise to str. + if isinstance(v, bytes): + v = v.decode() + elif isinstance(v, np.ndarray) and v.dtype.kind in ("S", "U"): + v = [x.decode() if isinstance(x, bytes) else str(x) for x in v] + md[k] = v + + schema = int(md.get("schema_version", -1)) + if schema != DISK_SNAPSHOT_SCHEMA_VERSION: + raise ValueError( + f"{path}: snapshot schema version {schema} does not match " + f"current {DISK_SNAPSHOT_SCHEMA_VERSION}; on-disk schema " + f"migration will land with phase 6 (not yet implemented)" + ) + + # Decode JSON-encoded list fields for caller convenience. + for key in list(md.keys()): + if key.endswith("_json"): + try: + decoded = json.loads(md[key]) + md[key[:-5]] = decoded # e.g. "mesh_names" alongside "mesh_names_json" + except (TypeError, ValueError, json.JSONDecodeError): + pass + + return md + + +# ----- Phase 2: mesh + meshvar bulk via #146's PETSc primitives ----- +# +# Layout convention: +# +# /path/to/run.snap.h5 wrapper (metadata, h5py-readable) +# /path/to/run.snap.bulk/ companion directory (one per snapshot) +# {mesh_safe}.mesh.00000.h5 mesh DM dump (PETSc HDF5) +# {mesh_safe}.{var_clean}.00000.h5 per-variable section + vec (PETSc HDF5) +# ... one set per (mesh, var) ... +# +# The bulk-dir path is derived from the wrapper path by convention, so a +# user opening just the wrapper file can find the bulk. They are a unit +# for portability — move them together. + + +def _bulk_dir_for(wrapper_path: str) -> str: + """Convention: wrapper at `run.snap.h5` ⇒ bulk at `run.snap.bulk/`.""" + base = wrapper_path[:-3] if wrapper_path.endswith(".h5") else wrapper_path + return base + ".bulk" + + +def _sanitise(name: str) -> str: + """Sanitise a mesh / variable name for use as a filename component. + + Replaces anything that isn't alphanumeric or in ``._-`` with ``_``. + Falls back to ``unnamed`` if the result is empty. The original name + is preserved in HDF5 group attrs as the ``@name`` field. + """ + safe = "".join(c if c.isalnum() or c in "._-" else "_" for c in name) + return safe or "unnamed" + + +def write_snapshot(model, path: str) -> str: + """Write a complete on-disk snapshot of the model's mesh + mesh-variable + state (phase 2 scope; swarms and python_state land in phase 3). + + Produces two artifacts: + + - ``path`` — the wrapper HDF5 file with rich metadata and the group + structure inspectable via ``h5ls``. + - ``_bulk_dir_for(path)`` — companion directory containing the + PETSc HDF5 files (mesh DM + per-variable section/vec) produced + by #146's :meth:`Mesh.write_checkpoint`. + + The two are a unit; move them together. Returns the wrapper path. + """ + import h5py + + # Phase-1 layer: metadata + skeleton groups. + write_snapshot_skeleton(model, path) + bulk_dir = _bulk_dir_for(path) + + # rank-0 creates the bulk directory; collective ops below need it + # to exist on the rank doing the PETSc-HDF5 write (which is rank 0 + # in this single-file write — actually PETSc's HDF5 viewer is + # collective, so all ranks participate). + with uw.selective_ranks(0) as rank0: + if rank0: + os.makedirs(bulk_dir, exist_ok=True) + uw.mpi.barrier() + + # For each registered mesh, drive #146's write_checkpoint into the + # bulk directory. write_checkpoint is collective (PETSc HDF5 + # viewer), so all ranks must participate. + mesh_records: list[dict] = [] + for mesh in list(model._meshes.values()): + mesh_safe = _sanitise(mesh.name) + mesh_vars = list(mesh.vars.values()) + # Filter to allocated variables — same skip rule as the in-memory + # path: lazy-allocated vars with _gvec == None have no data. + mesh_vars = [v for v in mesh_vars if v._gvec is not None] + + mesh.write_checkpoint( + mesh_safe, + outputPath=bulk_dir, + meshVars=mesh_vars, + index=0, + ) + + mesh_records.append({ + "name": mesh.name, + "safe_name": mesh_safe, + "mesh_file": f"{mesh_safe}.mesh.00000.h5", + "vars": [ + { + "name": v.clean_name, + "components": int(v.num_components), + "degree": int(v.degree), + "continuous": bool(v.continuous), + # Per-variable file produced by Mesh.write_checkpoint + # at outputPath: "{base}.{var.clean_name}.{index:05}.h5". + "external_file": ( + f"{mesh_safe}.{v.clean_name}.00000.h5" + ), + } + for v in mesh_vars + ], + }) + + # Reopen the wrapper to populate /meshes with the per-mesh records + # and to mark the groups filled. + with uw.selective_ranks(0) as rank0: + if rank0: + with h5py.File(path, "a") as f: + meshes_group = f[_GROUP_MESHES] + meshes_group.attrs["filled_by"] = "phase2" + meshes_group.attrs["bulk_dir"] = os.path.basename(bulk_dir) + + for rec in mesh_records: + g = meshes_group.create_group(rec["safe_name"]) + g.attrs["name"] = rec["name"] + g.attrs["mesh_file"] = rec["mesh_file"] + + vars_g = g.create_group("variables") + for var_rec in rec["vars"]: + v = vars_g.create_group(_sanitise(var_rec["name"])) + v.attrs["name"] = var_rec["name"] + v.attrs["components"] = var_rec["components"] + v.attrs["degree"] = var_rec["degree"] + v.attrs["continuous"] = var_rec["continuous"] + v.attrs["external_file"] = var_rec["external_file"] + + # Phase 3a: state-bearer dataclass serialisation. + ps_group = f[_GROUP_PYTHON_STATE] + ps_group.attrs["filled_by"] = "phase3a" + for obj in list(model._state_bearers): + key = f"{type(obj).__name__}_{obj.instance_number}" + if key in ps_group: + continue # idempotent if write_snapshot called twice + bg = ps_group.create_group(key) + _write_state_bearer_to_group(bg, obj) + + # Phase 3b + 6: swarms — per-rank sidecar files in the bulk dir, + # referenced from /swarms/{swarm_safe}/ via a sidecar_pattern attr. + # Every rank writes its own sidecar; all-ranks-participate so the + # bulk dir is complete before the wrapper records the layout. + rank = int(uw.mpi.rank) + size = int(uw.mpi.size) + swarm_records: list[dict] = [] + for swarm in list(model._swarms.values()): + swarm_safe = _swarm_safe_name(swarm) + sidecar_name = _swarm_sidecar_filename(swarm_safe, rank, size) + sidecar_path = os.path.join(bulk_dir, sidecar_name) + rec = _write_swarm_to_sidecar(swarm, sidecar_path) + rec["safe_name"] = swarm_safe + rec["sidecar_pattern"] = _swarm_sidecar_pattern(swarm_safe) + swarm_records.append(rec) + uw.mpi.barrier() + + # Wrapper update is rank-0-only; gather global counts across + # ranks first so the wrapper carries a complete inventory. + if swarm_records: + try: + from mpi4py import MPI + + comm = MPI.COMM_WORLD + global_counts = { + rec["safe_name"]: int( + comm.allreduce(rec["num_particles_local"], op=MPI.SUM) + ) + for rec in swarm_records + } + except ImportError: + global_counts = { + rec["safe_name"]: int(rec["num_particles_local"]) + for rec in swarm_records + } + + with uw.selective_ranks(0) as rank0: + if rank0: + with h5py.File(path, "a") as f: + sw_group = f[_GROUP_SWARMS] + sw_group.attrs["filled_by"] = "phase3b+phase6" + sw_group.attrs["mpi_size_at_write"] = size + for rec in swarm_records: + g = sw_group.create_group(rec["safe_name"]) + g.attrs["mesh_name"] = rec["mesh_name"] + g.attrs["num_particles_global"] = global_counts[ + rec["safe_name"] + ] + g.attrs["population_generation"] = rec[ + "population_generation" + ] + g.attrs["sidecar_pattern"] = rec["sidecar_pattern"] + g.attrs["mpi_size_at_write"] = size + vars_g = g.create_group("variables") + for var_rec in rec["vars"]: + v = vars_g.create_group(_sanitise(var_rec["name"])) + v.attrs["name"] = var_rec["name"] + v.attrs["num_components"] = var_rec[ + "num_components" + ] + + uw.mpi.barrier() + return path + + +def read_snapshot(model, path: str) -> None: + """Load mesh-variable DOFs from an on-disk snapshot into the model. + + The model must already have the same meshes (by name) and the + same variables (by ``clean_name``) registered — this is the + same-rank-count restart path that mirrors :func:`restore` for the + in-memory snapshot. Cross-run / rebuild-on-load is v1.2 scope. + + Bulk data is read via #146's :meth:`MeshVariable.read_checkpoint`; + no KDTree remapping (that's phase 4's compatibility layer in + ``read_timestep``). + """ + import h5py + + md = read_snapshot_metadata(path) + bulk_dir = _bulk_dir_for(path) + if not os.path.isdir(bulk_dir): + raise FileNotFoundError( + f"snapshot bulk directory missing: {bulk_dir} (expected next " + f"to wrapper {path})" + ) + + # Build {original_name -> registered Mesh} for lookup + meshes_by_name = {m.name: m for m in model._meshes.values()} + + with h5py.File(path, "r") as f: + meshes_group = f[_GROUP_MESHES] + for mesh_safe in meshes_group.keys(): + g = meshes_group[mesh_safe] + mesh_name = str(g.attrs.get("name", mesh_safe)) + mesh = meshes_by_name.get(mesh_name) + if mesh is None: + raise ValueError( + f"snapshot at {path} contains mesh {mesh_name!r} which " + f"is not registered on this model " + f"(registered: {sorted(meshes_by_name.keys())})" + ) + + current_vars = {v.clean_name: v for v in mesh.vars.values()} + vars_g = g["variables"] + for var_safe in vars_g.keys(): + v_attrs = vars_g[var_safe].attrs + var_name = str(v_attrs["name"]) + external_file = str(v_attrs["external_file"]) + var = current_vars.get(var_name) + if var is None: + raise ValueError( + f"snapshot variable {var_name!r} not registered on " + f"mesh {mesh_name!r}" + ) + var.read_checkpoint( + os.path.join(bulk_dir, external_file), + data_name=var_name, + ) + + # Phase 3a: restore state-bearer dataclasses. + if _GROUP_PYTHON_STATE in f: + ps_group = f[_GROUP_PYTHON_STATE] + bearers_by_key = { + f"{type(o).__name__}_{o.instance_number}": o + for o in list(model._state_bearers) + } + for key in ps_group.keys(): + obj = bearers_by_key.get(key) + if obj is None: + raise ValueError( + f"snapshot at {path} contains state-bearer {key!r} " + f"that is not registered on this model" + ) + _read_state_bearer_into(ps_group[key], obj) + + # Phase 3b + 6: restore swarms from per-rank sidecars. + if _GROUP_SWARMS in f: + sw_group = f[_GROUP_SWARMS] + if "mpi_size_at_write" in sw_group.attrs: + write_size = int(sw_group.attrs["mpi_size_at_write"]) + if write_size != int(uw.mpi.size): + raise ValueError( + f"snapshot at {path} was written on {write_size} " + f"MPI rank(s); this run is on {uw.mpi.size}. " + f"Cross-rank-count restore is not supported by the " + f"snapshot mechanism — use mesh.write_timestep for " + f"that case." + ) + + swarms_by_safe = { + _swarm_safe_name(s): s for s in list(model._swarms.values()) + } + rank = int(uw.mpi.rank) + size = int(uw.mpi.size) + for swarm_safe in sw_group.keys(): + g = sw_group[swarm_safe] + swarm = swarms_by_safe.get(swarm_safe) + if swarm is None: + raise ValueError( + f"snapshot at {path} contains swarm {swarm_safe!r} " + f"that is not registered on this model" + ) + # Resolve the per-rank sidecar name from the pattern. + pattern = str(g.attrs["sidecar_pattern"]) + sidecar_name = pattern.format(rank=rank, size=size) + _read_swarm_from_sidecar( + swarm, os.path.join(bulk_dir, sidecar_name) + ) + + +# ----- Phase 3b: swarm sidecars -------------------------------------------- +# +# Swarms always go to their own per-swarm sidecar file from day one, +# per Louis's "bulk is a problem with swarms, always" — no inline-vs- +# split toggle. The wrapper's /swarms/{swarm_safe}/ records metadata +# + an `@external_file` ref pointing at the sidecar in the bulk dir. +# +# The sidecar is h5py-direct (no PETSc). Swarms aren't DMPlex +# section/vec; they're per-particle numpy arrays. +# +# Single-rank now; MPI gets phase 6 treatment (per-rank sidecars or +# parallel-HDF5). + + +def _swarm_safe_name(swarm) -> str: + """Stable name for a swarm in the snapshot layout. Mirrors the + in-memory snapshot's `_snapshot_stable_name`.""" + raw = getattr(swarm, "name", None) or f"swarm_{swarm.instance_number}" + return _sanitise(raw) + + +def _swarm_sidecar_filename(swarm_safe: str, rank: int, size: int) -> str: + """Per-rank sidecar (phase 6 — parallel-safe). + + Each rank writes its own swarm sidecar; restoring requires the + same rank count. The filename carries both the writer's rank and + the total rank count so each file is self-describing. + """ + return f"{swarm_safe}.swarm.rank{rank:04d}of{size:04d}.h5" + + +def _swarm_sidecar_pattern(swarm_safe: str) -> str: + """Pattern stored in the wrapper for readers to fill in their own + (rank, size) when locating their sidecar.""" + return f"{swarm_safe}.swarm.rank{{rank:04d}}of{{size:04d}}.h5" + + +def _write_swarm_to_sidecar(swarm, sidecar_path: str) -> dict: + """Write a swarm's local-particle state to a sidecar h5 file. + + Returns a record dict the caller uses to populate the wrapper's + /swarms/{name}/ group. + """ + import h5py + + coord_field = swarm.dm.getField("DMSwarmPIC_coor").reshape( + (-1, swarm.dim) + ) + coords = np.asarray(coord_field).copy() + swarm.dm.restoreField("DMSwarmPIC_coor") + + var_records: list[dict] = [] + with h5py.File(sidecar_path, "w") as f: + # File-level metadata — h5ls -v on the sidecar tells you what + # it is without needing UW3 (matches the wrapper's bar). + f.attrs["num_particles_local"] = int(coords.shape[0]) + f.attrs["dim"] = int(swarm.dim) + f.attrs["mesh_name"] = str(swarm.mesh.name) + f.attrs["population_generation"] = int(swarm._population_generation) + # Parallel-write provenance — each per-rank sidecar carries + # its writer's identity so the reader can sanity-check it + # opened the right file. + f.attrs["mpi_rank"] = int(uw.mpi.rank) + f.attrs["mpi_size_at_write"] = int(uw.mpi.size) + + f.create_dataset("coordinates", data=coords) + + vars_g = f.create_group("variables") + for var in list(swarm._vars.values()): + # Filter PETSc-internal variables — same rule as the + # in-memory swarm capture. + if var.name.startswith("DMSwarm"): + continue + data = np.asarray(var.data).copy() + d = vars_g.create_dataset(var.clean_name, data=data) + d.attrs["num_components"] = int(var.num_components) + d.attrs["dtype"] = str(data.dtype) + var_records.append({ + "name": var.clean_name, + "num_components": int(var.num_components), + }) + + return { + "num_particles_local": int(coords.shape[0]), + "mesh_name": str(swarm.mesh.name), + "population_generation": int(swarm._population_generation), + "vars": var_records, + } + + +def _read_swarm_from_sidecar(swarm, sidecar_path: str) -> None: + """Restore swarm state from a sidecar file. Mirrors + :meth:`Swarm.apply_snapshot_payload` exactly — clear local + particles, re-add at saved coords, write var data back.""" + import h5py + + with h5py.File(sidecar_path, "r") as f: + saved_coords = np.asarray(f["coordinates"][...]) + captured_mesh_name = str(f.attrs.get("mesh_name", "")) + captured_rank = int(f.attrs.get("mpi_rank", 0)) + captured_size = int(f.attrs.get("mpi_size_at_write", 1)) + var_data: dict[str, np.ndarray] = {} + if "variables" in f: + for name in f["variables"].keys(): + var_data[name] = np.asarray(f["variables"][name][...]) + + if captured_size != uw.mpi.size: + raise ValueError( + f"sidecar at {sidecar_path}: was written on " + f"{captured_size} MPI rank(s); this run is on {uw.mpi.size}. " + f"Cross-rank-count snapshot restore is out of scope — restart " + f"on the same rank count or use mesh.write_timestep for the " + f"flexible-restart path." + ) + if captured_rank != uw.mpi.rank: + raise ValueError( + f"sidecar at {sidecar_path}: written by rank {captured_rank}; " + f"this is rank {uw.mpi.rank}. Wrong per-rank sidecar opened." + ) + if captured_mesh_name and captured_mesh_name != swarm.mesh.name: + raise ValueError( + f"sidecar at {sidecar_path}: parent mesh was " + f"{captured_mesh_name!r}, target swarm is on {swarm.mesh.name!r}" + ) + + # Clear local population. removePoint is O(1) per call (last point), + # so this is O(N) total — same approach as Swarm.apply_snapshot_payload. + while swarm.dm.getLocalSize() > 0: + swarm.dm.removePoint() + + n_saved = int(saved_coords.shape[0]) + if n_saved > 0: + swarm.dm.finalizeFieldRegister() + swarm.dm.addNPoints(npoints=n_saved) + + coord_field = swarm.dm.getField("DMSwarmPIC_coor").reshape( + (-1, swarm.dim) + ) + coord_field[...] = saved_coords + swarm.dm.restoreField("DMSwarmPIC_coor") + + rank_field = swarm.dm.getField("DMSwarm_rank") + rank_field[...] = uw.mpi.rank + swarm.dm.restoreField("DMSwarm_rank") + + # Invalidate canonical-data caches so subsequent var.data reads + # re-resolve to the rebuilt PETSc fields. + if hasattr(swarm._particle_coordinates, "_canonical_data"): + swarm._particle_coordinates._canonical_data = None + for var in swarm._vars.values(): + if hasattr(var, "_canonical_data"): + var._canonical_data = None + + # Restore counted as a population change (matches in-memory path). + swarm._population_generation += 1 + + # Write per-variable captured data back into the freshly-rebuilt + # swarm. + current_vars = {v.clean_name: v for v in swarm._vars.values()} + for var_name, saved in var_data.items(): + var = current_vars.get(var_name) + if var is None: + raise ValueError( + f"sidecar variable {var_name!r} is not present on the " + f"target swarm; restore requires the same variable set" + ) + current = np.asarray(var.data) + if current.shape != saved.shape: + raise ValueError( + f"swarm variable {var_name!r} shape mismatch on restore: " + f"sidecar {saved.shape} vs current {current.shape}" + ) + current[...] = saved + + +# ----- Phase 3a: state-bearer (Snapshottable) serialisation ---------------- +# +# Each registered state-bearer (DDt instances, ModelTracker, future +# helpers) exposes a `.state` property returning a SnapshottableState +# dataclass. We serialise each dataclass's fields into a per-bearer +# HDF5 group under /python_state, keyed by the same stable name the +# in-memory snapshot uses: f"{type(obj).__name__}_{obj.instance_number}". +# +# Serialisation is *generic over dataclass fields* — no per-class +# special code. Handled value types: None, bool, int, float, str, +# numpy.ndarray, list (JSON-encoded), dict (recursive subgroup). Other +# types (notably sympy expressions in DDtSymbolicState.psi_star) are +# marked with `_skipped` and not round-tripped — documented as +# a v1.x limitation; consumers either use a non-Symbolic DDt flavor +# or accept the psi_star reset. + + +_NULL_SENTINEL = "__none__" +_TYPE_ATTR = "__bearer_class__" +_DATACLASS_ATTR = "__state_class__" + + +def _is_h5_attr_scalar(value: Any) -> bool: + return isinstance(value, (bool, int, float, str)) and not isinstance( + value, bool + ) or isinstance(value, (bool, str)) + + +def _serialise_field(h5group, name: str, value: Any) -> None: + """Write a Python value into an HDF5 group as attr/dataset/subgroup. + + The shape of the storage records the type: + - attr scalar (int/float/bool/str) for scalars + - attr `` = '__none__' for None + - attr `__json` for JSON-encodable lists / nested simple structures + - dataset `` for numpy arrays + - subgroup `` for dict values, recursing + - attr `__skipped` = '' for anything else + """ + if value is None: + h5group.attrs[name] = _NULL_SENTINEL + return + if isinstance(value, (bool, int, float)): + h5group.attrs[name] = value + return + if isinstance(value, str): + h5group.attrs[name] = value + return + if isinstance(value, np.ndarray): + if name in h5group: + del h5group[name] + h5group.create_dataset(name, data=value) + return + if isinstance(value, dict): + if name in h5group: + del h5group[name] + sub = h5group.create_group(name) + for k, v in value.items(): + _serialise_field(sub, str(k), v) + return + if isinstance(value, (list, tuple)): + try: + h5group.attrs[name + "__json"] = json.dumps(list(value)) + return + except (TypeError, ValueError): + h5group.attrs[name + "__skipped"] = ( + f"unserialisable list (len={len(value)}, " + f"first-type={type(value[0]).__name__ if value else 'empty'})" + ) + return + h5group.attrs[name + "__skipped"] = ( + f"unserialisable type {type(value).__name__}" + ) + + +def _group_to_dict(h5group) -> dict: + """Read a subgroup back as a plain dict — symmetric to the dict + branch of :func:`_serialise_field`. Recurses for nested groups.""" + import h5py + + out: dict = {} + for k in h5group.attrs.keys(): + if k.endswith("__skipped"): + continue + if k.endswith("__json"): + out[k[: -len("__json")]] = json.loads(h5group.attrs[k]) + continue + v = h5group.attrs[k] + if isinstance(v, str) and v == _NULL_SENTINEL: + out[k] = None + elif isinstance(v, np.generic): + out[k] = v.item() + else: + out[k] = v + for k in h5group.keys(): + item = h5group[k] + if isinstance(item, h5py.Group): + out[k] = _group_to_dict(item) + else: + out[k] = np.asarray(item[...]) + return out + + +def _deserialise_field(h5group, name: str, fallback: Any) -> Any: + """Inverse of :func:`_serialise_field`. Returns ``fallback`` if the + field was skipped at write time, so we don't clobber a sensible + default with a placeholder.""" + import h5py + + if name in h5group: + item = h5group[name] + if isinstance(item, h5py.Group): + return _group_to_dict(item) + return np.asarray(item[...]) # h5py.Dataset + + if name in h5group.attrs: + v = h5group.attrs[name] + if isinstance(v, str) and v == _NULL_SENTINEL: + return None + if isinstance(v, np.generic): + return v.item() + return v + + if (name + "__json") in h5group.attrs: + return json.loads(h5group.attrs[name + "__json"]) + + if (name + "__skipped") in h5group.attrs: + # Skipped at write time — keep the current value rather than + # clobber it with a placeholder. + return fallback + + return fallback + + +def _write_state_bearer_to_group(group, obj) -> None: + """Serialise a Snapshottable's .state into the given HDF5 group.""" + state = obj.state + group.attrs[_TYPE_ATTR] = type(obj).__name__ + group.attrs[_DATACLASS_ATTR] = type(state).__name__ + group.attrs["instance_number"] = int(obj.instance_number) + + for f in dataclasses.fields(state): + _serialise_field(group, f.name, getattr(state, f.name)) + + +def _read_state_bearer_into(group, obj) -> None: + """Restore a Snapshottable's .state from a group written by + :func:`_write_state_bearer_to_group`. Uses the live ``obj.state`` + as a type template — fields that were skipped at write time keep + their current value rather than being clobbered by a placeholder. + """ + current_state = obj.state + captured_class = str(group.attrs.get(_DATACLASS_ATTR, "")) + if captured_class and captured_class != type(current_state).__name__: + raise ValueError( + f"state-bearer class mismatch: snapshot expects " + f"{captured_class}, current is {type(current_state).__name__}" + ) + + overrides = {} + for f in dataclasses.fields(current_state): + new_val = _deserialise_field(group, f.name, getattr(current_state, f.name)) + overrides[f.name] = new_val + obj.state = dataclasses.replace(current_state, **overrides) + + +def is_snapshot_wrapper(path: str) -> bool: + """Quick check whether ``path`` is a v1.1 snapshot wrapper file. + + Used by :meth:`MeshVariable.read_timestep` to dispatch between + the legacy per-variable layout and the v1.1 sidecar format — + same user call, different storage, hidden behind the function. + """ + import h5py + + try: + with h5py.File(path, "r") as f: + return _GROUP_METADATA in f and _GROUP_MESHES in f + except (OSError, KeyError): + return False + + +def extract_var_via_bridge(wrapper_path: str, var_name: str): + """Bridge for selective per-variable reads of v1.1 snapshots. + + Given the wrapper path and a variable name, returns + ``(coords, values)`` numpy arrays — exactly what + :meth:`MeshVariable.read_timestep` produces on rank 0 from the + legacy layout. The rest of read_timestep's swarm-routing + + KDTree machinery is format-agnostic; this bridge is what makes + ``read_timestep`` work transparently against new files. + + Mechanism: load the source mesh from the .mesh.h5 sidecar, + rebuild the source variable with the correct shape, load DOFs + via #146's ``read_checkpoint``, then read out ``var.coords`` + + ``var.array``. + """ + import h5py + + bulk_dir = _bulk_dir_for(wrapper_path) + found = None + with h5py.File(wrapper_path, "r") as f: + for mesh_safe in f[_GROUP_MESHES].keys(): + mg = f[_GROUP_MESHES][mesh_safe] + for var_safe in mg["variables"].keys(): + v_attrs = mg["variables"][var_safe].attrs + if str(v_attrs["name"]) == var_name: + found = ( + str(mg.attrs["mesh_file"]), + str(v_attrs["external_file"]), + int(v_attrs["degree"]), + int(v_attrs["components"]), + bool(v_attrs["continuous"]), + ) + break + if found: + break + if found is None: + raise ValueError( + f"variable {var_name!r} not found in v1.1 snapshot {wrapper_path}" + ) + + mesh_file_rel, var_file_rel, degree, components, continuous = found + # Rebuild a transient source mesh + variable to read DOFs into. + # We deliberately don't register them with the live model — these + # are throwaway and exit scope on return. + src_mesh = uw.discretisation.Mesh(os.path.join(bulk_dir, mesh_file_rel)) + src_var = uw.discretisation.MeshVariable( + var_name, src_mesh, components, degree=degree, continuous=continuous, + ) + src_var.read_checkpoint( + os.path.join(bulk_dir, var_file_rel), data_name=var_name + ) + + coords = np.asarray(src_var.coords).copy() + values = np.asarray(src_var.array[...]).reshape( + coords.shape[0], components + ).copy() + return coords, values + + +def inspect_snapshot(path: str) -> str: + """Human-readable one-shot summary of a snapshot file's metadata. + + Useful as a Python-side equivalent to running ``h5ls`` on the + file; intended for ``print(uw.checkpoint.inspect_snapshot(path))`` + at a notebook prompt. + """ + md = read_snapshot_metadata(path) + lines = [ + f"UW3 snapshot: {path}", + f" run_name : {md.get('run_name', '?')}", + f" created_at : {md.get('created_at', '?')}", + f" uw3_version : {md.get('uw3_version', '?')}", + f" schema_version : {md.get('schema_version', '?')}", + f" step / sim_time / dt : {md.get('step', '?')} / " + f"{md.get('sim_time', '?')} / {md.get('dt', '?')}", + f" dim / mesh_type : {md.get('dim', '?')} / " + f"{md.get('mesh_type', '?')}", + f" coordinate_system : {md.get('coordinate_system', '?')}", + f" mpi_ranks_at_write : {md.get('mpi_ranks_at_write', '?')}", + f" meshes : {md.get('mesh_names', [])}", + f" swarms : {md.get('swarm_names', [])}", + f" state_bearer_classes : {md.get('state_bearer_classes', [])}", + f" variables_summary : {md.get('variables_summary', '')}", + ] + return "\n".join(lines) diff --git a/src/underworld3/discretisation/discretisation_mesh_variables.py b/src/underworld3/discretisation/discretisation_mesh_variables.py index 78ad8249..9ee9a061 100644 --- a/src/underworld3/discretisation/discretisation_mesh_variables.py +++ b/src/underworld3/discretisation/discretisation_mesh_variables.py @@ -1206,14 +1206,38 @@ def read_timestep( ``file_size`` per rank. """ + # Format dispatch: ``data_filename`` may be either the + # legacy ``write_timestep`` prefix (in which case we + # reconstruct the per-variable file path the usual way) or a + # v1.1 snapshot wrapper path produced by + # ``model.save_state(file=…)``. The format-detection logic is + # hidden from the user — same call, both formats. + import h5py + import numpy as np + from underworld3.checkpoint.disk_snapshot import ( + is_snapshot_wrapper as _is_snapshot_wrapper, + extract_var_via_bridge as _extract_var_via_bridge, + ) + output_base_name = os.path.join(outputPath, data_filename) - data_file = output_base_name + f".mesh.{data_name}.{index:05}.h5" + legacy_file = output_base_name + f".mesh.{data_name}.{index:05}.h5" - if not os.path.isfile(os.path.abspath(data_file)): - raise RuntimeError(f"{os.path.abspath(data_file)} does not exist") + is_v1_1 = ( + os.path.isfile(data_filename) + and not data_filename.endswith( + f".mesh.{data_name}.{index:05}.h5" + ) + and _is_snapshot_wrapper(data_filename) + ) - import h5py - import numpy as np + if is_v1_1: + data_file = data_filename + else: + data_file = legacy_file + if not os.path.isfile(os.path.abspath(data_file)): + raise RuntimeError( + f"{os.path.abspath(data_file)} does not exist" + ) # ``self.num_components`` is correct for SCALAR (1), VECTOR (dim), # TENSOR (dim**2) and SYM_TENSOR (dim*(dim+1)/2). ``self.shape[1]`` @@ -1236,10 +1260,21 @@ def read_timestep( if uw.mpi.rank == 0: if verbose: - print(f"Reading data file {data_file}", flush=True) - with h5py.File(data_file, "r") as h5f: - X_src = h5f["fields"]["coordinates"][()].reshape(-1, dim) - D_src = h5f["fields"][data_name][()].reshape(-1, n_components) + print( + f"Reading data file {data_file} " + f"(format: {'v1.1 snapshot' if is_v1_1 else 'legacy timestep'})", + flush=True, + ) + if is_v1_1: + X_src, D_src = _extract_var_via_bridge(data_file, data_name) + X_src = X_src.reshape(-1, dim) + D_src = D_src.reshape(-1, n_components) + else: + with h5py.File(data_file, "r") as h5f: + X_src = h5f["fields"]["coordinates"][()].reshape(-1, dim) + D_src = h5f["fields"][data_name][()].reshape( + -1, n_components + ) else: X_src = np.empty((0, dim), dtype=np.double) D_src = np.empty((0, n_components), dtype=np.double) diff --git a/src/underworld3/model.py b/src/underworld3/model.py index 547e690a..f2b485f3 100644 --- a/src/underworld3/model.py +++ b/src/underworld3/model.py @@ -610,35 +610,78 @@ def _register_state_bearer(self, obj) -> None: """ self._state_bearers.add(obj) - def snapshot(self, *, path: Optional[str] = None): - """Capture a unitary in-memory snapshot of the model's state. + def save_state(self, *, file: Optional[str] = None): + """Save the model's current state — memory or disk, one entry point. - v1 covers mesh coordinates and mesh-variable DOFs across every - registered mesh. Subsequent PRs extend coverage to swarms and - solver-internal Python state. + Without ``file``, captures an in-memory :class:`Snapshot` + token suitable for in-run backtracking ("stash for + timesteps"). The token is plain Python / numpy — fast to + produce, fast to restore, does not survive the process. - Pass ``path=...`` to write to an HDF5 file once the on-disk - backend lands (v1.1); v1 raises ``NotImplementedError``. + With ``file=``, writes a persistent on-disk snapshot at + that path (plus a sibling ``.bulk/`` directory holding the + bulk PETSc + swarm sidecars). Survives the process; suitable + for restart, postprocessing, transferring runs. - See ``docs/developer/design/in_memory_checkpoint_design.md`` - for the full design. + Either way the captured state is the full model: all + registered meshes and mesh-variables, all swarms with + per-particle data, all solver-internal state-bearers + (:class:`ModelTracker`, ``DDt`` instances, anything else + exposing the ``Snapshottable`` contract). + + Parameters + ---------- + file : str, optional + Path to write a disk snapshot to. If omitted, an in-memory + token is returned instead. + + Returns + ------- + Snapshot + When called without ``file`` — pass to + :meth:`load_state` to restore. + str + When ``file`` is given — the path the snapshot was + written to (same as ``file``). """ from underworld3.checkpoint import snapshot as _snapshot + from underworld3.checkpoint import write_snapshot as _write_snapshot - return _snapshot(self, path=path) + if file is None: + return _snapshot(self) + return _write_snapshot(self, file) - def restore(self, snap) -> None: - """Restore the model from a :class:`underworld3.checkpoint.Snapshot`. + def load_state(self, source) -> None: + """Restore the model from a previously saved state — memory or disk. - Within-process restore: ``snap`` must have been produced by - :meth:`snapshot` on this same ``Model`` instance. Raises - :class:`underworld3.checkpoint.SnapshotInvalidatedError` if - the mesh has been adapted, or a captured mesh / variable is - no longer registered. + ``source`` is either: + + - a :class:`Snapshot` token returned by an earlier + ``save_state()`` call (in-memory restore — bit-exact), + - a path string to a disk-snapshot file (disk restore — + same-rank, same-model contract; mesh-rebuild on read is + v1.2 scope). + + Raises + ------ + TypeError + ``source`` is neither a :class:`Snapshot` nor a string. + :class:`SnapshotInvalidatedError` + The captured state no longer matches what is registered + (mesh adapted, state-bearer missing, ...). """ + from underworld3.checkpoint import Snapshot from underworld3.checkpoint import restore as _restore - - return _restore(self, snap) + from underworld3.checkpoint import read_snapshot as _read_snapshot + + if isinstance(source, Snapshot): + return _restore(self, source) + if isinstance(source, (str, os.PathLike)): + return _read_snapshot(self, str(source)) + raise TypeError( + f"load_state expects a Snapshot token or a path string, " + f"got {type(source).__name__}" + ) def define_parameter(self, name: str, ptype=None, **kwargs): """ diff --git a/tests/parallel/mpi_runner.sh b/tests/parallel/mpi_runner.sh index 3b479dbe..1c74260f 100755 --- a/tests/parallel/mpi_runner.sh +++ b/tests/parallel/mpi_runner.sh @@ -31,3 +31,10 @@ echo "ptest 0007 snapshot in-memory -np 3 (uneven partition)" mpirun -np 3 $PYTHON ./ptest_0007_snapshot_inmemory.py echo "ptest 0007 snapshot in-memory -np 4" mpirun -np 4 $PYTHON ./ptest_0007_snapshot_inmemory.py + +echo "ptest 0010 snapshot on-disk -np 1" +mpirun -np 1 $PYTHON ./ptest_0010_snapshot_disk.py +echo "ptest 0010 snapshot on-disk -np 3 (uneven)" +mpirun -np 3 $PYTHON ./ptest_0010_snapshot_disk.py +echo "ptest 0010 snapshot on-disk -np 4" +mpirun -np 4 $PYTHON ./ptest_0010_snapshot_disk.py diff --git a/tests/parallel/ptest_0007_snapshot_inmemory.py b/tests/parallel/ptest_0007_snapshot_inmemory.py index d533d046..ffe3185e 100644 --- a/tests/parallel/ptest_0007_snapshot_inmemory.py +++ b/tests/parallel/ptest_0007_snapshot_inmemory.py @@ -112,12 +112,12 @@ def main(): pre_count = comm.allreduce(swarm.dm.getLocalSize(), op=MPI.SUM) pre_ddt = (list(ddt.state.dt_history), ddt.state.n_solves_completed) - snap = model.snapshot() + snap = model.save_state() # --- Property 1 + 2: a migration-inducing step, then restore --- step(uw, V_fn, T, swarm, ddt, 0.3) # bigger dt -> more migration mid_count = comm.allreduce(swarm.dm.getLocalSize(), op=MPI.SUM) - model.restore(snap) + model.load_state(snap) post = global_sorted_particles(swarm, gid, material) post_count = comm.allreduce(swarm.dm.getLocalSize(), op=MPI.SUM) post_ddt = (list(ddt.state.dt_history), ddt.state.n_solves_completed) @@ -126,15 +126,15 @@ def main(): ddt_ok = pre_ddt == post_ddt # --- Property 3: bit-identical continuation across a stash --- - snap2 = model.snapshot() + snap2 = model.save_state() for _ in range(4): step(uw, V_fn, T, swarm, ddt, 0.1) ctrl = global_sorted_particles(swarm, gid, material) ctrl_ddt = (list(ddt.state.dt_history), ddt.state.n_solves_completed) - model.restore(snap2) + model.load_state(snap2) step(uw, V_fn, T, swarm, ddt, 0.5) # the regretted step - model.restore(snap2) + model.load_state(snap2) for _ in range(4): step(uw, V_fn, T, swarm, ddt, 0.1) stash = global_sorted_particles(swarm, gid, material) diff --git a/tests/parallel/ptest_0010_snapshot_disk.py b/tests/parallel/ptest_0010_snapshot_disk.py new file mode 100644 index 00000000..b9682313 --- /dev/null +++ b/tests/parallel/ptest_0010_snapshot_disk.py @@ -0,0 +1,178 @@ +"""Parallel (MPI) test of the on-disk snapshot path (v1.1). + +Phase 6 of the snapshot toolkit: per-rank swarm sidecars. The mesh ++ mesh-variable disk path is already parallel-correct via #146's +PETSc-collective HDF5 viewer; the swarm sidecar layer needs its +own per-rank file per swarm. This ptest exercises both together at +multi-rank. + +Run (4 ranks exercises cross-rank distribution of swarm particles): + + cd tests/parallel + mpirun -np 4 python ./ptest_0010_snapshot_disk.py + +Asserts (collective, checked on rank 0): + + 1. Disk write produces one wrapper file + one swarm sidecar per + rank in the bulk dir, each with the rank+size in its filename. + 2. Each rank's sidecar carries the writing rank's local-particle + state (verified by per-rank attrs on the sidecar). + 3. Round-trip is exact: scribble all variables + swarm coords + + swarm-var data, model.load_state(file=...), gathered (gid, x, y, + material) tables sorted by gid are np.array_equal. +""" + +import os + +import numpy as np +import sympy +from mpi4py import MPI + +import underworld3 as uw + +comm = MPI.COMM_WORLD +rank = uw.mpi.rank +size = uw.mpi.size + + +def build(): + uw.reset_default_model() + uw.use_strict_units(False) + uw.use_nondimensional_scaling(False) + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(4.0, 1.0), cellSize=1.0 / 6.0 + ) + x_sym, y_sym = mesh.X + V_fn = sympy.Matrix([[-(y_sym - 0.5), 0.25 * (x_sym - 2.0)]]).T + + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=1) + T.array[:, 0, 0] = mesh.X.coords[:, 0] - mesh.X.coords[:, 1] + + swarm = uw.swarm.Swarm(mesh) + gid = swarm.add_variable("gid", 1, dtype=float) + material = swarm.add_variable("material", 1, dtype=float) + swarm.populate(fill_param=2) + + local_n = swarm.dm.getLocalSize() + counts = comm.allgather(local_n) + offset = int(np.sum(counts[:rank])) + gid.data[:, 0] = offset + np.arange(local_n, dtype=float) + material.data[:, 0] = swarm._particle_coordinates.data[:, 0] + + model.tracker.time = 1.5 + model.tracker.step = 7 + return uw, model, mesh, V_fn, T, swarm, gid, material + + +def global_sorted_state(T, swarm, gid, material): + """Gather (gid, x, y, material, T-value-by-coord-bin) across ranks + + sort by gid → order/rank-independent canonical view.""" + g = gid.data[:, 0].copy() + coords = swarm._particle_coordinates.data.copy() + m = material.data[:, 0].copy() + local = np.column_stack([g, coords[:, 0], coords[:, 1], m]) + gathered = comm.allgather(local) + full = np.vstack([a for a in gathered if a.size]) if any( + a.size for a in gathered + ) else np.empty((0, 4)) + order = np.argsort(full[:, 0], kind="stable") + swarm_state = full[order] + + # T round-trip check: gather partition-invariant scalars + # (max, sum) rather than the full (coord, value) table — DOFs at + # partition boundaries are visible to multiple ranks and would + # appear duplicated/reordered in a gathered table, even though + # the underlying data is bit-exact. + t_arr = np.asarray(T.array[...]).reshape(-1) + t_max = comm.allreduce(float(t_arr.max()) if t_arr.size else -np.inf, + op=MPI.MAX) + t_min = comm.allreduce(float(t_arr.min()) if t_arr.size else np.inf, + op=MPI.MIN) + # bit-exact float sum across ranks is non-deterministic in general + # (non-associative); use min/max as bit-exact invariants instead. + return swarm_state, (t_max, t_min) + + +def main(): + import tempfile + + uw, model, mesh, V_fn, T, swarm, gid, material = build() + pre_swarm, pre_T = global_sorted_state(T, swarm, gid, material) + pre_count = comm.allreduce(swarm.dm.getLocalSize(), op=MPI.SUM) + + # Use a shared temp dir reachable from every rank + if rank == 0: + tmp = tempfile.mkdtemp(prefix="uw3_ptest_0010_") + else: + tmp = None + tmp = comm.bcast(tmp, root=0) + + wrapper = os.path.join(tmp, "parrun.snap.h5") + model.save_state(file=wrapper) + comm.Barrier() + + # Check files on rank 0 + files_ok = True + if rank == 0: + bulk = os.path.join(tmp, "parrun.snap.bulk") + files = sorted(os.listdir(bulk)) + per_rank = [f for f in files if ".swarm.rank" in f] + # Expect one swarm sidecar per rank + if len(per_rank) != size: + print( + f"!! expected {size} swarm sidecars, got {len(per_rank)}: " + f"{per_rank}", + flush=True, + ) + files_ok = False + else: + print( + f" swarm sidecars OK: {per_rank[0]} ... ({size} total)", + flush=True, + ) + + # Scribble everything + T.array[...] = -99.0 + coord_field = swarm.dm.getField("DMSwarmPIC_coor").reshape((-1, swarm.dim)) + coord_field[...] = -99.0 + swarm.dm.restoreField("DMSwarmPIC_coor") + material.data[...] = -99.0 + model.tracker.time = -1.0 + model.tracker.step = -1 + + model.load_state(wrapper) + post_count = comm.allreduce(swarm.dm.getLocalSize(), op=MPI.SUM) + post_swarm, post_T = global_sorted_state(T, swarm, gid, material) + + swarm_ok = np.array_equal(pre_swarm, post_swarm) + # T is checked via partition-invariant min/max scalars (see note + # in global_sorted_state — gathered DOFs include partition- + # boundary duplicates that resist a global-table comparison). + T_ok = (pre_T == post_T) + count_ok = pre_count == post_count + tracker_ok = (model.tracker.time == 1.5 and model.tracker.step == 7) + + if rank == 0: + print(f"[ranks={size}] particles total = {pre_count}", flush=True) + print(f" P1 disk wrapper + per-rank sidecars present: {files_ok}", + flush=True) + print(f" P2 particle count preserved: {count_ok}", + flush=True) + print(f" P3 swarm (coords + gid + material) exact: {swarm_ok}", + flush=True) + print(f" P4 T (mesh-variable DOFs) exact: {T_ok}", + flush=True) + print(f" P5 tracker state restored: {tracker_ok}", + flush=True) + + assert files_ok + assert count_ok + assert swarm_ok + assert T_ok + assert tracker_ok + print(f"[ranks={size}] PASS", flush=True) + + +if __name__ == "__main__": + main() diff --git a/tests/run_snapshot_backstepping_demo.py b/tests/run_snapshot_backstepping_demo.py index 5a014548..051fcd26 100644 --- a/tests/run_snapshot_backstepping_demo.py +++ b/tests/run_snapshot_backstepping_demo.py @@ -7,7 +7,7 @@ - timestep forward at small Δt for a while (CFL well under 1), - take a snapshot, - try one too-large Δt (CFL spikes far above 1), - - detect the bad step, call ``model.restore(snap)``, + - detect the bad step, call ``model.load_state(snap)``, - replay the same time interval with many small steps (CFL stays small), - continue past the speculative end-time. @@ -82,7 +82,7 @@ def take_step(dt: float): cfl_kept.append(disp / cfl_threshold) t_snap = t - snap = model.snapshot() + snap = model.save_state() # --- Phase 2: speculative big step --- disp_bad = take_step(candidate_dt) @@ -90,7 +90,7 @@ def take_step(dt: float): cfl_bad = disp_bad / cfl_threshold # --- CFL violated → restore --- - model.restore(snap) + model.load_state(snap) # --- Phase 3: substep replay --- times_recovered = [] @@ -171,7 +171,7 @@ def take_step(dt: float): ax.text( 0.5 * (t_snap + t_bad_end), 0.4 * cfl_bad, - "model.restore(snap)", + "model.load_state(snap)", ha="center", va="center", fontsize=10, color="0.35", style="italic", bbox=dict(facecolor="white", edgecolor="0.7", boxstyle="round,pad=0.25"), @@ -200,7 +200,7 @@ def take_step(dt: float): ax.set_xlabel("simulation time t") ax.set_ylabel("CFL ratio = max per-step displacement / cell radius") ax.set_title( - "Adaptive-Δt back-stepping • model.snapshot() / model.restore()", + "Adaptive-Δt back-stepping • model.save_state() / model.load_state()", pad=22, ) ax.legend(loc="upper right", frameon=False) diff --git a/tests/run_snapshot_backstepping_spatial.py b/tests/run_snapshot_backstepping_spatial.py index fb155747..fe06061f 100644 --- a/tests/run_snapshot_backstepping_spatial.py +++ b/tests/run_snapshot_backstepping_spatial.py @@ -5,14 +5,14 @@ spatial panels at the four moments that matter: [initial state (snapshot taken here)] [after speculative bad step] - [after model.restore(snap)] [after substep recovery to same t] + [after model.load_state(snap)] [after substep recovery to same t] Each panel shows the swarm particles coloured by their carried material value (initial radial position), with the domain boundary drawn as context. The diagonal pairs tell two stories: - top-left vs. bottom-left should be **visually identical**. That's - the proof that model.restore(snap) put the captured state back + the proof that model.load_state(snap) put the captured state back exactly. If the figure ever stops showing two identical panels in that diagonal, the snapshot mechanism has broken. @@ -69,7 +69,7 @@ def main(out_path: str = "snapshot_backstepping_spatial.png"): # Take the snapshot — this is the state that bottom-left will # have to match after restore. - snap = model.snapshot() + snap = model.save_state() # --- Speculative big step --- swarm.advection(V_fn, delta_t=candidate_dt, step_limit=False) @@ -79,8 +79,8 @@ def main(out_path: str = "snapshot_backstepping_spatial.png"): ) cfl_bad = max_disp_bad / cfl_threshold - # --- model.restore(snap) --- - model.restore(snap) + # --- model.load_state(snap) --- + model.load_state(snap) state_after_restore = _capture(swarm, material) # --- Substep recovery to the same target time --- @@ -121,7 +121,7 @@ def main(out_path: str = "snapshot_backstepping_spatial.png"): ( axes[1, 0], state_after_restore, - "After model.restore(snap)", + "After model.load_state(snap)", "t = 0.00 • visually identical to top-left", ), ( diff --git a/tests/test_0007_snapshot_inmemory.py b/tests/test_0007_snapshot_inmemory.py index 1af2af6f..d52fee33 100644 --- a/tests/test_0007_snapshot_inmemory.py +++ b/tests/test_0007_snapshot_inmemory.py @@ -24,12 +24,12 @@ def test_meshvariable_in_memory_roundtrip(): T.array[:, 0, 0] = T.coords[:, 0] + 2.0 * T.coords[:, 1] pre_array = np.asarray(T.array[...]).copy() - snap = model.snapshot() + snap = model.save_state() T.array[...] = -42.0 assert not np.allclose(np.asarray(T.array[...]), pre_array), "scribble didn't take" - model.restore(snap) + model.load_state(snap) assert np.allclose(np.asarray(T.array[...]), pre_array, atol=0.0, rtol=0.0), ( "MeshVariable.array is not bit-equivalent after restore" @@ -49,12 +49,12 @@ def test_multiple_meshvariables_roundtrip(): T_pre = np.asarray(T.array[...]).copy() V_pre = np.asarray(V.array[...]).copy() - snap = model.snapshot() + snap = model.save_state() T.array[...] = 0.0 V.array[...] = 0.0 - model.restore(snap) + model.load_state(snap) assert np.allclose(np.asarray(T.array[...]), T_pre) assert np.allclose(np.asarray(V.array[...]), V_pre) @@ -66,7 +66,7 @@ def test_snapshot_is_independent_of_subsequent_writes(): T = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) T.array[:, 0, 0] = 5.0 - snap = model.snapshot() + snap = model.save_state() T.array[...] = -1.0 # The backend still holds the captured value, not the post-write value. @@ -87,14 +87,14 @@ def test_mesh_version_invalidates_restore(): T = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) T.array[:, 0, 0] = 1.0 - snap = model.snapshot() + snap = model.save_state() # Simulate a mesh-mutation event (e.g. adapt(), or any deformation # routed through the high-level callback that bumps _mesh_version). mesh._mesh_version += 1 with pytest.raises(SnapshotInvalidatedError, match="_mesh_version"): - model.restore(snap) + model.load_state(snap) def test_restore_rejects_non_snapshot(): @@ -103,16 +103,7 @@ def test_restore_rejects_non_snapshot(): _ = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) with pytest.raises(TypeError): - model.restore({"not": "a snapshot"}) - - -def test_snapshot_path_is_v1_1_scope(): - """Passing path= raises NotImplementedError until the on-disk backend lands.""" - uw, model, mesh = _fresh_model_and_mesh() - _ = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) - - with pytest.raises(NotImplementedError): - model.snapshot(path="/tmp/should_not_be_written.h5") + model.load_state({"not": "a snapshot"}) # ----- Swarm coverage (rebuild-on-restore semantics) ----- @@ -150,14 +141,14 @@ def test_swarm_no_change_roundtrip(): coords_pre = _swarm_coords(swarm) material_pre = np.asarray(material.data).copy() - snap = model.snapshot() + snap = model.save_state() coord_field = swarm.dm.getField("DMSwarmPIC_coor").reshape((-1, swarm.dim)) coord_field[...] = -99.0 swarm.dm.restoreField("DMSwarmPIC_coor") material.data[...] = -99.0 - model.restore(snap) + model.load_state(snap) assert np.allclose(_swarm_coords(swarm), coords_pre) assert np.allclose(np.asarray(material.data), material_pre) @@ -172,14 +163,14 @@ def test_swarm_restore_after_migrate(): material_pre = np.asarray(material.data).copy() pop_gen_pre = swarm._population_generation - snap = model.snapshot() + snap = model.save_state() # Mutate: migrate() will bump the counter regardless of whether # particles actually moved. Restore must succeed anyway. swarm.migrate(remove_sent_points=True) assert swarm._population_generation > pop_gen_pre, "migrate didn't bump counter" - model.restore(snap) + model.load_state(snap) assert np.allclose(_swarm_coords(swarm), coords_pre), ( "restore did not recover particle coords across a migrate event" @@ -197,14 +188,14 @@ def test_swarm_restore_after_add_particles(): material_pre = np.asarray(material.data).copy() npre = swarm.dm.getLocalSize() - snap = model.snapshot() + snap = model.save_state() swarm.add_particles_with_coordinates( np.array([[0.5, 0.5], [0.25, 0.75]]) ) assert swarm.dm.getLocalSize() != npre, "add_particles didn't grow swarm" - model.restore(snap) + model.load_state(snap) assert swarm.dm.getLocalSize() == npre, ( "restore did not roll back to the captured particle count" @@ -218,7 +209,7 @@ def test_swarm_population_generation_is_informational_not_a_gate(): uw, model, mesh, swarm, _ = _fresh_model_mesh_and_swarm() gen_at_capture = swarm._population_generation - snap = model.snapshot() + snap = model.save_state() swarm.migrate(remove_sent_points=True) swarm.add_particles_with_coordinates(np.array([[0.5, 0.5]])) @@ -226,7 +217,7 @@ def test_swarm_population_generation_is_informational_not_a_gate(): assert gen_during > gen_at_capture # Restore is expected to *succeed*, not raise. - model.restore(snap) + model.load_state(snap) # And the counter has moved on from where it was at capture, # because restore itself counts as a population change. @@ -237,7 +228,7 @@ def test_swarm_internal_variables_are_not_captured(): """Internal DMSwarm_* variables stay out of the snapshot key list.""" uw, model, mesh, swarm, material = _fresh_model_mesh_and_swarm() - snap = model.snapshot() + snap = model.save_state() keys = snap.backend.list_vectors() swarm_name = swarm._snapshot_stable_name() svar_keys = [k for k in keys if k.startswith(f"swarm:{swarm_name}:var:")] @@ -299,14 +290,14 @@ def test_symbolic_ddt_roundtrip_recovers_state(): assert state_pre.n_solves_completed == 2 assert state_pre.dt_history == [0.2, 0.1] - snap = model.snapshot() + snap = model.save_state() # Mutate: take another solve, dt_history changes. ddt.update_pre_solve(dt=0.5) ddt.update_post_solve(dt=0.5) assert ddt.state.dt_history == [0.5, 0.2] - model.restore(snap) + model.load_state(snap) # Primary state is back to captured. state_post = ddt.state @@ -344,7 +335,7 @@ def test_symbolic_ddt_snapshot_is_deep_copy(): ddt.update_pre_solve(dt=0.1) ddt.update_post_solve(dt=0.1) - snap = model.snapshot() + snap = model.save_state() # Find the DDt's captured state by type — state_bearers is # unordered and now also contains the model tracker. from underworld3.systems.ddt import DDtSymbolicState @@ -394,14 +385,14 @@ def test_eulerian_ddt_roundtrip(): ddt._dt = 0.2 state_pre = ddt.state - snap = model.snapshot() + snap = model.save_state() ddt._dt_history = [0.99, 0.99] ddt._history_initialised = False ddt._n_solves_completed = 0 ddt._dt = None - model.restore(snap) + model.load_state(snap) assert ddt.state.dt_history == state_pre.dt_history assert ddt.state.history_initialised == state_pre.history_initialised @@ -437,11 +428,11 @@ def test_semilagrangian_ddt_roundtrip(): ddt._dt = 0.3 state_pre = ddt.state - snap = model.snapshot() + snap = model.save_state() ddt._dt_history = [None, None] ddt._history_initialised = False ddt._n_solves_completed = 0 - model.restore(snap) + model.load_state(snap) assert ddt.state.dt_history == state_pre.dt_history assert ddt.state.history_initialised is True @@ -474,11 +465,11 @@ def test_lagrangian_ddt_roundtrip(): ddt._dt = 0.2 state_pre = ddt.state - snap = model.snapshot() + snap = model.save_state() ddt._dt_history = [None, None] ddt._history_initialised = False ddt._n_solves_completed = 0 - model.restore(snap) + model.load_state(snap) assert ddt.state.dt_history == state_pre.dt_history assert ddt.state.history_initialised is True @@ -549,7 +540,7 @@ def test_backstepping_cfl_recovery_end_to_end(): # Take the snapshot *before* the speculative step. Everything that # will be touched gets captured. - snap = model.snapshot() + snap = model.save_state() # Speculative step at the candidate Δt. Bigger than the user # thinks is safe — they'll check after and back-step if it isn't. @@ -569,7 +560,7 @@ def test_backstepping_cfl_recovery_end_to_end(): # Back-step. Everything captured is brought back to the snapshot # point — swarm positions, the material variable carried with the # swarm, and the DDt's BDF history. - model.restore(snap) + model.load_state(snap) assert np.allclose(swarm._particle_coordinates.data, coords_initial), ( "particle positions did not roll back after restore" @@ -724,7 +715,7 @@ def test_continuation_deterministic_after_restore(): for _ in range(3): _step(uw, V_fn, T, swarm, ddt, 0.05) - snap = model.snapshot() + snap = model.save_state() # Branch A: K steps straight from S. for _ in range(5): @@ -732,7 +723,7 @@ def test_continuation_deterministic_after_restore(): state_A = _capture_full_state(T, swarm, material, ddt) # Branch B: restore S, then the identical K steps. - model.restore(snap) + model.load_state(snap) for _ in range(5): _step(uw, V_fn, T, swarm, ddt, 0.05) state_B = _capture_full_state(T, swarm, material, ddt) @@ -756,7 +747,7 @@ def test_continuation_bit_identical_across_stash_and_recover(): for _ in range(3): _step(uw, V_fn, T, swarm, ddt, 0.05) - snap = model.snapshot() + snap = model.save_state() # Control: K good steps from S. for _ in range(5): @@ -766,9 +757,9 @@ def test_continuation_bit_identical_across_stash_and_recover(): # Stash scenario: back to S, take a deliberately disruptive step # (10x Δt — large advection, big T jump, DDt history shift), then # discard it via restore and run the intended K good steps. - model.restore(snap) + model.load_state(snap) _step(uw, V_fn, T, swarm, ddt, 0.5) # the regretted step - model.restore(snap) + model.load_state(snap) for _ in range(5): _step(uw, V_fn, T, swarm, ddt, 0.05) stash = _capture_full_state(T, swarm, material, ddt) diff --git a/tests/test_0008_snapshot_realsolver.py b/tests/test_0008_snapshot_realsolver.py index 33afe4a3..d37f15de 100644 --- a/tests/test_0008_snapshot_realsolver.py +++ b/tests/test_0008_snapshot_realsolver.py @@ -105,14 +105,14 @@ def test_realsolver_restore_recovers_solution_field(): adv_diff.solve(timestep=1.0e-3) pre_T = _capture(T) - snap = model.snapshot() + snap = model.save_state() adv_diff.solve(timestep=5.0) # absurd Δt: converges, over-diffused assert not np.allclose(_capture(T), pre_T, atol=1e-8), ( "the regretted solve was not actually disruptive" ) - model.restore(snap) + model.load_state(snap) assert np.array_equal(_capture(T), pre_T), ( "solution field not exactly recovered after restore" ) @@ -132,18 +132,18 @@ def test_realsolver_regretted_step_leaves_no_trace(): for _ in range(3): adv_diff.solve(timestep=1.0e-3) - snap = model.snapshot() + snap = model.save_state() # B: restore, then K good solves. - model.restore(snap) + model.load_state(snap) for _ in range(4): adv_diff.solve(timestep=1.0e-3) B = _capture(T) # C: restore, a regretted solve, restore, the same K good solves. - model.restore(snap) + model.load_state(snap) adv_diff.solve(timestep=5.0) - model.restore(snap) + model.load_state(snap) for _ in range(4): adv_diff.solve(timestep=1.0e-3) C = _capture(T) @@ -165,7 +165,7 @@ def test_realsolver_continuation_within_solver_tolerance(): for _ in range(3): adv_diff.solve(timestep=1.0e-3) - snap = model.snapshot() + snap = model.save_state() # Control: never snapshotted/restored — straight K solves. for _ in range(4): @@ -173,9 +173,9 @@ def test_realsolver_continuation_within_solver_tolerance(): ctrl = _capture(T) # Stash path: restore, regretted solve, restore, same K solves. - model.restore(snap) + model.load_state(snap) adv_diff.solve(timestep=5.0) - model.restore(snap) + model.load_state(snap) for _ in range(4): adv_diff.solve(timestep=1.0e-3) stash = _capture(T) diff --git a/tests/test_0009_model_tracker.py b/tests/test_0009_model_tracker.py index 35403654..3f3f12ac 100644 --- a/tests/test_0009_model_tracker.py +++ b/tests/test_0009_model_tracker.py @@ -34,13 +34,13 @@ def test_tracker_builtins_revert_on_restore(): model.tracker.step = 7 model.tracker.dt = 0.05 - snap = model.snapshot() + snap = model.save_state() model.tracker.time = 99.0 model.tracker.step = 999 model.tracker.dt = 1.0 - model.restore(snap) + model.load_state(snap) assert model.tracker.time == 3.14 assert model.tracker.step == 7 @@ -53,9 +53,9 @@ def test_tracker_user_quantity_reverts(): uw, model = _fresh_model() model.tracker.my_diagnostic = 42.0 - snap = model.snapshot() + snap = model.save_state() model.tracker.my_diagnostic = -1.0 - model.restore(snap) + model.load_state(snap) assert model.tracker.my_diagnostic == 42.0 @@ -67,11 +67,11 @@ def test_tracker_numpy_quantity_reverts_by_value(): arr = np.array([1.0, 2.0, 3.0]) model.tracker.history = arr - snap = model.snapshot() + snap = model.save_state() model.tracker.history[:] = -9.0 # in-place mutation assert np.allclose(model.tracker.history, -9.0) - model.restore(snap) + model.load_state(snap) assert np.allclose(model.tracker.history, [1.0, 2.0, 3.0]) @@ -81,11 +81,11 @@ def test_tracker_quantity_added_after_snapshot_is_dropped_on_restore(): uw, model = _fresh_model() model.tracker.a = 1.0 - snap = model.snapshot() + snap = model.save_state() model.tracker.b = 2.0 # created after snapshot assert "b" in model.tracker - model.restore(snap) + model.load_state(snap) assert "a" in model.tracker assert "b" not in model.tracker @@ -99,13 +99,13 @@ def test_tracker_is_what_makes_state_revertible(): loose_time = 0.0 model.tracker.time = 0.0 - snap = model.snapshot() + snap = model.save_state() # Advance both the loose variable and the tracked one. loose_time = 5.0 model.tracker.time = 5.0 - model.restore(snap) + model.load_state(snap) # The loose variable is untouched by restore (the language can't # know about it); the tracked one rolled back. @@ -122,10 +122,10 @@ def test_tracker_state_roundtrip_is_bit_identical(): model.tracker.payload = np.arange(5).astype(float) state_pre = model.tracker.state - snap = model.snapshot() + snap = model.save_state() model.tracker.time = 12345.0 model.tracker.payload[:] = 0.0 - model.restore(snap) + model.load_state(snap) state_post = model.tracker.state assert state_post.managed["time"] == state_pre.managed["time"] @@ -158,7 +158,7 @@ def do_step(dt): for _ in range(3): do_step(0.05) - snap = model.snapshot() + snap = model.save_state() t_snap, s_snap = model.tracker.time, model.tracker.step # Regretted big step. @@ -166,7 +166,7 @@ def do_step(dt): assert model.tracker.step == s_snap + 1 assert model.tracker.time != t_snap - model.restore(snap) + model.load_state(snap) assert model.tracker.time == t_snap assert model.tracker.step == s_snap diff --git a/tests/test_0010_snapshot_disk_format.py b/tests/test_0010_snapshot_disk_format.py new file mode 100644 index 00000000..56609dfe --- /dev/null +++ b/tests/test_0010_snapshot_disk_format.py @@ -0,0 +1,669 @@ +"""Phase 1 of the on-disk snapshot format (v1.1). + +These tests assert the *inspectability bar* — an external h5 reader +(here, h5py — but the assertions translate directly to ``h5ls`` +output) must see meaningful information about a UW3 snapshot file +without UW3 needing to be in the loop. They do not yet exercise any +PETSc bulk-data writes; that lands in phase 2. +""" + +import json + +import pytest +import numpy as np + +pytestmark = [pytest.mark.level_1, pytest.mark.tier_a] + + +def _fresh_model_with_state(tmp_path): + import underworld3 as uw + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 4.0 + ) + _ = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) + _ = uw.discretisation.MeshVariable("V", mesh, 2, degree=2) + swarm = uw.swarm.Swarm(mesh) + swarm.populate(fill_param=2) + + model.tracker.time = 3.14 + model.tracker.step = 42 + model.tracker.dt = 0.05 + return uw, model, mesh, swarm + + +def test_skeleton_writes_expected_group_structure(tmp_path): + """The file an h5 tool would open has exactly the documented + top-level group structure — no surprises for external readers.""" + import h5py + + uw, model, mesh, swarm = _fresh_model_with_state(tmp_path) + path = str(tmp_path / "phase1.snap.h5") + uw.checkpoint.write_snapshot_skeleton(model, path) + + with h5py.File(path, "r") as f: + top = set(f.keys()) + # Variables nest under each mesh (/meshes/{name}/variables/...) so + # they are not a top-level group. + assert top == {"metadata", "meshes", "swarms", "python_state"} + + +def test_metadata_is_inspectable_without_uw3(tmp_path): + """The /metadata attrs h5py reads back are useful — i.e. an + external reader sees the run identity, schema, step/time, geometry, + MPI rank count, and the inventory of meshes/swarms/variables. + + Concretely: an h5py user (the proxy for h5ls/h5dump here) can + answer 'what's in this file?' from /metadata alone.""" + import h5py + + uw, model, mesh, swarm = _fresh_model_with_state(tmp_path) + path = str(tmp_path / "phase1.snap.h5") + uw.checkpoint.write_snapshot_skeleton(model, path) + + with h5py.File(path, "r") as f: + md = f["metadata"].attrs + # Identity / versioning + assert int(md["schema_version"]) == 1 + assert isinstance(str(md["created_at"]), str) + assert str(md["run_name"]) != "" + # Tracker conventions surfaced as scalars + assert float(md["sim_time"]) == 3.14 + assert int(md["step"]) == 42 + assert float(md["dt"]) == 0.05 + # Geometry + assert int(md["dim"]) == 2 + assert str(md["mesh_type"]) != "" + # MPI + assert int(md["mpi_ranks_at_write"]) >= 1 + # Inventories (JSON-encoded list-typed values) + var_summary = str(md["variables_summary"]) + assert "T" in var_summary and "V" in var_summary + swarm_names = json.loads(str(md["swarm_names_json"])) + assert len(swarm_names) == 1 + state_classes = json.loads(str(md["state_bearer_classes_json"])) + assert "ModelTracker" in state_classes + + +def test_read_snapshot_metadata_roundtrip(tmp_path): + """write -> read returns the same content, with JSON-encoded list + fields conveniently decoded for the caller.""" + import underworld3 as uw + + uw, model, mesh, swarm = _fresh_model_with_state(tmp_path) + path = str(tmp_path / "phase1.snap.h5") + uw.checkpoint.write_snapshot_skeleton(model, path) + + md = uw.checkpoint.read_snapshot_metadata(path) + assert md["schema_version"] == 1 + assert md["sim_time"] == 3.14 + assert md["step"] == 42 + assert md["dim"] == 2 + # Convenience: JSON-encoded lists are also exposed as plain lists. + assert isinstance(md["mesh_names"], list) and len(md["mesh_names"]) == 1 + assert isinstance(md["swarm_names"], list) + assert "ModelTracker" in md["state_bearer_classes"] + + +def test_read_rejects_non_snapshot_file(tmp_path): + """Pointing at an h5 file that isn't a UW3 snapshot raises + cleanly, not with an obscure h5py error.""" + import h5py + import underworld3 as uw + + path = str(tmp_path / "not-a-snapshot.h5") + with h5py.File(path, "w") as f: + f.create_dataset("random_dataset", data=np.zeros(3)) + + with pytest.raises(ValueError, match="not a UW3 snapshot"): + uw.checkpoint.read_snapshot_metadata(path) + + +def test_read_rejects_wrong_schema_version(tmp_path): + """A future-version snapshot we cannot interpret raises, with a + pointer to the (future) migration path.""" + import h5py + import underworld3 as uw + + path = str(tmp_path / "future.snap.h5") + with h5py.File(path, "w") as f: + md = f.create_group("metadata") + md.attrs["schema_version"] = 999 + + with pytest.raises(ValueError, match="schema version 999"): + uw.checkpoint.read_snapshot_metadata(path) + + +def test_inspect_snapshot_summary_includes_key_facts(tmp_path): + """The human-readable summary surface (intended for notebook + `print(...)` use) covers the same key facts external h5 inspection + would surface.""" + import underworld3 as uw + + uw, model, mesh, swarm = _fresh_model_with_state(tmp_path) + path = str(tmp_path / "phase1.snap.h5") + uw.checkpoint.write_snapshot_skeleton(model, path) + + summary = uw.checkpoint.inspect_snapshot(path) + assert "UW3 snapshot" in summary + assert "schema_version : 1" in summary + assert "sim_time" in summary and "3.14" in summary + assert "step" in summary and "42" in summary + assert "ModelTracker" in summary + + +def test_skeleton_groups_have_filled_by_marker(tmp_path): + """Empty top-level groups carry a `filled_by` attr so a phase-2/3 + reader knows whether their content is populated yet — and an + external inspector sees 'this group is empty because phase 2 + hasn't run' rather than getting nothing.""" + import h5py + import underworld3 as uw + + uw, model, mesh, swarm = _fresh_model_with_state(tmp_path) + path = str(tmp_path / "phase1.snap.h5") + uw.checkpoint.write_snapshot_skeleton(model, path) + + with h5py.File(path, "r") as f: + for name in ("meshes", "swarms", "python_state"): + assert "filled_by" in f[name].attrs + assert str(f[name].attrs["filled_by"]) == "" + + +# ----- Phase 2: mesh + mesh-variable bulk via #146 ----- + + +def _fresh_model_mesh_and_vars(): + import underworld3 as uw + + uw.reset_default_model() + uw.use_strict_units(False) + uw.use_nondimensional_scaling(False) + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 4.0 + ) + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=1) + V = uw.discretisation.MeshVariable("V", mesh, 2, degree=2) + return uw, model, mesh, T, V + + +def test_write_snapshot_produces_wrapper_and_bulk_dir(tmp_path): + """The two artifacts the convention promises: wrapper file + + sibling .bulk/ directory containing PETSc HDF5 files.""" + import os + import underworld3 as uw + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + T.array[:, 0, 0] = 5.0 + V.array[:, 0, 0] = -3.0 + + path = str(tmp_path / "run.snap.h5") + model.save_state(file=path) + + bulk = str(tmp_path / "run.snap.bulk") + assert os.path.exists(path) + assert os.path.isdir(bulk) + # #146-format files in the bulk dir. + files = sorted(os.listdir(bulk)) + # At least: mesh file + one file per variable. + assert any(f.endswith(".mesh.00000.h5") for f in files) + assert any("T.00000.h5" in f for f in files) + assert any("V.00000.h5" in f for f in files) + + +def test_write_snapshot_populates_wrapper_layout(tmp_path): + """The wrapper carries the per-mesh + per-variable metadata that + makes 'what's in this snapshot?' answerable from h5py alone.""" + import h5py + import underworld3 as uw + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + T.array[:, 0, 0] = 1.0 + V.array[:, 0, 0] = 2.0 + V.array[:, 0, 1] = 3.0 + + path = str(tmp_path / "run.snap.h5") + model.save_state(file=path) + + with h5py.File(path, "r") as f: + assert f["meshes"].attrs["filled_by"] == "phase2" + # One mesh subgroup + mesh_names = list(f["meshes"].keys()) + assert len(mesh_names) == 1 + mg = f["meshes"][mesh_names[0]] + # Per-mesh attrs + assert mg.attrs["name"] == mesh.name + assert mg.attrs["mesh_file"].endswith(".mesh.00000.h5") + # Variables subgroup + var_names = sorted(mg["variables"].keys()) + assert var_names == ["T", "V"] + # Per-var attrs include shape info + external_file pointer. + v_attrs = mg["variables"]["V"].attrs + assert v_attrs["components"] == 2 + assert v_attrs["degree"] == 2 + assert v_attrs["external_file"].endswith("V.00000.h5") + + +def test_write_read_snapshot_bit_exact_roundtrip(tmp_path): + """The core phase-2 guarantee: write a snapshot, scribble all + variables, read snapshot back, all variables match write-time + values bit-for-bit (#146's PETSc DMPlex same-rank reload, just + delivered via the wrapper).""" + import underworld3 as uw + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + T.array[:, 0, 0] = 5.0 * T.coords[:, 0] - 2.0 + V.array[:, 0, 0] = 3.0 * V.coords[:, 0] + V.array[:, 0, 1] = 7.0 * V.coords[:, 1] + T_pre = np.asarray(T.array[...]).copy() + V_pre = np.asarray(V.array[...]).copy() + + path = str(tmp_path / "run.snap.h5") + model.save_state(file=path) + + # Scribble. + T.array[...] = -99.0 + V.array[...] = -99.0 + + model.load_state(path) + + assert np.array_equal(np.asarray(T.array[...]), T_pre), ( + f"T not bit-exact after read_snapshot — max|d|=" + f"{float(np.max(np.abs(np.asarray(T.array[...]) - T_pre))):.3e}" + ) + assert np.array_equal(np.asarray(V.array[...]), V_pre), ( + f"V not bit-exact after read_snapshot — max|d|=" + f"{float(np.max(np.abs(np.asarray(V.array[...]) - V_pre))):.3e}" + ) + + +def test_read_snapshot_rejects_missing_bulk_dir(tmp_path): + """If the user moves the wrapper without the bulk dir, read fails + with a clear pointer rather than an obscure h5py error.""" + import os + import underworld3 as uw + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + path = str(tmp_path / "run.snap.h5") + model.save_state(file=path) + + # Delete the bulk dir to simulate the move-the-wrapper-only mistake. + import shutil + shutil.rmtree(str(tmp_path / "run.snap.bulk")) + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + with pytest.raises(FileNotFoundError, match="bulk directory missing"): + model.load_state(path) + + +def test_read_snapshot_rejects_mismatched_mesh(tmp_path): + """If the target model's meshes don't match the snapshot's, raise + clearly — mesh-rebuild on read is v1.2 scope.""" + import underworld3 as uw + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + path = str(tmp_path / "run.snap.h5") + model.save_state(file=path) + + # Fresh model with a *different* mesh — write_snapshot's mesh.name + # won't match. + uw.reset_default_model() + model2 = uw.get_default_model() + other = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(2.0, 2.0), cellSize=1.0 / 3.0, + ) + # Don't reuse the original mesh's name — make the lookup miss. + other.name = "definitely_a_different_mesh" + + with pytest.raises(ValueError, match="not registered on this model"): + model2.load_state(path) + + +# ----- Phase 3a: state-bearer (Snapshottable) serialisation ----- + + +def test_tracker_round_trips_through_disk_snapshot(tmp_path): + """ModelTracker is always auto-registered as a state-bearer, so + every snapshot must round-trip its time/step/dt and any + user-added managed quantities exactly.""" + import underworld3 as uw + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + model.tracker.time = 3.14 + model.tracker.step = 42 + model.tracker.dt = 0.05 + model.tracker.my_diagnostic = 99.0 + model.tracker.history_arr = np.array([1.0, 2.0, 3.0]) + + path = str(tmp_path / "run.snap.h5") + model.save_state(file=path) + + # Scribble everything tracker-side. + model.tracker.time = -1.0 + model.tracker.step = -1 + model.tracker.dt = -1.0 + model.tracker.my_diagnostic = -1.0 + model.tracker.history_arr = np.array([-1.0, -1.0, -1.0]) + + model.load_state(path) + + assert model.tracker.time == 3.14 + assert model.tracker.step == 42 + assert model.tracker.dt == 0.05 + assert model.tracker.my_diagnostic == 99.0 + assert np.array_equal( + np.asarray(model.tracker.history_arr), np.array([1.0, 2.0, 3.0]) + ) + + +def test_python_state_group_is_inspectable(tmp_path): + """An external h5py reader sees the per-bearer groups under + /python_state with class info + a managed dict for the tracker.""" + import h5py + import underworld3 as uw + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + model.tracker.time = 1.0 + model.tracker.step = 2 + model.tracker.my_q = 7.0 + + path = str(tmp_path / "run.snap.h5") + model.save_state(file=path) + + with h5py.File(path, "r") as f: + ps = f["python_state"] + assert ps.attrs["filled_by"] == "phase3a" + + tracker_keys = [k for k in ps.keys() if k.startswith("ModelTracker_")] + assert len(tracker_keys) == 1 + tg = ps[tracker_keys[0]] + assert tg.attrs["__bearer_class__"] == "ModelTracker" + assert tg.attrs["__state_class__"] == "TrackerState" + # TrackerState.managed is a dict, stored as a sub-group. + assert "managed" in tg and isinstance(tg["managed"], h5py.Group) + managed = tg["managed"] + # Pre-seeded conventions present. + assert "time" in managed.attrs + assert float(managed.attrs["time"]) == 1.0 + assert int(managed.attrs["step"]) == 2 + # User-added quantity present. + assert "my_q" in managed.attrs + assert float(managed.attrs["my_q"]) == 7.0 + + +def test_ddt_symbolic_state_round_trips_primary_fields(tmp_path): + """A Symbolic DDt has dt_history, history_initialised, + n_solves_completed, dt round-tripped via the generic dataclass + serialiser. psi_star (sympy) is documented as skipped — the + primary BDF-control fields are what matter for re-continuing.""" + import underworld3 as uw + import sympy + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + ddt = uw.systems.ddt.Symbolic(T.sym, order=2) + ddt._dt_history = [0.05, 0.03] + ddt._history_initialised = True + ddt._n_solves_completed = 2 + ddt._dt = 0.05 + + path = str(tmp_path / "run.snap.h5") + model.save_state(file=path) + + # Scribble the primary fields. + ddt._dt_history = [None, None] + ddt._history_initialised = False + ddt._n_solves_completed = 0 + ddt._dt = None + + model.load_state(path) + + assert ddt.state.dt_history == [0.05, 0.03] + assert ddt.state.history_initialised is True + assert ddt.state.n_solves_completed == 2 + assert ddt.state.dt == 0.05 + + +def test_read_snapshot_rejects_missing_state_bearer(tmp_path): + """If a state-bearer exists in the snapshot but not on the load- + target model, raise — same-rank/same-model contract.""" + import underworld3 as uw + + # Source model has a Symbolic DDt; snapshot it. + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + ddt = uw.systems.ddt.Symbolic(T.sym, order=2) + ddt._dt_history = [0.05, 0.05] + + path = str(tmp_path / "run.snap.h5") + model.save_state(file=path) + + # Target model has no DDt. + uw, model2, mesh2, T2, V2 = _fresh_model_mesh_and_vars() + # Force name match so the mesh part loads. + mesh2.name = mesh.name + + with pytest.raises(ValueError, match="state-bearer .* not registered"): + model2.load_state(path) + + +# ----- Phase 3b: swarms in sidecars ----- + + +def _fresh_model_mesh_swarm(): + import underworld3 as uw + + uw.reset_default_model() + uw.use_strict_units(False) + uw.use_nondimensional_scaling(False) + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 4.0 + ) + swarm = uw.swarm.Swarm(mesh) + material = swarm.add_variable("material", 1, dtype=float) + swarm.populate(fill_param=2) + material.data[:, 0] = swarm._particle_coordinates.data[:, 0] + return uw, model, mesh, swarm, material + + +def test_swarm_sidecar_lands_in_bulk_dir(tmp_path): + """A registered swarm produces its own h5 sidecar in the bulk dir + with predictable name (per-rank); wrapper records the sidecar + pattern + rank count.""" + import os + import re + import h5py + import underworld3 as uw + + uw, model, mesh, swarm, material = _fresh_model_mesh_swarm() + path = str(tmp_path / "run.snap.h5") + model.save_state(file=path) + + bulk = str(tmp_path / "run.snap.bulk") + files = sorted(os.listdir(bulk)) + # Per-rank pattern: .swarm.rank{R:04d}of{S:04d}.h5 + swarm_sidecars = [ + f for f in files + if re.match(r".*\.swarm\.rank\d{4}of\d{4}\.h5$", f) + ] + assert len(swarm_sidecars) >= 1 # at least rank-0 file on single-rank + + with h5py.File(path, "r") as f: + sw = f["swarms"] + assert sw.attrs["filled_by"] == "phase3b+phase6" + assert "mpi_size_at_write" in sw.attrs + assert len(sw.keys()) == 1 + swarm_safe = list(sw.keys())[0] + sg = sw[swarm_safe] + # Wrapper records pattern (not a specific file) — readers + # fill in their own rank. + assert "sidecar_pattern" in sg.attrs + assert "{rank" in str(sg.attrs["sidecar_pattern"]) + assert sg.attrs["mesh_name"] == mesh.name + assert int(sg.attrs["num_particles_global"]) > 0 + # User-added variables surface in /swarms/{name}/variables/ + assert "material" in sg["variables"] + + +def test_swarm_sidecar_is_inspectable(tmp_path): + """The swarm sidecar itself has h5-inspectable structure: top- + level attrs, /coordinates dataset, /variables/{name} datasets + with attrs. h5ls -v on the sidecar tells you what it holds.""" + import os + import h5py + import underworld3 as uw + + uw, model, mesh, swarm, material = _fresh_model_mesh_swarm() + path = str(tmp_path / "run.snap.h5") + model.save_state(file=path) + + bulk = str(tmp_path / "run.snap.bulk") + swarm_sidecar = [ + f for f in os.listdir(bulk) if ".swarm.rank" in f + ][0] + + with h5py.File(os.path.join(bulk, swarm_sidecar), "r") as f: + # File-level attrs identify the file without UW3 in the loop. + assert int(f.attrs["num_particles_local"]) > 0 + assert int(f.attrs["dim"]) == 2 + assert str(f.attrs["mesh_name"]) == mesh.name + # Bulk in standard h5 places + assert "coordinates" in f + assert f["coordinates"].shape[1] == 2 + # User var present with metadata + v = f["variables"]["material"] + assert int(v.attrs["num_components"]) == 1 + + +def test_swarm_round_trips_through_disk_snapshot(tmp_path): + """The whole swarm — particle coords + svar data — round-trips + exactly through write→scribble→read.""" + import os + import underworld3 as uw + + uw, model, mesh, swarm, material = _fresh_model_mesh_swarm() + coords_pre = swarm._particle_coordinates.data.copy() + material_pre = np.asarray(material.data).copy() + + path = str(tmp_path / "run.snap.h5") + model.save_state(file=path) + + # Scribble: corrupt coords directly via the PETSc field, scribble + # material via the standard array view. + coord_field = swarm.dm.getField("DMSwarmPIC_coor").reshape( + (-1, swarm.dim) + ) + coord_field[...] = -99.0 + swarm.dm.restoreField("DMSwarmPIC_coor") + material.data[...] = -99.0 + + model.load_state(path) + + assert np.array_equal(swarm._particle_coordinates.data, coords_pre), ( + "particle coords not bit-exact after disk-snapshot restore" + ) + assert np.array_equal(np.asarray(material.data), material_pre), ( + "swarm material not bit-exact after disk-snapshot restore" + ) + + +def test_swarm_restore_recovers_after_particle_count_change(tmp_path): + """Mirror of the in-memory rebuild-on-restore guarantee, but on + disk: snapshot a swarm, mutate its population (add particles), + restore — the original local population (count + coords + var + data) is fully reconstructed.""" + import underworld3 as uw + + uw, model, mesh, swarm, material = _fresh_model_mesh_swarm() + n_pre = swarm.dm.getLocalSize() + coords_pre = swarm._particle_coordinates.data.copy() + material_pre = np.asarray(material.data).copy() + + path = str(tmp_path / "run.snap.h5") + model.save_state(file=path) + + # Mutate the population — add two extra particles. Local count + # changes. + swarm.add_particles_with_coordinates( + np.array([[0.4, 0.4], [0.6, 0.6]]) + ) + assert swarm.dm.getLocalSize() != n_pre + + model.load_state(path) + + assert swarm.dm.getLocalSize() == n_pre, ( + "restore did not roll back to the captured particle count" + ) + assert np.array_equal(swarm._particle_coordinates.data, coords_pre) + assert np.array_equal(np.asarray(material.data), material_pre) + + +# ----- Phase 4: read_timestep is format-aware ----- + + +def test_read_timestep_reads_v1_1_snapshot_via_dispatch(tmp_path): + """The selective-read entry point users already know + (``var.read_timestep(...)``) reads a v1.1 snapshot wrapper too — + the format detection is hidden inside the call. Same-mesh read + is bit-exact because the KDTree query lands exactly on the + captured DOF coordinates.""" + import underworld3 as uw + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + T.array[:, 0, 0] = 5.0 * T.coords[:, 0] + 3.0 + T_pre = np.asarray(T.array[...]).copy() + + path = str(tmp_path / "run.snap.h5") + model.save_state(file=path) + + # Scribble T, then re-load via the legacy read_timestep API — + # but pointed at the v1.1 wrapper. Format dispatch handles it. + T.array[...] = -99.0 + T.read_timestep(path, "T", 0) + + assert np.allclose(np.asarray(T.array[...]), T_pre, atol=1e-12), ( + f"read_timestep against v1.1 snapshot not bit-exact: " + f"max|d| = {float(np.max(np.abs(np.asarray(T.array[...]) - T_pre))):.3e}" + ) + + +def test_read_timestep_legacy_path_unchanged(tmp_path): + """Belt-and-braces: pointing read_timestep at a legacy + write_timestep file still uses the legacy code path (not the v1.1 + bridge).""" + import underworld3 as uw + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + T.array[:, 0, 0] = 7.0 + T_pre = np.asarray(T.array[...]).copy() + + # Write via legacy write_timestep — different format, same content. + mesh.write_timestep( + "legacy", index=0, outputPath=str(tmp_path), meshVars=[T] + ) + + T.array[...] = -99.0 + T.read_timestep("legacy", "T", 0, outputPath=str(tmp_path)) + assert np.allclose(np.asarray(T.array[...]), T_pre, atol=1e-12) + + +def test_swarm_internals_not_in_sidecar(tmp_path): + """The PETSc-internal swarm variables (DMSwarmPIC_coor, DMSwarm_X0, + DMSwarm_remeshed) are filtered at capture — they're not in the + sidecar's /variables group. Same rule as in-memory capture.""" + import os + import h5py + import underworld3 as uw + + uw, model, mesh, swarm, material = _fresh_model_mesh_swarm() + path = str(tmp_path / "run.snap.h5") + model.save_state(file=path) + + bulk = str(tmp_path / "run.snap.bulk") + sidecar = [f for f in os.listdir(bulk) if ".swarm.rank" in f][0] + with h5py.File(os.path.join(bulk, sidecar), "r") as f: + var_names = set(f["variables"].keys()) + assert "material" in var_names + assert not any(n.startswith("DMSwarm") for n in var_names)