From 66a2d748ef46d8b2fa2d963dfa2b9ae11b7659ef Mon Sep 17 00:00:00 2001 From: lmoresi Date: Wed, 20 May 2026 13:10:15 +1000 Subject: [PATCH 1/7] feat(snapshot v1.1, phase 1): metadata layer + skeleton + inspectability MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit First slice of the on-disk snapshot format (v1.1). Establishes the file structure and the inspectability bar; no PETSc bulk yet (that is phase 2). Stacked on the in-memory snapshot toolkit (#195) and the model tracker (#196) so it can serialise both later. What lands: - src/underworld3/checkpoint/disk_snapshot.py - DISK_SNAPSHOT_SCHEMA_VERSION = 1 - write_snapshot_skeleton(model, path): writes /metadata attrs + empty stub groups /mesh /variables /swarms /python_state (the structure phases 2+ will fill in). - read_snapshot_metadata(path): reads /metadata back as a plain dict, decodes JSON-encoded list fields for convenience, validates schema version. - inspect_snapshot(path): human-readable summary suitable for print(...) at a notebook prompt. - src/underworld3/checkpoint/__init__.py: exports. - tests/test_0010_snapshot_disk_format.py (7, tier_a level_1): - top-level group structure matches the spec - h5py-readable /metadata attrs cover identity, schema, tracker conventions, geometry, MPI rank count, and inventories of meshes / swarms / state-bearer classes / variables — the proxy for "an external user running h5ls/h5dump sees useful info" - read/write roundtrip - rejection of non-snapshot files and wrong-schema files with clear errors (not obscure h5py noise) - inspect_snapshot includes the key facts - skeleton groups carry `filled_by` attrs so phases 2/3 readers and external inspectors can tell whether content is populated yet. Design notes encoded: - UW3-controlled rich-metadata wrapper around PETSc bulk; pure PETSc HDF5 dumps fail the inspectability bar so are rejected as the format. - List-typed metadata stored as JSON strings in scalar attrs so h5py / h5ls handle them cleanly; read API exposes them as plain Python lists alongside the *_json originals. - Swarm storage left as a phase-3 decision: the metadata wrapper is designed to support `@external_file` on /swarms/swarm_X/ when individual swarms grow too bulky for a single file. No commitment to inline vs split until phase 3 has real swarm sizes in hand. Stacked on feature/model-tracker; PRs to development after #195 and #196 land. Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- src/underworld3/checkpoint/__init__.py | 10 + src/underworld3/checkpoint/disk_snapshot.py | 265 ++++++++++++++++++++ tests/test_0010_snapshot_disk_format.py | 170 +++++++++++++ 3 files changed, 445 insertions(+) create mode 100644 src/underworld3/checkpoint/disk_snapshot.py create mode 100644 tests/test_0010_snapshot_disk_format.py diff --git a/src/underworld3/checkpoint/__init__.py b/src/underworld3/checkpoint/__init__.py index c601b1ff..31863784 100644 --- a/src/underworld3/checkpoint/__init__.py +++ b/src/underworld3/checkpoint/__init__.py @@ -27,6 +27,12 @@ ) from .state import Snapshottable, SnapshottableState from .tracker import ModelTracker, TrackerState +from .disk_snapshot import ( + DISK_SNAPSHOT_SCHEMA_VERSION, + inspect_snapshot, + read_snapshot_metadata, + write_snapshot_skeleton, +) __all__ = [ "CheckpointBackend", @@ -40,4 +46,8 @@ "SnapshottableState", "ModelTracker", "TrackerState", + "DISK_SNAPSHOT_SCHEMA_VERSION", + "inspect_snapshot", + "read_snapshot_metadata", + "write_snapshot_skeleton", ] diff --git a/src/underworld3/checkpoint/disk_snapshot.py b/src/underworld3/checkpoint/disk_snapshot.py new file mode 100644 index 00000000..2d178883 --- /dev/null +++ b/src/underworld3/checkpoint/disk_snapshot.py @@ -0,0 +1,265 @@ +"""On-disk snapshot format (v1.1) — metadata wrapper around PETSc bulk. + +Design (see ``memory/project_snapshot_v1_1_disk_format.md`` and +``docs/developer/design/in_memory_checkpoint_design.md``): + + - Single self-contained HDF5 file (with the freedom to externalise + bulky swarm data into companion files later — phase 3). + - UW3-controlled rich metadata wrapper around PETSc-format bulk + data, so ``h5ls`` / ``h5dump`` show useful information about a + snapshot file without UW3 needing to be in the loop. + - Bulk data layers (PETSc DMPlex topology, sections, vectors) are + delegated to the primitives that landed in #146; this module owns + the *layout* and the *metadata*, not the binary serialisation of + fields. + +File structure (target — phases 2+ fill in the bulk under these +groups; phase 1 writes the metadata and empty stub groups): + + my_run.snap.h5/ + ├── /metadata (attrs: uw3_version, schema_version, + │ created_at, step, sim_time, dt, dim, + │ mesh_type, coordinate_system, + │ mpi_ranks_at_write, variables_summary, ...) + ├── /mesh (phase 2 — DMPlex topology + coords + labels) + ├── /variables (phase 2 — one subgroup per mesh-variable) + ├── /swarms (phase 3 — possibly @external_file refs) + └── /python_state (phase 3 — Snapshottable dataclasses as attrs) + +Phase 1 (this commit): the metadata layer and the skeleton group +structure, with an inspectability acceptance test that asserts an +external reader (h5py here) sees meaningful information without any +UW3 imports. +""" + +from __future__ import annotations + +import datetime +import json +from typing import Any, Optional + +import numpy as np + +import underworld3 as uw + + +DISK_SNAPSHOT_SCHEMA_VERSION = 1 + +# Top-level group names — fixed; renaming would be a schema-version bump. +_GROUP_METADATA = "metadata" +_GROUP_MESH = "mesh" +_GROUP_VARIABLES = "variables" +_GROUP_SWARMS = "swarms" +_GROUP_PYTHON_STATE = "python_state" + +_TOP_LEVEL_GROUPS = ( + _GROUP_METADATA, + _GROUP_MESH, + _GROUP_VARIABLES, + _GROUP_SWARMS, + _GROUP_PYTHON_STATE, +) + + +def _collect_metadata(model) -> dict: + """Build the metadata dict that gets written into ``/metadata`` attrs. + + Stable, h5-friendly types only: strings, ints, floats, lists of + strings (stored as JSON for compactness and to keep h5 attrs + scalar-typed where possible). No pickling, no repr of UW3 objects. + """ + now_iso = datetime.datetime.now(datetime.timezone.utc).isoformat( + timespec="seconds" + ) + + # Mesh-derived info (gracefully absent if no mesh registered). + meshes = list(model._meshes.values()) + first_mesh = meshes[0] if meshes else None + mesh_names = [m.name for m in meshes] + if first_mesh is not None: + dim = int(first_mesh.dim) + mesh_type = type(first_mesh).__name__ + coord_system = ( + first_mesh.CoordinateSystem.type + if hasattr(first_mesh, "CoordinateSystem") + else "unknown" + ) + else: + dim = -1 + mesh_type = "" + coord_system = "" + + # Swarm names (WeakValueDictionary, snapshot it). + swarm_names = [] + for s in list(model._swarms.values()): + name = getattr(s, "name", None) or f"swarm_{s.instance_number}" + swarm_names.append(name) + + # State-bearer class summary (just the class names — useful in + # h5ls without needing UW3 to interpret). + state_bearer_classes = sorted( + {type(o).__name__ for o in list(model._state_bearers)} + ) + + # Tracker conventions (the model-dwelling record). + tracker = getattr(model, "tracker", None) + if tracker is not None: + sim_time = float(tracker.time) if tracker.time is not None else 0.0 + step = int(tracker.step) if tracker.step is not None else 0 + dt_val = tracker.dt + dt = float(dt_val) if dt_val is not None else float("nan") + else: + sim_time, step, dt = 0.0, 0, float("nan") + + # Per-variable summary across all registered meshes. + var_entries = [] + for m in meshes: + for var in m.vars.values(): + kind = "vector" if var.num_components > 1 else "scalar" + var_entries.append( + f"{m.name}.{var.clean_name} ({kind}, " + f"components={var.num_components}, degree={var.degree})" + ) + variables_summary = "; ".join(var_entries) if var_entries else "" + + md = { + # Versioning + "uw3_version": str(getattr(uw, "__version__", "0.0.0")), + "schema_version": int(DISK_SNAPSHOT_SCHEMA_VERSION), + "created_at": now_iso, + # Identity + "run_name": str(getattr(model, "name", "default")), + # Time / step (from the tracker — pre-seeded conventions) + "step": step, + "sim_time": sim_time, + "dt": dt, + # Geometry / topology + "dim": dim, + "mesh_type": mesh_type, + "coordinate_system": str(coord_system), + # MPI + "mpi_ranks_at_write": int(uw.mpi.size), + # Inventories — JSON for list-typed values so h5 attrs stay scalar. + "mesh_names_json": json.dumps(mesh_names), + "swarm_names_json": json.dumps(swarm_names), + "state_bearer_classes_json": json.dumps(state_bearer_classes), + "variables_summary": variables_summary, + } + return md + + +def _write_metadata_attrs(h5group, metadata: dict) -> None: + """Write a metadata dict as HDF5 attrs on a group. Plain types only.""" + for k, v in metadata.items(): + h5group.attrs[k] = v + + +def write_snapshot_skeleton(model, path: str) -> str: + """Phase 1: write the metadata + empty skeleton group structure. + + Returns the path written. Subsequent phases (2: mesh + meshvar + bulk; 3: swarms + python_state) populate the empty top-level + groups using PETSc primitives and dataclass serialisation + respectively. Writing is rank-0-only at this phase since no + collective PETSc operations are involved yet. + """ + import h5py + + with uw.selective_ranks(0) as should_execute: + if not should_execute: + uw.mpi.barrier() + return path + + metadata = _collect_metadata(model) + + with h5py.File(path, "w") as f: + md_group = f.create_group(_GROUP_METADATA) + _write_metadata_attrs(md_group, metadata) + + # Stub the other top-level groups so external readers can + # see the file's intended shape from day one — phases 2/3 + # populate them. + for name in ( + _GROUP_MESH, + _GROUP_VARIABLES, + _GROUP_SWARMS, + _GROUP_PYTHON_STATE, + ): + grp = f.create_group(name) + grp.attrs["filled_by"] = "" # set to "phase2" / "phase3" later + + uw.mpi.barrier() + return path + + +def read_snapshot_metadata(path: str) -> dict: + """Read the ``/metadata`` group's attrs back as a plain dict. + + Validates the schema version. Lists stored as ``*_json`` are + decoded back into Python lists for caller convenience but the + on-disk form stays JSON for h5-tool friendliness. + """ + import h5py + + with h5py.File(path, "r") as f: + if _GROUP_METADATA not in f: + raise ValueError( + f"{path}: not a UW3 snapshot file (no /{_GROUP_METADATA} group)" + ) + md_group = f[_GROUP_METADATA] + md = {} + for k in md_group.attrs.keys(): + v = md_group.attrs[k] + # h5py returns bytes for some string attrs; normalise to str. + if isinstance(v, bytes): + v = v.decode() + elif isinstance(v, np.ndarray) and v.dtype.kind in ("S", "U"): + v = [x.decode() if isinstance(x, bytes) else str(x) for x in v] + md[k] = v + + schema = int(md.get("schema_version", -1)) + if schema != DISK_SNAPSHOT_SCHEMA_VERSION: + raise ValueError( + f"{path}: snapshot schema version {schema} does not match " + f"current {DISK_SNAPSHOT_SCHEMA_VERSION}; on-disk schema " + f"migration will land with phase 6 (not yet implemented)" + ) + + # Decode JSON-encoded list fields for caller convenience. + for key in list(md.keys()): + if key.endswith("_json"): + try: + decoded = json.loads(md[key]) + md[key[:-5]] = decoded # e.g. "mesh_names" alongside "mesh_names_json" + except (TypeError, ValueError, json.JSONDecodeError): + pass + + return md + + +def inspect_snapshot(path: str) -> str: + """Human-readable one-shot summary of a snapshot file's metadata. + + Useful as a Python-side equivalent to running ``h5ls`` on the + file; intended for ``print(uw.checkpoint.inspect_snapshot(path))`` + at a notebook prompt. + """ + md = read_snapshot_metadata(path) + lines = [ + f"UW3 snapshot: {path}", + f" run_name : {md.get('run_name', '?')}", + f" created_at : {md.get('created_at', '?')}", + f" uw3_version : {md.get('uw3_version', '?')}", + f" schema_version : {md.get('schema_version', '?')}", + f" step / sim_time / dt : {md.get('step', '?')} / " + f"{md.get('sim_time', '?')} / {md.get('dt', '?')}", + f" dim / mesh_type : {md.get('dim', '?')} / " + f"{md.get('mesh_type', '?')}", + f" coordinate_system : {md.get('coordinate_system', '?')}", + f" mpi_ranks_at_write : {md.get('mpi_ranks_at_write', '?')}", + f" meshes : {md.get('mesh_names', [])}", + f" swarms : {md.get('swarm_names', [])}", + f" state_bearer_classes : {md.get('state_bearer_classes', [])}", + f" variables_summary : {md.get('variables_summary', '')}", + ] + return "\n".join(lines) diff --git a/tests/test_0010_snapshot_disk_format.py b/tests/test_0010_snapshot_disk_format.py new file mode 100644 index 00000000..a5aa1961 --- /dev/null +++ b/tests/test_0010_snapshot_disk_format.py @@ -0,0 +1,170 @@ +"""Phase 1 of the on-disk snapshot format (v1.1). + +These tests assert the *inspectability bar* — an external h5 reader +(here, h5py — but the assertions translate directly to ``h5ls`` +output) must see meaningful information about a UW3 snapshot file +without UW3 needing to be in the loop. They do not yet exercise any +PETSc bulk-data writes; that lands in phase 2. +""" + +import json + +import pytest +import numpy as np + +pytestmark = [pytest.mark.level_1, pytest.mark.tier_a] + + +def _fresh_model_with_state(tmp_path): + import underworld3 as uw + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 4.0 + ) + _ = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) + _ = uw.discretisation.MeshVariable("V", mesh, 2, degree=2) + swarm = uw.swarm.Swarm(mesh) + swarm.populate(fill_param=2) + + model.tracker.time = 3.14 + model.tracker.step = 42 + model.tracker.dt = 0.05 + return uw, model, mesh, swarm + + +def test_skeleton_writes_expected_group_structure(tmp_path): + """The file an h5 tool would open has exactly the documented + top-level group structure — no surprises for external readers.""" + import h5py + + uw, model, mesh, swarm = _fresh_model_with_state(tmp_path) + path = str(tmp_path / "phase1.snap.h5") + uw.checkpoint.write_snapshot_skeleton(model, path) + + with h5py.File(path, "r") as f: + top = set(f.keys()) + assert top == {"metadata", "mesh", "variables", "swarms", "python_state"} + + +def test_metadata_is_inspectable_without_uw3(tmp_path): + """The /metadata attrs h5py reads back are useful — i.e. an + external reader sees the run identity, schema, step/time, geometry, + MPI rank count, and the inventory of meshes/swarms/variables. + + Concretely: an h5py user (the proxy for h5ls/h5dump here) can + answer 'what's in this file?' from /metadata alone.""" + import h5py + + uw, model, mesh, swarm = _fresh_model_with_state(tmp_path) + path = str(tmp_path / "phase1.snap.h5") + uw.checkpoint.write_snapshot_skeleton(model, path) + + with h5py.File(path, "r") as f: + md = f["metadata"].attrs + # Identity / versioning + assert int(md["schema_version"]) == 1 + assert isinstance(str(md["created_at"]), str) + assert str(md["run_name"]) != "" + # Tracker conventions surfaced as scalars + assert float(md["sim_time"]) == 3.14 + assert int(md["step"]) == 42 + assert float(md["dt"]) == 0.05 + # Geometry + assert int(md["dim"]) == 2 + assert str(md["mesh_type"]) != "" + # MPI + assert int(md["mpi_ranks_at_write"]) >= 1 + # Inventories (JSON-encoded list-typed values) + var_summary = str(md["variables_summary"]) + assert "T" in var_summary and "V" in var_summary + swarm_names = json.loads(str(md["swarm_names_json"])) + assert len(swarm_names) == 1 + state_classes = json.loads(str(md["state_bearer_classes_json"])) + assert "ModelTracker" in state_classes + + +def test_read_snapshot_metadata_roundtrip(tmp_path): + """write -> read returns the same content, with JSON-encoded list + fields conveniently decoded for the caller.""" + import underworld3 as uw + + uw, model, mesh, swarm = _fresh_model_with_state(tmp_path) + path = str(tmp_path / "phase1.snap.h5") + uw.checkpoint.write_snapshot_skeleton(model, path) + + md = uw.checkpoint.read_snapshot_metadata(path) + assert md["schema_version"] == 1 + assert md["sim_time"] == 3.14 + assert md["step"] == 42 + assert md["dim"] == 2 + # Convenience: JSON-encoded lists are also exposed as plain lists. + assert isinstance(md["mesh_names"], list) and len(md["mesh_names"]) == 1 + assert isinstance(md["swarm_names"], list) + assert "ModelTracker" in md["state_bearer_classes"] + + +def test_read_rejects_non_snapshot_file(tmp_path): + """Pointing at an h5 file that isn't a UW3 snapshot raises + cleanly, not with an obscure h5py error.""" + import h5py + import underworld3 as uw + + path = str(tmp_path / "not-a-snapshot.h5") + with h5py.File(path, "w") as f: + f.create_dataset("random_dataset", data=np.zeros(3)) + + with pytest.raises(ValueError, match="not a UW3 snapshot"): + uw.checkpoint.read_snapshot_metadata(path) + + +def test_read_rejects_wrong_schema_version(tmp_path): + """A future-version snapshot we cannot interpret raises, with a + pointer to the (future) migration path.""" + import h5py + import underworld3 as uw + + path = str(tmp_path / "future.snap.h5") + with h5py.File(path, "w") as f: + md = f.create_group("metadata") + md.attrs["schema_version"] = 999 + + with pytest.raises(ValueError, match="schema version 999"): + uw.checkpoint.read_snapshot_metadata(path) + + +def test_inspect_snapshot_summary_includes_key_facts(tmp_path): + """The human-readable summary surface (intended for notebook + `print(...)` use) covers the same key facts external h5 inspection + would surface.""" + import underworld3 as uw + + uw, model, mesh, swarm = _fresh_model_with_state(tmp_path) + path = str(tmp_path / "phase1.snap.h5") + uw.checkpoint.write_snapshot_skeleton(model, path) + + summary = uw.checkpoint.inspect_snapshot(path) + assert "UW3 snapshot" in summary + assert "schema_version : 1" in summary + assert "sim_time" in summary and "3.14" in summary + assert "step" in summary and "42" in summary + assert "ModelTracker" in summary + + +def test_skeleton_groups_have_filled_by_marker(tmp_path): + """Empty top-level groups carry a `filled_by` attr so a phase-2/3 + reader knows whether their content is populated yet — and an + external inspector sees 'this group is empty because phase 2 + hasn't run' rather than getting nothing.""" + import h5py + import underworld3 as uw + + uw, model, mesh, swarm = _fresh_model_with_state(tmp_path) + path = str(tmp_path / "phase1.snap.h5") + uw.checkpoint.write_snapshot_skeleton(model, path) + + with h5py.File(path, "r") as f: + for name in ("mesh", "variables", "swarms", "python_state"): + assert "filled_by" in f[name].attrs + assert str(f[name].attrs["filled_by"]) == "" From bb0b0d3dd48c48ec11bebd8670350eeafabc9585 Mon Sep 17 00:00:00 2001 From: lmoresi Date: Wed, 20 May 2026 15:03:50 +1000 Subject: [PATCH 2/7] feat(snapshot v1.1, phase 2): mesh + meshvar bulk via #146 + bit-exact roundtrip MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Builds on phase 1's metadata wrapper to actually carry mesh + mesh- variable state to disk and read it back. Delegates the heavy lifting to #146's `Mesh.write_checkpoint` / `MeshVariable.read_checkpoint` PETSc-DMPlex primitives — phase 2's job is layout, dispatch, and tying the wrapper to the bulk data via a simple convention. Layout (final v1.1 shape): /path/to/run.snap.h5 wrapper (h5py-inspectable) /path/to/run.snap.bulk/ companion directory (one per snap) {mesh_safe}.mesh.00000.h5 {mesh_safe}.{var_clean}.00000.h5 Wrapper carries /meshes/{mesh_safe}/ with @name, @mesh_file, and /meshes/{mesh_safe}/variables/{var_safe}/ with @name, @components, @degree, @continuous, @external_file. The bulk-dir path is derived from the wrapper path by convention (`.h5` → `.bulk`), so no external_file attr is needed for the standard placement. Move them together; a clear FileNotFoundError fires if bulk is missing on read. Phase 1 layout refactor folded in: - /mesh (singular) → /meshes (plural) — supports multi-mesh natively. - /variables removed from the top level — now nests under each mesh as /meshes/{name}/variables/{var}, matching the in-memory snapshot's mesh→vars structure. New API: - `write_snapshot(model, path)` — writes wrapper + bulk; covers every registered mesh and every allocated meshvar on each mesh. Lazy-allocated vars (_gvec is None) are skipped — same rule as the in-memory path. - `read_snapshot(model, path)` — loads var DOFs back into already- registered meshes by name. Mesh / variable mismatch raises a clear ValueError (mesh-rebuild on read is v1.2 scope). - `write_snapshot_skeleton` / `read_snapshot_metadata` / `inspect_snapshot` stay as phase-1 metadata-only entry points. Branch hygiene: merged origin/development (which now has #146) into this branch so the new code can actually call read_checkpoint. The merge was clean — #146 and the snapshot toolkit only overlap at different methods in `discretisation_mesh.py`, as the earlier analysis predicted. PR target will be development once #195/#196 land; the diff stays clean because the merged dev commits are already there. Tests (12 total, 5 new in phase 2, tier_a level_1): - write produces wrapper + bulk-dir with the expected file pattern - wrapper populated with the per-mesh + per-var metadata that makes inspectability self-sufficient - bit-exact write→scribble→read roundtrip on a 2D mesh with one scalar + one vector variable (np.array_equal, zero tolerance) - missing bulk-dir → clear FileNotFoundError - mismatched mesh on read → clear ValueError (not an obscure h5py trace) Regression: 64 tests pass (24 snapshot + 9 tracker + 12 disk-format + 19 core/regression). Phase 3 next: swarms (with the @external_file freedom kept open for bulky swarms) + /python_state for DDt + ModelTracker via dataclass- to-HDF5-attrs serialisation. Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- src/underworld3/checkpoint/__init__.py | 4 + src/underworld3/checkpoint/disk_snapshot.py | 195 +++++++++++++++++++- tests/test_0010_snapshot_disk_format.py | 156 +++++++++++++++- 3 files changed, 347 insertions(+), 8 deletions(-) diff --git a/src/underworld3/checkpoint/__init__.py b/src/underworld3/checkpoint/__init__.py index 31863784..ad17d5b9 100644 --- a/src/underworld3/checkpoint/__init__.py +++ b/src/underworld3/checkpoint/__init__.py @@ -30,7 +30,9 @@ from .disk_snapshot import ( DISK_SNAPSHOT_SCHEMA_VERSION, inspect_snapshot, + read_snapshot, read_snapshot_metadata, + write_snapshot, write_snapshot_skeleton, ) @@ -48,6 +50,8 @@ "TrackerState", "DISK_SNAPSHOT_SCHEMA_VERSION", "inspect_snapshot", + "read_snapshot", "read_snapshot_metadata", + "write_snapshot", "write_snapshot_skeleton", ] diff --git a/src/underworld3/checkpoint/disk_snapshot.py b/src/underworld3/checkpoint/disk_snapshot.py index 2d178883..f88c133a 100644 --- a/src/underworld3/checkpoint/disk_snapshot.py +++ b/src/underworld3/checkpoint/disk_snapshot.py @@ -36,6 +36,7 @@ import datetime import json +import os from typing import Any, Optional import numpy as np @@ -46,16 +47,17 @@ DISK_SNAPSHOT_SCHEMA_VERSION = 1 # Top-level group names — fixed; renaming would be a schema-version bump. +# Variables are NOT a top-level group; they nest under each mesh: +# /meshes/{name}/variables/{var}. Swarms similarly carry their own +# variables when phase 3 lands. _GROUP_METADATA = "metadata" -_GROUP_MESH = "mesh" -_GROUP_VARIABLES = "variables" +_GROUP_MESHES = "meshes" _GROUP_SWARMS = "swarms" _GROUP_PYTHON_STATE = "python_state" _TOP_LEVEL_GROUPS = ( _GROUP_METADATA, - _GROUP_MESH, - _GROUP_VARIABLES, + _GROUP_MESHES, _GROUP_SWARMS, _GROUP_PYTHON_STATE, ) @@ -180,8 +182,7 @@ def write_snapshot_skeleton(model, path: str) -> str: # see the file's intended shape from day one — phases 2/3 # populate them. for name in ( - _GROUP_MESH, - _GROUP_VARIABLES, + _GROUP_MESHES, _GROUP_SWARMS, _GROUP_PYTHON_STATE, ): @@ -237,6 +238,188 @@ def read_snapshot_metadata(path: str) -> dict: return md +# ----- Phase 2: mesh + meshvar bulk via #146's PETSc primitives ----- +# +# Layout convention: +# +# /path/to/run.snap.h5 wrapper (metadata, h5py-readable) +# /path/to/run.snap.bulk/ companion directory (one per snapshot) +# {mesh_safe}.mesh.00000.h5 mesh DM dump (PETSc HDF5) +# {mesh_safe}.{var_clean}.00000.h5 per-variable section + vec (PETSc HDF5) +# ... one set per (mesh, var) ... +# +# The bulk-dir path is derived from the wrapper path by convention, so a +# user opening just the wrapper file can find the bulk. They are a unit +# for portability — move them together. + + +def _bulk_dir_for(wrapper_path: str) -> str: + """Convention: wrapper at `run.snap.h5` ⇒ bulk at `run.snap.bulk/`.""" + base = wrapper_path[:-3] if wrapper_path.endswith(".h5") else wrapper_path + return base + ".bulk" + + +def _sanitise(name: str) -> str: + """Sanitise a mesh / variable name for use as a filename component. + + Replaces anything that isn't alphanumeric or in ``._-`` with ``_``. + Falls back to ``unnamed`` if the result is empty. The original name + is preserved in HDF5 group attrs as the ``@name`` field. + """ + safe = "".join(c if c.isalnum() or c in "._-" else "_" for c in name) + return safe or "unnamed" + + +def write_snapshot(model, path: str) -> str: + """Write a complete on-disk snapshot of the model's mesh + mesh-variable + state (phase 2 scope; swarms and python_state land in phase 3). + + Produces two artifacts: + + - ``path`` — the wrapper HDF5 file with rich metadata and the group + structure inspectable via ``h5ls``. + - ``_bulk_dir_for(path)`` — companion directory containing the + PETSc HDF5 files (mesh DM + per-variable section/vec) produced + by #146's :meth:`Mesh.write_checkpoint`. + + The two are a unit; move them together. Returns the wrapper path. + """ + import h5py + + # Phase-1 layer: metadata + skeleton groups. + write_snapshot_skeleton(model, path) + bulk_dir = _bulk_dir_for(path) + + # rank-0 creates the bulk directory; collective ops below need it + # to exist on the rank doing the PETSc-HDF5 write (which is rank 0 + # in this single-file write — actually PETSc's HDF5 viewer is + # collective, so all ranks participate). + with uw.selective_ranks(0) as rank0: + if rank0: + os.makedirs(bulk_dir, exist_ok=True) + uw.mpi.barrier() + + # For each registered mesh, drive #146's write_checkpoint into the + # bulk directory. write_checkpoint is collective (PETSc HDF5 + # viewer), so all ranks must participate. + mesh_records: list[dict] = [] + for mesh in list(model._meshes.values()): + mesh_safe = _sanitise(mesh.name) + mesh_vars = list(mesh.vars.values()) + # Filter to allocated variables — same skip rule as the in-memory + # path: lazy-allocated vars with _gvec == None have no data. + mesh_vars = [v for v in mesh_vars if v._gvec is not None] + + mesh.write_checkpoint( + mesh_safe, + outputPath=bulk_dir, + meshVars=mesh_vars, + index=0, + ) + + mesh_records.append({ + "name": mesh.name, + "safe_name": mesh_safe, + "mesh_file": f"{mesh_safe}.mesh.00000.h5", + "vars": [ + { + "name": v.clean_name, + "components": int(v.num_components), + "degree": int(v.degree), + "continuous": bool(v.continuous), + # Per-variable file produced by Mesh.write_checkpoint + # at outputPath: "{base}.{var.clean_name}.{index:05}.h5". + "external_file": ( + f"{mesh_safe}.{v.clean_name}.00000.h5" + ), + } + for v in mesh_vars + ], + }) + + # Reopen the wrapper to populate /meshes with the per-mesh records + # and to mark the groups filled. + with uw.selective_ranks(0) as rank0: + if rank0: + with h5py.File(path, "a") as f: + meshes_group = f[_GROUP_MESHES] + meshes_group.attrs["filled_by"] = "phase2" + meshes_group.attrs["bulk_dir"] = os.path.basename(bulk_dir) + + for rec in mesh_records: + g = meshes_group.create_group(rec["safe_name"]) + g.attrs["name"] = rec["name"] + g.attrs["mesh_file"] = rec["mesh_file"] + + vars_g = g.create_group("variables") + for var_rec in rec["vars"]: + v = vars_g.create_group(_sanitise(var_rec["name"])) + v.attrs["name"] = var_rec["name"] + v.attrs["components"] = var_rec["components"] + v.attrs["degree"] = var_rec["degree"] + v.attrs["continuous"] = var_rec["continuous"] + v.attrs["external_file"] = var_rec["external_file"] + + uw.mpi.barrier() + return path + + +def read_snapshot(model, path: str) -> None: + """Load mesh-variable DOFs from an on-disk snapshot into the model. + + The model must already have the same meshes (by name) and the + same variables (by ``clean_name``) registered — this is the + same-rank-count restart path that mirrors :func:`restore` for the + in-memory snapshot. Cross-run / rebuild-on-load is v1.2 scope. + + Bulk data is read via #146's :meth:`MeshVariable.read_checkpoint`; + no KDTree remapping (that's phase 4's compatibility layer in + ``read_timestep``). + """ + import h5py + + md = read_snapshot_metadata(path) + bulk_dir = _bulk_dir_for(path) + if not os.path.isdir(bulk_dir): + raise FileNotFoundError( + f"snapshot bulk directory missing: {bulk_dir} (expected next " + f"to wrapper {path})" + ) + + # Build {original_name -> registered Mesh} for lookup + meshes_by_name = {m.name: m for m in model._meshes.values()} + + with h5py.File(path, "r") as f: + meshes_group = f[_GROUP_MESHES] + for mesh_safe in meshes_group.keys(): + g = meshes_group[mesh_safe] + mesh_name = str(g.attrs.get("name", mesh_safe)) + mesh = meshes_by_name.get(mesh_name) + if mesh is None: + raise ValueError( + f"snapshot at {path} contains mesh {mesh_name!r} which " + f"is not registered on this model " + f"(registered: {sorted(meshes_by_name.keys())})" + ) + + current_vars = {v.clean_name: v for v in mesh.vars.values()} + vars_g = g["variables"] + for var_safe in vars_g.keys(): + v_attrs = vars_g[var_safe].attrs + var_name = str(v_attrs["name"]) + external_file = str(v_attrs["external_file"]) + var = current_vars.get(var_name) + if var is None: + raise ValueError( + f"snapshot variable {var_name!r} not registered on " + f"mesh {mesh_name!r}" + ) + var.read_checkpoint( + os.path.join(bulk_dir, external_file), + data_name=var_name, + ) + + def inspect_snapshot(path: str) -> str: """Human-readable one-shot summary of a snapshot file's metadata. diff --git a/tests/test_0010_snapshot_disk_format.py b/tests/test_0010_snapshot_disk_format.py index a5aa1961..026b7e75 100644 --- a/tests/test_0010_snapshot_disk_format.py +++ b/tests/test_0010_snapshot_disk_format.py @@ -45,7 +45,9 @@ def test_skeleton_writes_expected_group_structure(tmp_path): with h5py.File(path, "r") as f: top = set(f.keys()) - assert top == {"metadata", "mesh", "variables", "swarms", "python_state"} + # Variables nest under each mesh (/meshes/{name}/variables/...) so + # they are not a top-level group. + assert top == {"metadata", "meshes", "swarms", "python_state"} def test_metadata_is_inspectable_without_uw3(tmp_path): @@ -165,6 +167,156 @@ def test_skeleton_groups_have_filled_by_marker(tmp_path): uw.checkpoint.write_snapshot_skeleton(model, path) with h5py.File(path, "r") as f: - for name in ("mesh", "variables", "swarms", "python_state"): + for name in ("meshes", "swarms", "python_state"): assert "filled_by" in f[name].attrs assert str(f[name].attrs["filled_by"]) == "" + + +# ----- Phase 2: mesh + mesh-variable bulk via #146 ----- + + +def _fresh_model_mesh_and_vars(): + import underworld3 as uw + + uw.reset_default_model() + uw.use_strict_units(False) + uw.use_nondimensional_scaling(False) + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 4.0 + ) + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=1) + V = uw.discretisation.MeshVariable("V", mesh, 2, degree=2) + return uw, model, mesh, T, V + + +def test_write_snapshot_produces_wrapper_and_bulk_dir(tmp_path): + """The two artifacts the convention promises: wrapper file + + sibling .bulk/ directory containing PETSc HDF5 files.""" + import os + import underworld3 as uw + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + T.array[:, 0, 0] = 5.0 + V.array[:, 0, 0] = -3.0 + + path = str(tmp_path / "run.snap.h5") + uw.checkpoint.write_snapshot(model, path) + + bulk = str(tmp_path / "run.snap.bulk") + assert os.path.exists(path) + assert os.path.isdir(bulk) + # #146-format files in the bulk dir. + files = sorted(os.listdir(bulk)) + # At least: mesh file + one file per variable. + assert any(f.endswith(".mesh.00000.h5") for f in files) + assert any("T.00000.h5" in f for f in files) + assert any("V.00000.h5" in f for f in files) + + +def test_write_snapshot_populates_wrapper_layout(tmp_path): + """The wrapper carries the per-mesh + per-variable metadata that + makes 'what's in this snapshot?' answerable from h5py alone.""" + import h5py + import underworld3 as uw + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + T.array[:, 0, 0] = 1.0 + V.array[:, 0, 0] = 2.0 + V.array[:, 0, 1] = 3.0 + + path = str(tmp_path / "run.snap.h5") + uw.checkpoint.write_snapshot(model, path) + + with h5py.File(path, "r") as f: + assert f["meshes"].attrs["filled_by"] == "phase2" + # One mesh subgroup + mesh_names = list(f["meshes"].keys()) + assert len(mesh_names) == 1 + mg = f["meshes"][mesh_names[0]] + # Per-mesh attrs + assert mg.attrs["name"] == mesh.name + assert mg.attrs["mesh_file"].endswith(".mesh.00000.h5") + # Variables subgroup + var_names = sorted(mg["variables"].keys()) + assert var_names == ["T", "V"] + # Per-var attrs include shape info + external_file pointer. + v_attrs = mg["variables"]["V"].attrs + assert v_attrs["components"] == 2 + assert v_attrs["degree"] == 2 + assert v_attrs["external_file"].endswith("V.00000.h5") + + +def test_write_read_snapshot_bit_exact_roundtrip(tmp_path): + """The core phase-2 guarantee: write a snapshot, scribble all + variables, read snapshot back, all variables match write-time + values bit-for-bit (#146's PETSc DMPlex same-rank reload, just + delivered via the wrapper).""" + import underworld3 as uw + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + T.array[:, 0, 0] = 5.0 * T.coords[:, 0] - 2.0 + V.array[:, 0, 0] = 3.0 * V.coords[:, 0] + V.array[:, 0, 1] = 7.0 * V.coords[:, 1] + T_pre = np.asarray(T.array[...]).copy() + V_pre = np.asarray(V.array[...]).copy() + + path = str(tmp_path / "run.snap.h5") + uw.checkpoint.write_snapshot(model, path) + + # Scribble. + T.array[...] = -99.0 + V.array[...] = -99.0 + + uw.checkpoint.read_snapshot(model, path) + + assert np.array_equal(np.asarray(T.array[...]), T_pre), ( + f"T not bit-exact after read_snapshot — max|d|=" + f"{float(np.max(np.abs(np.asarray(T.array[...]) - T_pre))):.3e}" + ) + assert np.array_equal(np.asarray(V.array[...]), V_pre), ( + f"V not bit-exact after read_snapshot — max|d|=" + f"{float(np.max(np.abs(np.asarray(V.array[...]) - V_pre))):.3e}" + ) + + +def test_read_snapshot_rejects_missing_bulk_dir(tmp_path): + """If the user moves the wrapper without the bulk dir, read fails + with a clear pointer rather than an obscure h5py error.""" + import os + import underworld3 as uw + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + path = str(tmp_path / "run.snap.h5") + uw.checkpoint.write_snapshot(model, path) + + # Delete the bulk dir to simulate the move-the-wrapper-only mistake. + import shutil + shutil.rmtree(str(tmp_path / "run.snap.bulk")) + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + with pytest.raises(FileNotFoundError, match="bulk directory missing"): + uw.checkpoint.read_snapshot(model, path) + + +def test_read_snapshot_rejects_mismatched_mesh(tmp_path): + """If the target model's meshes don't match the snapshot's, raise + clearly — mesh-rebuild on read is v1.2 scope.""" + import underworld3 as uw + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + path = str(tmp_path / "run.snap.h5") + uw.checkpoint.write_snapshot(model, path) + + # Fresh model with a *different* mesh — write_snapshot's mesh.name + # won't match. + uw.reset_default_model() + model2 = uw.get_default_model() + other = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(2.0, 2.0), cellSize=1.0 / 3.0, + ) + # Don't reuse the original mesh's name — make the lookup miss. + other.name = "definitely_a_different_mesh" + + with pytest.raises(ValueError, match="not registered on this model"): + uw.checkpoint.read_snapshot(model2, path) From d3ff91d19320ff599c9725a089628be6d516fbce Mon Sep 17 00:00:00 2001 From: lmoresi Date: Wed, 20 May 2026 16:39:48 +1000 Subject: [PATCH 3/7] =?UTF-8?q?feat(snapshot=20v1.1,=20phase=203a):=20/pyt?= =?UTF-8?q?hon=5Fstate=20=E2=80=94=20state-bearer=20round-trip?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Serialises every registered Snapshottable's .state dataclass into a per-bearer group under /python_state, keyed by the same stable name the in-memory snapshot uses (f"{type(obj).__name__}_{obj.instance_number}"). ModelTracker (always auto-registered) and DDt state therefore now travel with the disk snapshot in addition to the mesh + meshvar bulk from phase 2. Generic field serialisation (no per-class code): - None -> attr "__none__" sentinel - bool/int/float/str-> scalar attr (preserves type via h5py) - numpy.ndarray -> dataset - list/tuple -> attr __json (JSON, handles None) - dict -> subgroup, recursive (used by TrackerState.managed) - unhandleable -> attr __skipped = "" — restore keeps the *current* live value rather than clobbering it with a placeholder, so a documented partial round-trip (e.g. DDtSymbolicState.psi_star which is sympy and would need srepr+sympify) doesn't break. Restore uses the live obj.state as a type template + dataclasses. replace(...): captured fields override; skipped fields keep their current value. ValueError on state-bearer-not-registered keeps the same-rank/same-model contract. Tests (4 new, 16 total tier_a level_1): - tracker time/step/dt + user-added quantity (scalar + numpy array) round-trip exactly through disk - /python_state group is h5py-inspectable: __bearer_class__, __state_class__, instance_number; TrackerState.managed visible as a subgroup with each managed key as an attr (so h5ls shows 'time', 'step', 'dt', 'my_q' directly) - Symbolic DDt's primary BDF-control fields (dt_history, history_initialised, n_solves_completed, dt) round-trip; psi_star (sympy) is documented as skipped — restore keeps current value - mismatched state-bearer set on read raises clearly Phase 3b next: swarms in a per-swarm sidecar from day one (per Louis's "break out swarms" direction — bulk is always a swarm problem, so don't even try inline). Regression: 68 tests pass (24 in-memory + 9 tracker + 16 disk-format + 19 core/regression). Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- src/underworld3/checkpoint/disk_snapshot.py | 193 ++++++++++++++++++++ tests/test_0010_snapshot_disk_format.py | 126 +++++++++++++ 2 files changed, 319 insertions(+) diff --git a/src/underworld3/checkpoint/disk_snapshot.py b/src/underworld3/checkpoint/disk_snapshot.py index f88c133a..60494b6a 100644 --- a/src/underworld3/checkpoint/disk_snapshot.py +++ b/src/underworld3/checkpoint/disk_snapshot.py @@ -34,6 +34,7 @@ from __future__ import annotations +import dataclasses import datetime import json import os @@ -360,6 +361,16 @@ def write_snapshot(model, path: str) -> str: v.attrs["continuous"] = var_rec["continuous"] v.attrs["external_file"] = var_rec["external_file"] + # Phase 3a: state-bearer dataclass serialisation. + ps_group = f[_GROUP_PYTHON_STATE] + ps_group.attrs["filled_by"] = "phase3a" + for obj in list(model._state_bearers): + key = f"{type(obj).__name__}_{obj.instance_number}" + if key in ps_group: + continue # idempotent if write_snapshot called twice + bg = ps_group.create_group(key) + _write_state_bearer_to_group(bg, obj) + uw.mpi.barrier() return path @@ -419,6 +430,188 @@ def read_snapshot(model, path: str) -> None: data_name=var_name, ) + # Phase 3a: restore state-bearer dataclasses. + if _GROUP_PYTHON_STATE in f: + ps_group = f[_GROUP_PYTHON_STATE] + bearers_by_key = { + f"{type(o).__name__}_{o.instance_number}": o + for o in list(model._state_bearers) + } + for key in ps_group.keys(): + obj = bearers_by_key.get(key) + if obj is None: + raise ValueError( + f"snapshot at {path} contains state-bearer {key!r} " + f"that is not registered on this model" + ) + _read_state_bearer_into(ps_group[key], obj) + + +# ----- Phase 3a: state-bearer (Snapshottable) serialisation ---------------- +# +# Each registered state-bearer (DDt instances, ModelTracker, future +# helpers) exposes a `.state` property returning a SnapshottableState +# dataclass. We serialise each dataclass's fields into a per-bearer +# HDF5 group under /python_state, keyed by the same stable name the +# in-memory snapshot uses: f"{type(obj).__name__}_{obj.instance_number}". +# +# Serialisation is *generic over dataclass fields* — no per-class +# special code. Handled value types: None, bool, int, float, str, +# numpy.ndarray, list (JSON-encoded), dict (recursive subgroup). Other +# types (notably sympy expressions in DDtSymbolicState.psi_star) are +# marked with `_skipped` and not round-tripped — documented as +# a v1.x limitation; consumers either use a non-Symbolic DDt flavor +# or accept the psi_star reset. + + +_NULL_SENTINEL = "__none__" +_TYPE_ATTR = "__bearer_class__" +_DATACLASS_ATTR = "__state_class__" + + +def _is_h5_attr_scalar(value: Any) -> bool: + return isinstance(value, (bool, int, float, str)) and not isinstance( + value, bool + ) or isinstance(value, (bool, str)) + + +def _serialise_field(h5group, name: str, value: Any) -> None: + """Write a Python value into an HDF5 group as attr/dataset/subgroup. + + The shape of the storage records the type: + - attr scalar (int/float/bool/str) for scalars + - attr `` = '__none__' for None + - attr `__json` for JSON-encodable lists / nested simple structures + - dataset `` for numpy arrays + - subgroup `` for dict values, recursing + - attr `__skipped` = '' for anything else + """ + if value is None: + h5group.attrs[name] = _NULL_SENTINEL + return + if isinstance(value, (bool, int, float)): + h5group.attrs[name] = value + return + if isinstance(value, str): + h5group.attrs[name] = value + return + if isinstance(value, np.ndarray): + if name in h5group: + del h5group[name] + h5group.create_dataset(name, data=value) + return + if isinstance(value, dict): + if name in h5group: + del h5group[name] + sub = h5group.create_group(name) + for k, v in value.items(): + _serialise_field(sub, str(k), v) + return + if isinstance(value, (list, tuple)): + try: + h5group.attrs[name + "__json"] = json.dumps(list(value)) + return + except (TypeError, ValueError): + h5group.attrs[name + "__skipped"] = ( + f"unserialisable list (len={len(value)}, " + f"first-type={type(value[0]).__name__ if value else 'empty'})" + ) + return + h5group.attrs[name + "__skipped"] = ( + f"unserialisable type {type(value).__name__}" + ) + + +def _group_to_dict(h5group) -> dict: + """Read a subgroup back as a plain dict — symmetric to the dict + branch of :func:`_serialise_field`. Recurses for nested groups.""" + import h5py + + out: dict = {} + for k in h5group.attrs.keys(): + if k.endswith("__skipped"): + continue + if k.endswith("__json"): + out[k[: -len("__json")]] = json.loads(h5group.attrs[k]) + continue + v = h5group.attrs[k] + if isinstance(v, str) and v == _NULL_SENTINEL: + out[k] = None + elif isinstance(v, np.generic): + out[k] = v.item() + else: + out[k] = v + for k in h5group.keys(): + item = h5group[k] + if isinstance(item, h5py.Group): + out[k] = _group_to_dict(item) + else: + out[k] = np.asarray(item[...]) + return out + + +def _deserialise_field(h5group, name: str, fallback: Any) -> Any: + """Inverse of :func:`_serialise_field`. Returns ``fallback`` if the + field was skipped at write time, so we don't clobber a sensible + default with a placeholder.""" + import h5py + + if name in h5group: + item = h5group[name] + if isinstance(item, h5py.Group): + return _group_to_dict(item) + return np.asarray(item[...]) # h5py.Dataset + + if name in h5group.attrs: + v = h5group.attrs[name] + if isinstance(v, str) and v == _NULL_SENTINEL: + return None + if isinstance(v, np.generic): + return v.item() + return v + + if (name + "__json") in h5group.attrs: + return json.loads(h5group.attrs[name + "__json"]) + + if (name + "__skipped") in h5group.attrs: + # Skipped at write time — keep the current value rather than + # clobber it with a placeholder. + return fallback + + return fallback + + +def _write_state_bearer_to_group(group, obj) -> None: + """Serialise a Snapshottable's .state into the given HDF5 group.""" + state = obj.state + group.attrs[_TYPE_ATTR] = type(obj).__name__ + group.attrs[_DATACLASS_ATTR] = type(state).__name__ + group.attrs["instance_number"] = int(obj.instance_number) + + for f in dataclasses.fields(state): + _serialise_field(group, f.name, getattr(state, f.name)) + + +def _read_state_bearer_into(group, obj) -> None: + """Restore a Snapshottable's .state from a group written by + :func:`_write_state_bearer_to_group`. Uses the live ``obj.state`` + as a type template — fields that were skipped at write time keep + their current value rather than being clobbered by a placeholder. + """ + current_state = obj.state + captured_class = str(group.attrs.get(_DATACLASS_ATTR, "")) + if captured_class and captured_class != type(current_state).__name__: + raise ValueError( + f"state-bearer class mismatch: snapshot expects " + f"{captured_class}, current is {type(current_state).__name__}" + ) + + overrides = {} + for f in dataclasses.fields(current_state): + new_val = _deserialise_field(group, f.name, getattr(current_state, f.name)) + overrides[f.name] = new_val + obj.state = dataclasses.replace(current_state, **overrides) + def inspect_snapshot(path: str) -> str: """Human-readable one-shot summary of a snapshot file's metadata. diff --git a/tests/test_0010_snapshot_disk_format.py b/tests/test_0010_snapshot_disk_format.py index 026b7e75..b8514b45 100644 --- a/tests/test_0010_snapshot_disk_format.py +++ b/tests/test_0010_snapshot_disk_format.py @@ -320,3 +320,129 @@ def test_read_snapshot_rejects_mismatched_mesh(tmp_path): with pytest.raises(ValueError, match="not registered on this model"): uw.checkpoint.read_snapshot(model2, path) + + +# ----- Phase 3a: state-bearer (Snapshottable) serialisation ----- + + +def test_tracker_round_trips_through_disk_snapshot(tmp_path): + """ModelTracker is always auto-registered as a state-bearer, so + every snapshot must round-trip its time/step/dt and any + user-added managed quantities exactly.""" + import underworld3 as uw + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + model.tracker.time = 3.14 + model.tracker.step = 42 + model.tracker.dt = 0.05 + model.tracker.my_diagnostic = 99.0 + model.tracker.history_arr = np.array([1.0, 2.0, 3.0]) + + path = str(tmp_path / "run.snap.h5") + uw.checkpoint.write_snapshot(model, path) + + # Scribble everything tracker-side. + model.tracker.time = -1.0 + model.tracker.step = -1 + model.tracker.dt = -1.0 + model.tracker.my_diagnostic = -1.0 + model.tracker.history_arr = np.array([-1.0, -1.0, -1.0]) + + uw.checkpoint.read_snapshot(model, path) + + assert model.tracker.time == 3.14 + assert model.tracker.step == 42 + assert model.tracker.dt == 0.05 + assert model.tracker.my_diagnostic == 99.0 + assert np.array_equal( + np.asarray(model.tracker.history_arr), np.array([1.0, 2.0, 3.0]) + ) + + +def test_python_state_group_is_inspectable(tmp_path): + """An external h5py reader sees the per-bearer groups under + /python_state with class info + a managed dict for the tracker.""" + import h5py + import underworld3 as uw + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + model.tracker.time = 1.0 + model.tracker.step = 2 + model.tracker.my_q = 7.0 + + path = str(tmp_path / "run.snap.h5") + uw.checkpoint.write_snapshot(model, path) + + with h5py.File(path, "r") as f: + ps = f["python_state"] + assert ps.attrs["filled_by"] == "phase3a" + + tracker_keys = [k for k in ps.keys() if k.startswith("ModelTracker_")] + assert len(tracker_keys) == 1 + tg = ps[tracker_keys[0]] + assert tg.attrs["__bearer_class__"] == "ModelTracker" + assert tg.attrs["__state_class__"] == "TrackerState" + # TrackerState.managed is a dict, stored as a sub-group. + assert "managed" in tg and isinstance(tg["managed"], h5py.Group) + managed = tg["managed"] + # Pre-seeded conventions present. + assert "time" in managed.attrs + assert float(managed.attrs["time"]) == 1.0 + assert int(managed.attrs["step"]) == 2 + # User-added quantity present. + assert "my_q" in managed.attrs + assert float(managed.attrs["my_q"]) == 7.0 + + +def test_ddt_symbolic_state_round_trips_primary_fields(tmp_path): + """A Symbolic DDt has dt_history, history_initialised, + n_solves_completed, dt round-tripped via the generic dataclass + serialiser. psi_star (sympy) is documented as skipped — the + primary BDF-control fields are what matter for re-continuing.""" + import underworld3 as uw + import sympy + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + ddt = uw.systems.ddt.Symbolic(T.sym, order=2) + ddt._dt_history = [0.05, 0.03] + ddt._history_initialised = True + ddt._n_solves_completed = 2 + ddt._dt = 0.05 + + path = str(tmp_path / "run.snap.h5") + uw.checkpoint.write_snapshot(model, path) + + # Scribble the primary fields. + ddt._dt_history = [None, None] + ddt._history_initialised = False + ddt._n_solves_completed = 0 + ddt._dt = None + + uw.checkpoint.read_snapshot(model, path) + + assert ddt.state.dt_history == [0.05, 0.03] + assert ddt.state.history_initialised is True + assert ddt.state.n_solves_completed == 2 + assert ddt.state.dt == 0.05 + + +def test_read_snapshot_rejects_missing_state_bearer(tmp_path): + """If a state-bearer exists in the snapshot but not on the load- + target model, raise — same-rank/same-model contract.""" + import underworld3 as uw + + # Source model has a Symbolic DDt; snapshot it. + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + ddt = uw.systems.ddt.Symbolic(T.sym, order=2) + ddt._dt_history = [0.05, 0.05] + + path = str(tmp_path / "run.snap.h5") + uw.checkpoint.write_snapshot(model, path) + + # Target model has no DDt. + uw, model2, mesh2, T2, V2 = _fresh_model_mesh_and_vars() + # Force name match so the mesh part loads. + mesh2.name = mesh.name + + with pytest.raises(ValueError, match="state-bearer .* not registered"): + uw.checkpoint.read_snapshot(model2, path) From 94f3b7027d9619fcf5fd7623b3f41f33bac6545f Mon Sep 17 00:00:00 2001 From: lmoresi Date: Wed, 20 May 2026 16:44:18 +1000 Subject: [PATCH 4/7] feat(snapshot v1.1, phase 3b): swarms in per-swarm sidecars MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per Louis's direction ("break out the swarm information into a separate file in the first instance — bulk is a problem with swarms, always"), swarms always go to their own h5py-direct sidecar from day one. No inline-vs-split toggle — sidecar is the only path. Layout: /path/to/run.snap.h5 wrapper /path/to/run.snap.bulk/{swarm_safe}.swarm.h5 swarm sidecar (one per swarm) Sidecar structure (h5py-native, no PETSc — swarms aren't DMPlex section/vec): @num_particles_local, @dim, @mesh_name, @population_generation /coordinates dataset, (n_local, dim) /variables/{var_clean_name} dataset, (n_local, num_components) @num_components, @dtype The sidecar's top-level @attrs and group structure mean `h5ls -v` on the sidecar alone tells you "this holds N particles in dim D on mesh M with these variables" — same inspectability bar as the wrapper. Wrapper /swarms/{swarm_safe}/ carries metadata + the @external_file pointer to the sidecar in the bulk dir. Restore mirrors the in-memory Swarm.apply_snapshot_payload exactly: clear local population via dm.removePoint loop, addNPoints at saved coords, write var data back. Same rebuild-on-restore semantics — the disk snapshot recovers from a particle-population mutation (added particles between snapshot and restore) just like the in-memory path does, proven by test_swarm_restore_recovers_after_particle_count_change. Tests (5 new, 21 total tier_a level_1): - swarm sidecar lands in bulk dir with predictable name; wrapper records external_file ref + mesh_name + var inventory - sidecar is self-inspectable via h5py (file-level attrs + /coordinates + /variables with per-var attrs) - whole swarm (coords + svar data) round-trips bit-exact through write → scribble → read - rebuild-on-restore parity with in-memory path: snapshot, mutate population, restore → exact local population recovered - PETSc-internal DMSwarm_* variables filtered at capture (same rule as in-memory) MPI: single-rank only in this phase. The current rank-0-only sidecar write only captures rank 0's local particles in a parallel run. Phase 6 will either use h5py-mpi parallel HDF5 or per-rank sidecars to match #195's parallel exact-reconstruction guarantee. 73 tests pass (24 in-memory + 9 tracker + 21 disk-format + 19 core/regression). Phase 4 next: format detection + dispatch in MeshVariable.read_timestep so it reads BOTH the legacy per-variable layout AND the new v1.1 sidecar format via the KDTree bridge. Closes the compatibility commitment from the design discussion. Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- src/underworld3/checkpoint/disk_snapshot.py | 202 ++++++++++++++++++++ tests/test_0010_snapshot_disk_format.py | 161 ++++++++++++++++ 2 files changed, 363 insertions(+) diff --git a/src/underworld3/checkpoint/disk_snapshot.py b/src/underworld3/checkpoint/disk_snapshot.py index 60494b6a..622611ba 100644 --- a/src/underworld3/checkpoint/disk_snapshot.py +++ b/src/underworld3/checkpoint/disk_snapshot.py @@ -371,6 +371,46 @@ def write_snapshot(model, path: str) -> str: bg = ps_group.create_group(key) _write_state_bearer_to_group(bg, obj) + # Phase 3b: swarms — one sidecar file per swarm in the bulk dir, + # referenced from /swarms/{swarm_safe}/ in the wrapper. + swarm_records: list[dict] = [] + for swarm in list(model._swarms.values()): + swarm_safe = _swarm_safe_name(swarm) + sidecar_name = _swarm_sidecar_filename(swarm_safe) + sidecar_path = os.path.join(bulk_dir, sidecar_name) + # h5py-direct write (single-rank in this phase; MPI is phase 6). + with uw.selective_ranks(0) as rank0: + if rank0: + rec = _write_swarm_to_sidecar(swarm, sidecar_path) + rec["safe_name"] = swarm_safe + rec["external_file"] = sidecar_name + swarm_records.append(rec) + uw.mpi.barrier() + + if swarm_records: + with uw.selective_ranks(0) as rank0: + if rank0: + with h5py.File(path, "a") as f: + sw_group = f[_GROUP_SWARMS] + sw_group.attrs["filled_by"] = "phase3b" + for rec in swarm_records: + g = sw_group.create_group(rec["safe_name"]) + g.attrs["mesh_name"] = rec["mesh_name"] + g.attrs["num_particles_local"] = rec[ + "num_particles_local" + ] + g.attrs["population_generation"] = rec[ + "population_generation" + ] + g.attrs["external_file"] = rec["external_file"] + vars_g = g.create_group("variables") + for var_rec in rec["vars"]: + v = vars_g.create_group(_sanitise(var_rec["name"])) + v.attrs["name"] = var_rec["name"] + v.attrs["num_components"] = var_rec[ + "num_components" + ] + uw.mpi.barrier() return path @@ -446,6 +486,168 @@ def read_snapshot(model, path: str) -> None: ) _read_state_bearer_into(ps_group[key], obj) + # Phase 3b: restore swarms from sidecars. + if _GROUP_SWARMS in f: + sw_group = f[_GROUP_SWARMS] + swarms_by_safe = { + _swarm_safe_name(s): s for s in list(model._swarms.values()) + } + for swarm_safe in sw_group.keys(): + g = sw_group[swarm_safe] + swarm = swarms_by_safe.get(swarm_safe) + if swarm is None: + raise ValueError( + f"snapshot at {path} contains swarm {swarm_safe!r} " + f"that is not registered on this model" + ) + external_file = str(g.attrs["external_file"]) + _read_swarm_from_sidecar( + swarm, os.path.join(bulk_dir, external_file) + ) + + +# ----- Phase 3b: swarm sidecars -------------------------------------------- +# +# Swarms always go to their own per-swarm sidecar file from day one, +# per Louis's "bulk is a problem with swarms, always" — no inline-vs- +# split toggle. The wrapper's /swarms/{swarm_safe}/ records metadata +# + an `@external_file` ref pointing at the sidecar in the bulk dir. +# +# The sidecar is h5py-direct (no PETSc). Swarms aren't DMPlex +# section/vec; they're per-particle numpy arrays. +# +# Single-rank now; MPI gets phase 6 treatment (per-rank sidecars or +# parallel-HDF5). + + +def _swarm_safe_name(swarm) -> str: + """Stable name for a swarm in the snapshot layout. Mirrors the + in-memory snapshot's `_snapshot_stable_name`.""" + raw = getattr(swarm, "name", None) or f"swarm_{swarm.instance_number}" + return _sanitise(raw) + + +def _swarm_sidecar_filename(swarm_safe: str) -> str: + return f"{swarm_safe}.swarm.h5" + + +def _write_swarm_to_sidecar(swarm, sidecar_path: str) -> dict: + """Write a swarm's local-particle state to a sidecar h5 file. + + Returns a record dict the caller uses to populate the wrapper's + /swarms/{name}/ group. + """ + import h5py + + coord_field = swarm.dm.getField("DMSwarmPIC_coor").reshape( + (-1, swarm.dim) + ) + coords = np.asarray(coord_field).copy() + swarm.dm.restoreField("DMSwarmPIC_coor") + + var_records: list[dict] = [] + with h5py.File(sidecar_path, "w") as f: + # File-level metadata — h5ls -v on the sidecar tells you what + # it is without needing UW3 (matches the wrapper's bar). + f.attrs["num_particles_local"] = int(coords.shape[0]) + f.attrs["dim"] = int(swarm.dim) + f.attrs["mesh_name"] = str(swarm.mesh.name) + f.attrs["population_generation"] = int(swarm._population_generation) + + f.create_dataset("coordinates", data=coords) + + vars_g = f.create_group("variables") + for var in list(swarm._vars.values()): + # Filter PETSc-internal variables — same rule as the + # in-memory swarm capture. + if var.name.startswith("DMSwarm"): + continue + data = np.asarray(var.data).copy() + d = vars_g.create_dataset(var.clean_name, data=data) + d.attrs["num_components"] = int(var.num_components) + d.attrs["dtype"] = str(data.dtype) + var_records.append({ + "name": var.clean_name, + "num_components": int(var.num_components), + }) + + return { + "num_particles_local": int(coords.shape[0]), + "mesh_name": str(swarm.mesh.name), + "population_generation": int(swarm._population_generation), + "vars": var_records, + } + + +def _read_swarm_from_sidecar(swarm, sidecar_path: str) -> None: + """Restore swarm state from a sidecar file. Mirrors + :meth:`Swarm.apply_snapshot_payload` exactly — clear local + particles, re-add at saved coords, write var data back.""" + import h5py + + with h5py.File(sidecar_path, "r") as f: + saved_coords = np.asarray(f["coordinates"][...]) + captured_mesh_name = str(f.attrs.get("mesh_name", "")) + var_data: dict[str, np.ndarray] = {} + if "variables" in f: + for name in f["variables"].keys(): + var_data[name] = np.asarray(f["variables"][name][...]) + + if captured_mesh_name and captured_mesh_name != swarm.mesh.name: + raise ValueError( + f"sidecar at {sidecar_path}: parent mesh was " + f"{captured_mesh_name!r}, target swarm is on {swarm.mesh.name!r}" + ) + + # Clear local population. removePoint is O(1) per call (last point), + # so this is O(N) total — same approach as Swarm.apply_snapshot_payload. + while swarm.dm.getLocalSize() > 0: + swarm.dm.removePoint() + + n_saved = int(saved_coords.shape[0]) + if n_saved > 0: + swarm.dm.finalizeFieldRegister() + swarm.dm.addNPoints(npoints=n_saved) + + coord_field = swarm.dm.getField("DMSwarmPIC_coor").reshape( + (-1, swarm.dim) + ) + coord_field[...] = saved_coords + swarm.dm.restoreField("DMSwarmPIC_coor") + + rank_field = swarm.dm.getField("DMSwarm_rank") + rank_field[...] = uw.mpi.rank + swarm.dm.restoreField("DMSwarm_rank") + + # Invalidate canonical-data caches so subsequent var.data reads + # re-resolve to the rebuilt PETSc fields. + if hasattr(swarm._particle_coordinates, "_canonical_data"): + swarm._particle_coordinates._canonical_data = None + for var in swarm._vars.values(): + if hasattr(var, "_canonical_data"): + var._canonical_data = None + + # Restore counted as a population change (matches in-memory path). + swarm._population_generation += 1 + + # Write per-variable captured data back into the freshly-rebuilt + # swarm. + current_vars = {v.clean_name: v for v in swarm._vars.values()} + for var_name, saved in var_data.items(): + var = current_vars.get(var_name) + if var is None: + raise ValueError( + f"sidecar variable {var_name!r} is not present on the " + f"target swarm; restore requires the same variable set" + ) + current = np.asarray(var.data) + if current.shape != saved.shape: + raise ValueError( + f"swarm variable {var_name!r} shape mismatch on restore: " + f"sidecar {saved.shape} vs current {current.shape}" + ) + current[...] = saved + # ----- Phase 3a: state-bearer (Snapshottable) serialisation ---------------- # diff --git a/tests/test_0010_snapshot_disk_format.py b/tests/test_0010_snapshot_disk_format.py index b8514b45..89a28686 100644 --- a/tests/test_0010_snapshot_disk_format.py +++ b/tests/test_0010_snapshot_disk_format.py @@ -446,3 +446,164 @@ def test_read_snapshot_rejects_missing_state_bearer(tmp_path): with pytest.raises(ValueError, match="state-bearer .* not registered"): uw.checkpoint.read_snapshot(model2, path) + + +# ----- Phase 3b: swarms in sidecars ----- + + +def _fresh_model_mesh_swarm(): + import underworld3 as uw + + uw.reset_default_model() + uw.use_strict_units(False) + uw.use_nondimensional_scaling(False) + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 4.0 + ) + swarm = uw.swarm.Swarm(mesh) + material = swarm.add_variable("material", 1, dtype=float) + swarm.populate(fill_param=2) + material.data[:, 0] = swarm._particle_coordinates.data[:, 0] + return uw, model, mesh, swarm, material + + +def test_swarm_sidecar_lands_in_bulk_dir(tmp_path): + """A registered swarm produces its own h5 sidecar in the bulk dir + with predictable name; wrapper carries the external_file ref.""" + import os + import h5py + import underworld3 as uw + + uw, model, mesh, swarm, material = _fresh_model_mesh_swarm() + path = str(tmp_path / "run.snap.h5") + uw.checkpoint.write_snapshot(model, path) + + bulk = str(tmp_path / "run.snap.bulk") + files = sorted(os.listdir(bulk)) + swarm_sidecars = [f for f in files if f.endswith(".swarm.h5")] + assert len(swarm_sidecars) == 1 + + with h5py.File(path, "r") as f: + sw = f["swarms"] + assert sw.attrs["filled_by"] == "phase3b" + assert len(sw.keys()) == 1 + swarm_safe = list(sw.keys())[0] + sg = sw[swarm_safe] + assert sg.attrs["external_file"] == swarm_sidecars[0] + assert sg.attrs["mesh_name"] == mesh.name + # User-added variables surface in /swarms/{name}/variables/ + assert "material" in sg["variables"] + + +def test_swarm_sidecar_is_inspectable(tmp_path): + """The swarm sidecar itself has h5-inspectable structure: top- + level attrs, /coordinates dataset, /variables/{name} datasets + with attrs. h5ls -v on the sidecar tells you what it holds.""" + import os + import h5py + import underworld3 as uw + + uw, model, mesh, swarm, material = _fresh_model_mesh_swarm() + path = str(tmp_path / "run.snap.h5") + uw.checkpoint.write_snapshot(model, path) + + bulk = str(tmp_path / "run.snap.bulk") + swarm_sidecar = [ + f for f in os.listdir(bulk) if f.endswith(".swarm.h5") + ][0] + + with h5py.File(os.path.join(bulk, swarm_sidecar), "r") as f: + # File-level attrs identify the file without UW3 in the loop. + assert int(f.attrs["num_particles_local"]) > 0 + assert int(f.attrs["dim"]) == 2 + assert str(f.attrs["mesh_name"]) == mesh.name + # Bulk in standard h5 places + assert "coordinates" in f + assert f["coordinates"].shape[1] == 2 + # User var present with metadata + v = f["variables"]["material"] + assert int(v.attrs["num_components"]) == 1 + + +def test_swarm_round_trips_through_disk_snapshot(tmp_path): + """The whole swarm — particle coords + svar data — round-trips + exactly through write→scribble→read.""" + import os + import underworld3 as uw + + uw, model, mesh, swarm, material = _fresh_model_mesh_swarm() + coords_pre = swarm._particle_coordinates.data.copy() + material_pre = np.asarray(material.data).copy() + + path = str(tmp_path / "run.snap.h5") + uw.checkpoint.write_snapshot(model, path) + + # Scribble: corrupt coords directly via the PETSc field, scribble + # material via the standard array view. + coord_field = swarm.dm.getField("DMSwarmPIC_coor").reshape( + (-1, swarm.dim) + ) + coord_field[...] = -99.0 + swarm.dm.restoreField("DMSwarmPIC_coor") + material.data[...] = -99.0 + + uw.checkpoint.read_snapshot(model, path) + + assert np.array_equal(swarm._particle_coordinates.data, coords_pre), ( + "particle coords not bit-exact after disk-snapshot restore" + ) + assert np.array_equal(np.asarray(material.data), material_pre), ( + "swarm material not bit-exact after disk-snapshot restore" + ) + + +def test_swarm_restore_recovers_after_particle_count_change(tmp_path): + """Mirror of the in-memory rebuild-on-restore guarantee, but on + disk: snapshot a swarm, mutate its population (add particles), + restore — the original local population (count + coords + var + data) is fully reconstructed.""" + import underworld3 as uw + + uw, model, mesh, swarm, material = _fresh_model_mesh_swarm() + n_pre = swarm.dm.getLocalSize() + coords_pre = swarm._particle_coordinates.data.copy() + material_pre = np.asarray(material.data).copy() + + path = str(tmp_path / "run.snap.h5") + uw.checkpoint.write_snapshot(model, path) + + # Mutate the population — add two extra particles. Local count + # changes. + swarm.add_particles_with_coordinates( + np.array([[0.4, 0.4], [0.6, 0.6]]) + ) + assert swarm.dm.getLocalSize() != n_pre + + uw.checkpoint.read_snapshot(model, path) + + assert swarm.dm.getLocalSize() == n_pre, ( + "restore did not roll back to the captured particle count" + ) + assert np.array_equal(swarm._particle_coordinates.data, coords_pre) + assert np.array_equal(np.asarray(material.data), material_pre) + + +def test_swarm_internals_not_in_sidecar(tmp_path): + """The PETSc-internal swarm variables (DMSwarmPIC_coor, DMSwarm_X0, + DMSwarm_remeshed) are filtered at capture — they're not in the + sidecar's /variables group. Same rule as in-memory capture.""" + import os + import h5py + import underworld3 as uw + + uw, model, mesh, swarm, material = _fresh_model_mesh_swarm() + path = str(tmp_path / "run.snap.h5") + uw.checkpoint.write_snapshot(model, path) + + bulk = str(tmp_path / "run.snap.bulk") + sidecar = [f for f in os.listdir(bulk) if f.endswith(".swarm.h5")][0] + with h5py.File(os.path.join(bulk, sidecar), "r") as f: + var_names = set(f["variables"].keys()) + assert "material" in var_names + assert not any(n.startswith("DMSwarm") for n in var_names) From 3acc65dea2ad4ddf3ea6f0a3cc656adde9f299c0 Mon Sep 17 00:00:00 2001 From: lmoresi Date: Wed, 20 May 2026 18:43:12 +1000 Subject: [PATCH 5/7] feat(snapshot, phase 5): unified Model.save_state / load_state MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Single user-facing entry point for all snapshot use cases. Same methods serve in-memory ephemeral stash and on-disk persistent snapshot — the dispatch is mechanical, the user has one API to learn: token = model.save_state() # in-memory, returns Snapshot model.load_state(token) # restore from token model.save_state(file="step42.snap.h5") # on-disk, returns path model.load_state("step42.snap.h5") # restore from disk # (also: load_state(file=…)) load_state dispatches on argument type — Snapshot → in-memory restore; str/PathLike → disk restore. Type-mismatched source raises TypeError with a clear message. Renames replace the prior Model.snapshot() / Model.restore() pair from #195. Pre-merge, no public users to migrate; getting the user-facing API right now means there is never a disparate version shipped. uw.checkpoint.{snapshot,restore,write_snapshot,read_snapshot, read_snapshot_metadata,inspect_snapshot,write_snapshot_skeleton} stay as power-user / lower-level entry points that save_state / load_state delegate to. Files updated (mechanical renames, except the doc rewrite): - src/underworld3/model.py: save_state / load_state methods replace snapshot / restore; load_state accepts positional Snapshot or str/os.PathLike, with TypeError on anything else. - tests/test_0007_snapshot_inmemory.py — 23 callers renamed; obsolete test_snapshot_path_is_v1_1_scope deleted (v1.1 has landed). - tests/test_0008_snapshot_realsolver.py — 3 tests renamed. - tests/test_0009_model_tracker.py — 9 tests renamed. - tests/test_0010_snapshot_disk_format.py — 21 tests: replace uw.checkpoint.write_snapshot / read_snapshot with model.save_state / model.load_state at user-style call sites; keep write_snapshot_skeleton + read_snapshot_metadata where the test is specifically exercising the lower-level entry points. - tests/parallel/ptest_0007_snapshot_inmemory.py — np-1/3/4 ptest. - tests/run_snapshot_backstepping_{demo,spatial}.py — demo scripts. - docs/advanced/snapshot-restore.md — rewritten API section to show both modes; added "On-disk file layout" section and a "Choosing between paths" comparison table covering write_timestep, write_checkpoint, and save_state. Limitations section updated to reflect that on-disk is now real (was "in-memory only"). Regression: 75 single-rank tests pass (was 76 — minus the deleted obsolete v1.1-scope test); MPI ptest at -np 4 still PASS with the parallel exact-reconstruction guarantee. Docs build clean with no snapshot-related warnings; the new layout + choosing-between-paths sections render. Phase 4 (read_timestep format-aware dispatch for backward compat) becomes a nice-to-have at this point — save_state / load_state is the recommended surface, write_timestep / read_timestep keep their existing role unchanged. Phase 6 (parallel HDF5 / per-rank sidecars for on-disk MPI) is the remaining correctness item. Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- docs/advanced/snapshot-restore.md | 117 +++++++++++++----- src/underworld3/model.py | 81 +++++++++--- .../parallel/ptest_0007_snapshot_inmemory.py | 10 +- tests/run_snapshot_backstepping_demo.py | 10 +- tests/run_snapshot_backstepping_spatial.py | 12 +- tests/test_0007_snapshot_inmemory.py | 75 +++++------ tests/test_0008_snapshot_realsolver.py | 18 +-- tests/test_0009_model_tracker.py | 28 ++--- tests/test_0010_snapshot_disk_format.py | 44 +++---- 9 files changed, 240 insertions(+), 155 deletions(-) diff --git a/docs/advanced/snapshot-restore.md b/docs/advanced/snapshot-restore.md index 2acaf276..997090e6 100644 --- a/docs/advanced/snapshot-restore.md +++ b/docs/advanced/snapshot-restore.md @@ -6,7 +6,7 @@ title: "State Snapshots & Restore" ## Overview -`Model.snapshot()` and `Model.restore()` are a *stash for timesteps* — +`Model.save_state()` and `Model.load_state()` are a *stash for timesteps* — a quick "hold that thought, I might need to come back" mechanism for time-stepping code. Take a snapshot, try a step, and if you don't like the result, restore and try again. The system is put back exactly as it @@ -23,11 +23,13 @@ Typical uses: - **Multi-stage time integration** (RK-style) — restore to the start of a step between stages. -This is intentionally *not* archival checkpointing. It is fast, -in-memory, and meant to be used freely within a run. For long-term, -on-disk restart files, use the existing `mesh.write_timestep()` / -`read_timestep()` path, which is unchanged and serves a different -purpose. +The same entry points serve two storage modes — in-memory for fast +intra-run stash, and on-disk for persistent restart / postprocessing. +You pick by giving (or not giving) a ``file=``. The existing +``mesh.write_timestep()`` / ``read_timestep()`` and +``mesh.write_checkpoint()`` / ``MeshVariable.read_checkpoint()`` +paths are unchanged and continue to serve their existing roles (see +"Choosing between paths" at the bottom). ## The API @@ -38,20 +40,37 @@ model = uw.get_default_model() # ... set up mesh, variables, swarm, solvers, step a few times ... -token = model.snapshot() # capture everything, return a token +# In-memory: the "stash for timesteps" use case. +token = model.save_state() # capture everything, return a token # ... take a speculative step you might regret ... -model.restore(token) # put everything back exactly +model.load_state(token) # put everything back exactly ``` -`snapshot()` returns a plain in-memory token. You can hold several at -once and restore any of them. `restore()` returns the model to the -exact state at the moment that token was captured. +To persist on disk instead, pass a path. The same call captures the +same state — only the storage layer differs: + +```python +# On-disk: persistent snapshot for restart, postprocessing, +# bisection studies, or transferring a run to another machine. +model.save_state(file="step42.snap.h5") +# ... later, or in a fresh process with the same model set up ... +model.load_state("step42.snap.h5") +``` + +``save_state()`` returns a :class:`Snapshot` token when called +without ``file``; you can hold several at once and restore any of +them. With ``file=...`` it writes a self-contained on-disk snapshot +and returns the path. + +``load_state()`` takes either a token or a path string — it figures +out which from the argument type, so the same call works for both +storage modes. ## What is captured -You do not enumerate anything — `snapshot()` captures the full state +You do not enumerate anything — `save_state()` captures the full state of the model automatically: - mesh coordinates, @@ -69,13 +88,13 @@ and restore — that is exactly the situation restore exists for. A subtle trap in time-stepping scripts: your loop counter and simulation time usually live in plain Python variables, and -`restore()` has no way to know about them. +`load_state()` has no way to know about them. ```python model_time = 0.0 -token = model.snapshot() +token = model.save_state() model_time = 5.0 # advance -model.restore(token) +model.load_state(token) # model_time is still 5.0 — restore cannot reach a local variable ``` @@ -87,12 +106,12 @@ restored. model.tracker.time = 0.0 model.tracker.step = 0 -token = model.snapshot() +token = model.save_state() model.tracker.time = 5.0 model.tracker.step = 100 -model.restore(token) +model.load_state(token) model.tracker.time # 0.0 — reverted automatically model.tracker.step # 0 — reverted automatically @@ -109,7 +128,7 @@ model.tracker.energy_history = np.zeros(3) These now travel with every snapshot and revert on every restore — no extra code, no special handling in your solvers. Using the tracker is optional; solvers do not depend on it. It is simply the place to keep -the things you want `restore()` to manage. +the things you want `load_state()` to manage. ```{note} Reserved name `state` is reserved on the tracker (it is the snapshot mechanism's own @@ -139,7 +158,7 @@ cfl_limit = mesh.get_min_radius() dt = 0.5 while model.tracker.time < t_end: - token = model.snapshot() + token = model.save_state() coords_before = swarm._particle_coordinates.data.copy() # Speculative step at the current Δt. @@ -153,7 +172,7 @@ while model.tracker.time < t_end: if moved > cfl_limit: # Too big — discard and retry with a smaller Δt. - model.restore(token) + model.load_state(token) dt *= 0.5 continue @@ -180,24 +199,56 @@ attempt starts from precisely where the failed one began. ``` ```{warning} Limitations -- **In-memory only.** Snapshots live in process memory and are not - written to disk; they do not survive the process exiting. They are - also a full copy of model state — holding many large snapshots at - once costs memory. -- **Same rank count.** A snapshot taken on *N* MPI ranks is restored - on *N* ranks. Changing the rank count is not supported by this - mechanism (use the `write_timestep` restart path for that). +- **In-memory tokens are intra-run only.** Tokens returned by + ``save_state()`` (no ``file``) live in process memory and do not + survive the process exiting. They are also a full copy of model + state — holding many large tokens at once costs memory. To persist + across runs, use ``save_state(file=…)`` instead. +- **Same rank count.** A snapshot written on *N* MPI ranks must be + read on *N* ranks. Cross-rank-count restart is not supported by + this mechanism (use the ``mesh.write_timestep`` path for that). - **No mesh adaptation across a snapshot.** If the mesh is adapted - between snapshot and restore, restore refuses with a clear error + between save and load, ``load_state`` refuses with a clear error rather than corrupting state. - **Recovery vs. a never-snapshotted run** is bit-exact for the - *discarded-step* guarantee above. Continuing after a restore that - ran a real solver may differ from a run that never snapshotted by a - small amount within solver tolerance — restore resyncs solver - fields rather than reproducing their exact internal buffers. This - does not affect the correctness of backtracking. + *discarded-step* guarantee above. Continuing after a load_state + that ran a real solver may differ from a run that never + snapshotted by a small amount within solver tolerance — load_state + resyncs solver fields rather than reproducing their exact internal + buffers. This does not affect the correctness of backtracking. ``` +## On-disk file layout (when ``save_state(file=…)`` is used) + +A disk snapshot is **two artifacts**: a small wrapper HDF5 file plus +a sibling ``.bulk/`` directory holding the bulk data. They are a +unit — move them together. The wrapper is rich in metadata and +inspectable with standard h5 tools without UW3 in the loop: + +```text +my_run.snap.h5 (~tens of KB; metadata, group structure) +my_run.snap.bulk/ (per-mesh + per-swarm sidecars) + {mesh}.mesh.00000.h5 + {mesh}.{var}.00000.h5 (one per mesh-variable) + {swarm}.swarm.h5 (one per swarm) +``` + +A quick ``h5ls -v my_run.snap.h5/metadata`` shows you the run name, +schema version, simulation time, step, dimensions, MPI rank count at +write, and inventories of meshes / swarms / variables / state-bearer +classes. For a Python-side summary use +``uw.checkpoint.inspect_snapshot(path)``. + +## Choosing between paths + +| Need | Use | +|---|---| +| Backtrack a few steps inside a running script (RK staging, adaptive Δt, predictor–corrector probes) | ``save_state()`` → token | +| Persist whole-model state across runs (crash recovery, bisection studies, full restart) | ``save_state(file=…)`` / ``load_state(file=…)`` | +| Restart from a previous run on a *different* rank count or remap onto a *different* resolution | ``mesh.write_timestep()`` / ``MeshVariable.read_timestep()`` (KDTree remap) | +| Efficient same-rank restart writing only specific variables for postprocessing | ``mesh.write_checkpoint()`` / ``MeshVariable.read_checkpoint()`` (PETSc DMPlex per-variable) | +| Visualisation for ParaView (XDMF + per-step HDF5) | ``mesh.write_timestep(create_xdmf=True)`` | + ## Related - [Parallel-Safe Scripting](parallel-computing.md) — MPI patterns; diff --git a/src/underworld3/model.py b/src/underworld3/model.py index 547e690a..f2b485f3 100644 --- a/src/underworld3/model.py +++ b/src/underworld3/model.py @@ -610,35 +610,78 @@ def _register_state_bearer(self, obj) -> None: """ self._state_bearers.add(obj) - def snapshot(self, *, path: Optional[str] = None): - """Capture a unitary in-memory snapshot of the model's state. + def save_state(self, *, file: Optional[str] = None): + """Save the model's current state — memory or disk, one entry point. - v1 covers mesh coordinates and mesh-variable DOFs across every - registered mesh. Subsequent PRs extend coverage to swarms and - solver-internal Python state. + Without ``file``, captures an in-memory :class:`Snapshot` + token suitable for in-run backtracking ("stash for + timesteps"). The token is plain Python / numpy — fast to + produce, fast to restore, does not survive the process. - Pass ``path=...`` to write to an HDF5 file once the on-disk - backend lands (v1.1); v1 raises ``NotImplementedError``. + With ``file=``, writes a persistent on-disk snapshot at + that path (plus a sibling ``.bulk/`` directory holding the + bulk PETSc + swarm sidecars). Survives the process; suitable + for restart, postprocessing, transferring runs. - See ``docs/developer/design/in_memory_checkpoint_design.md`` - for the full design. + Either way the captured state is the full model: all + registered meshes and mesh-variables, all swarms with + per-particle data, all solver-internal state-bearers + (:class:`ModelTracker`, ``DDt`` instances, anything else + exposing the ``Snapshottable`` contract). + + Parameters + ---------- + file : str, optional + Path to write a disk snapshot to. If omitted, an in-memory + token is returned instead. + + Returns + ------- + Snapshot + When called without ``file`` — pass to + :meth:`load_state` to restore. + str + When ``file`` is given — the path the snapshot was + written to (same as ``file``). """ from underworld3.checkpoint import snapshot as _snapshot + from underworld3.checkpoint import write_snapshot as _write_snapshot - return _snapshot(self, path=path) + if file is None: + return _snapshot(self) + return _write_snapshot(self, file) - def restore(self, snap) -> None: - """Restore the model from a :class:`underworld3.checkpoint.Snapshot`. + def load_state(self, source) -> None: + """Restore the model from a previously saved state — memory or disk. - Within-process restore: ``snap`` must have been produced by - :meth:`snapshot` on this same ``Model`` instance. Raises - :class:`underworld3.checkpoint.SnapshotInvalidatedError` if - the mesh has been adapted, or a captured mesh / variable is - no longer registered. + ``source`` is either: + + - a :class:`Snapshot` token returned by an earlier + ``save_state()`` call (in-memory restore — bit-exact), + - a path string to a disk-snapshot file (disk restore — + same-rank, same-model contract; mesh-rebuild on read is + v1.2 scope). + + Raises + ------ + TypeError + ``source`` is neither a :class:`Snapshot` nor a string. + :class:`SnapshotInvalidatedError` + The captured state no longer matches what is registered + (mesh adapted, state-bearer missing, ...). """ + from underworld3.checkpoint import Snapshot from underworld3.checkpoint import restore as _restore - - return _restore(self, snap) + from underworld3.checkpoint import read_snapshot as _read_snapshot + + if isinstance(source, Snapshot): + return _restore(self, source) + if isinstance(source, (str, os.PathLike)): + return _read_snapshot(self, str(source)) + raise TypeError( + f"load_state expects a Snapshot token or a path string, " + f"got {type(source).__name__}" + ) def define_parameter(self, name: str, ptype=None, **kwargs): """ diff --git a/tests/parallel/ptest_0007_snapshot_inmemory.py b/tests/parallel/ptest_0007_snapshot_inmemory.py index d533d046..ffe3185e 100644 --- a/tests/parallel/ptest_0007_snapshot_inmemory.py +++ b/tests/parallel/ptest_0007_snapshot_inmemory.py @@ -112,12 +112,12 @@ def main(): pre_count = comm.allreduce(swarm.dm.getLocalSize(), op=MPI.SUM) pre_ddt = (list(ddt.state.dt_history), ddt.state.n_solves_completed) - snap = model.snapshot() + snap = model.save_state() # --- Property 1 + 2: a migration-inducing step, then restore --- step(uw, V_fn, T, swarm, ddt, 0.3) # bigger dt -> more migration mid_count = comm.allreduce(swarm.dm.getLocalSize(), op=MPI.SUM) - model.restore(snap) + model.load_state(snap) post = global_sorted_particles(swarm, gid, material) post_count = comm.allreduce(swarm.dm.getLocalSize(), op=MPI.SUM) post_ddt = (list(ddt.state.dt_history), ddt.state.n_solves_completed) @@ -126,15 +126,15 @@ def main(): ddt_ok = pre_ddt == post_ddt # --- Property 3: bit-identical continuation across a stash --- - snap2 = model.snapshot() + snap2 = model.save_state() for _ in range(4): step(uw, V_fn, T, swarm, ddt, 0.1) ctrl = global_sorted_particles(swarm, gid, material) ctrl_ddt = (list(ddt.state.dt_history), ddt.state.n_solves_completed) - model.restore(snap2) + model.load_state(snap2) step(uw, V_fn, T, swarm, ddt, 0.5) # the regretted step - model.restore(snap2) + model.load_state(snap2) for _ in range(4): step(uw, V_fn, T, swarm, ddt, 0.1) stash = global_sorted_particles(swarm, gid, material) diff --git a/tests/run_snapshot_backstepping_demo.py b/tests/run_snapshot_backstepping_demo.py index 5a014548..051fcd26 100644 --- a/tests/run_snapshot_backstepping_demo.py +++ b/tests/run_snapshot_backstepping_demo.py @@ -7,7 +7,7 @@ - timestep forward at small Δt for a while (CFL well under 1), - take a snapshot, - try one too-large Δt (CFL spikes far above 1), - - detect the bad step, call ``model.restore(snap)``, + - detect the bad step, call ``model.load_state(snap)``, - replay the same time interval with many small steps (CFL stays small), - continue past the speculative end-time. @@ -82,7 +82,7 @@ def take_step(dt: float): cfl_kept.append(disp / cfl_threshold) t_snap = t - snap = model.snapshot() + snap = model.save_state() # --- Phase 2: speculative big step --- disp_bad = take_step(candidate_dt) @@ -90,7 +90,7 @@ def take_step(dt: float): cfl_bad = disp_bad / cfl_threshold # --- CFL violated → restore --- - model.restore(snap) + model.load_state(snap) # --- Phase 3: substep replay --- times_recovered = [] @@ -171,7 +171,7 @@ def take_step(dt: float): ax.text( 0.5 * (t_snap + t_bad_end), 0.4 * cfl_bad, - "model.restore(snap)", + "model.load_state(snap)", ha="center", va="center", fontsize=10, color="0.35", style="italic", bbox=dict(facecolor="white", edgecolor="0.7", boxstyle="round,pad=0.25"), @@ -200,7 +200,7 @@ def take_step(dt: float): ax.set_xlabel("simulation time t") ax.set_ylabel("CFL ratio = max per-step displacement / cell radius") ax.set_title( - "Adaptive-Δt back-stepping • model.snapshot() / model.restore()", + "Adaptive-Δt back-stepping • model.save_state() / model.load_state()", pad=22, ) ax.legend(loc="upper right", frameon=False) diff --git a/tests/run_snapshot_backstepping_spatial.py b/tests/run_snapshot_backstepping_spatial.py index fb155747..fe06061f 100644 --- a/tests/run_snapshot_backstepping_spatial.py +++ b/tests/run_snapshot_backstepping_spatial.py @@ -5,14 +5,14 @@ spatial panels at the four moments that matter: [initial state (snapshot taken here)] [after speculative bad step] - [after model.restore(snap)] [after substep recovery to same t] + [after model.load_state(snap)] [after substep recovery to same t] Each panel shows the swarm particles coloured by their carried material value (initial radial position), with the domain boundary drawn as context. The diagonal pairs tell two stories: - top-left vs. bottom-left should be **visually identical**. That's - the proof that model.restore(snap) put the captured state back + the proof that model.load_state(snap) put the captured state back exactly. If the figure ever stops showing two identical panels in that diagonal, the snapshot mechanism has broken. @@ -69,7 +69,7 @@ def main(out_path: str = "snapshot_backstepping_spatial.png"): # Take the snapshot — this is the state that bottom-left will # have to match after restore. - snap = model.snapshot() + snap = model.save_state() # --- Speculative big step --- swarm.advection(V_fn, delta_t=candidate_dt, step_limit=False) @@ -79,8 +79,8 @@ def main(out_path: str = "snapshot_backstepping_spatial.png"): ) cfl_bad = max_disp_bad / cfl_threshold - # --- model.restore(snap) --- - model.restore(snap) + # --- model.load_state(snap) --- + model.load_state(snap) state_after_restore = _capture(swarm, material) # --- Substep recovery to the same target time --- @@ -121,7 +121,7 @@ def main(out_path: str = "snapshot_backstepping_spatial.png"): ( axes[1, 0], state_after_restore, - "After model.restore(snap)", + "After model.load_state(snap)", "t = 0.00 • visually identical to top-left", ), ( diff --git a/tests/test_0007_snapshot_inmemory.py b/tests/test_0007_snapshot_inmemory.py index 1af2af6f..d52fee33 100644 --- a/tests/test_0007_snapshot_inmemory.py +++ b/tests/test_0007_snapshot_inmemory.py @@ -24,12 +24,12 @@ def test_meshvariable_in_memory_roundtrip(): T.array[:, 0, 0] = T.coords[:, 0] + 2.0 * T.coords[:, 1] pre_array = np.asarray(T.array[...]).copy() - snap = model.snapshot() + snap = model.save_state() T.array[...] = -42.0 assert not np.allclose(np.asarray(T.array[...]), pre_array), "scribble didn't take" - model.restore(snap) + model.load_state(snap) assert np.allclose(np.asarray(T.array[...]), pre_array, atol=0.0, rtol=0.0), ( "MeshVariable.array is not bit-equivalent after restore" @@ -49,12 +49,12 @@ def test_multiple_meshvariables_roundtrip(): T_pre = np.asarray(T.array[...]).copy() V_pre = np.asarray(V.array[...]).copy() - snap = model.snapshot() + snap = model.save_state() T.array[...] = 0.0 V.array[...] = 0.0 - model.restore(snap) + model.load_state(snap) assert np.allclose(np.asarray(T.array[...]), T_pre) assert np.allclose(np.asarray(V.array[...]), V_pre) @@ -66,7 +66,7 @@ def test_snapshot_is_independent_of_subsequent_writes(): T = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) T.array[:, 0, 0] = 5.0 - snap = model.snapshot() + snap = model.save_state() T.array[...] = -1.0 # The backend still holds the captured value, not the post-write value. @@ -87,14 +87,14 @@ def test_mesh_version_invalidates_restore(): T = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) T.array[:, 0, 0] = 1.0 - snap = model.snapshot() + snap = model.save_state() # Simulate a mesh-mutation event (e.g. adapt(), or any deformation # routed through the high-level callback that bumps _mesh_version). mesh._mesh_version += 1 with pytest.raises(SnapshotInvalidatedError, match="_mesh_version"): - model.restore(snap) + model.load_state(snap) def test_restore_rejects_non_snapshot(): @@ -103,16 +103,7 @@ def test_restore_rejects_non_snapshot(): _ = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) with pytest.raises(TypeError): - model.restore({"not": "a snapshot"}) - - -def test_snapshot_path_is_v1_1_scope(): - """Passing path= raises NotImplementedError until the on-disk backend lands.""" - uw, model, mesh = _fresh_model_and_mesh() - _ = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) - - with pytest.raises(NotImplementedError): - model.snapshot(path="/tmp/should_not_be_written.h5") + model.load_state({"not": "a snapshot"}) # ----- Swarm coverage (rebuild-on-restore semantics) ----- @@ -150,14 +141,14 @@ def test_swarm_no_change_roundtrip(): coords_pre = _swarm_coords(swarm) material_pre = np.asarray(material.data).copy() - snap = model.snapshot() + snap = model.save_state() coord_field = swarm.dm.getField("DMSwarmPIC_coor").reshape((-1, swarm.dim)) coord_field[...] = -99.0 swarm.dm.restoreField("DMSwarmPIC_coor") material.data[...] = -99.0 - model.restore(snap) + model.load_state(snap) assert np.allclose(_swarm_coords(swarm), coords_pre) assert np.allclose(np.asarray(material.data), material_pre) @@ -172,14 +163,14 @@ def test_swarm_restore_after_migrate(): material_pre = np.asarray(material.data).copy() pop_gen_pre = swarm._population_generation - snap = model.snapshot() + snap = model.save_state() # Mutate: migrate() will bump the counter regardless of whether # particles actually moved. Restore must succeed anyway. swarm.migrate(remove_sent_points=True) assert swarm._population_generation > pop_gen_pre, "migrate didn't bump counter" - model.restore(snap) + model.load_state(snap) assert np.allclose(_swarm_coords(swarm), coords_pre), ( "restore did not recover particle coords across a migrate event" @@ -197,14 +188,14 @@ def test_swarm_restore_after_add_particles(): material_pre = np.asarray(material.data).copy() npre = swarm.dm.getLocalSize() - snap = model.snapshot() + snap = model.save_state() swarm.add_particles_with_coordinates( np.array([[0.5, 0.5], [0.25, 0.75]]) ) assert swarm.dm.getLocalSize() != npre, "add_particles didn't grow swarm" - model.restore(snap) + model.load_state(snap) assert swarm.dm.getLocalSize() == npre, ( "restore did not roll back to the captured particle count" @@ -218,7 +209,7 @@ def test_swarm_population_generation_is_informational_not_a_gate(): uw, model, mesh, swarm, _ = _fresh_model_mesh_and_swarm() gen_at_capture = swarm._population_generation - snap = model.snapshot() + snap = model.save_state() swarm.migrate(remove_sent_points=True) swarm.add_particles_with_coordinates(np.array([[0.5, 0.5]])) @@ -226,7 +217,7 @@ def test_swarm_population_generation_is_informational_not_a_gate(): assert gen_during > gen_at_capture # Restore is expected to *succeed*, not raise. - model.restore(snap) + model.load_state(snap) # And the counter has moved on from where it was at capture, # because restore itself counts as a population change. @@ -237,7 +228,7 @@ def test_swarm_internal_variables_are_not_captured(): """Internal DMSwarm_* variables stay out of the snapshot key list.""" uw, model, mesh, swarm, material = _fresh_model_mesh_and_swarm() - snap = model.snapshot() + snap = model.save_state() keys = snap.backend.list_vectors() swarm_name = swarm._snapshot_stable_name() svar_keys = [k for k in keys if k.startswith(f"swarm:{swarm_name}:var:")] @@ -299,14 +290,14 @@ def test_symbolic_ddt_roundtrip_recovers_state(): assert state_pre.n_solves_completed == 2 assert state_pre.dt_history == [0.2, 0.1] - snap = model.snapshot() + snap = model.save_state() # Mutate: take another solve, dt_history changes. ddt.update_pre_solve(dt=0.5) ddt.update_post_solve(dt=0.5) assert ddt.state.dt_history == [0.5, 0.2] - model.restore(snap) + model.load_state(snap) # Primary state is back to captured. state_post = ddt.state @@ -344,7 +335,7 @@ def test_symbolic_ddt_snapshot_is_deep_copy(): ddt.update_pre_solve(dt=0.1) ddt.update_post_solve(dt=0.1) - snap = model.snapshot() + snap = model.save_state() # Find the DDt's captured state by type — state_bearers is # unordered and now also contains the model tracker. from underworld3.systems.ddt import DDtSymbolicState @@ -394,14 +385,14 @@ def test_eulerian_ddt_roundtrip(): ddt._dt = 0.2 state_pre = ddt.state - snap = model.snapshot() + snap = model.save_state() ddt._dt_history = [0.99, 0.99] ddt._history_initialised = False ddt._n_solves_completed = 0 ddt._dt = None - model.restore(snap) + model.load_state(snap) assert ddt.state.dt_history == state_pre.dt_history assert ddt.state.history_initialised == state_pre.history_initialised @@ -437,11 +428,11 @@ def test_semilagrangian_ddt_roundtrip(): ddt._dt = 0.3 state_pre = ddt.state - snap = model.snapshot() + snap = model.save_state() ddt._dt_history = [None, None] ddt._history_initialised = False ddt._n_solves_completed = 0 - model.restore(snap) + model.load_state(snap) assert ddt.state.dt_history == state_pre.dt_history assert ddt.state.history_initialised is True @@ -474,11 +465,11 @@ def test_lagrangian_ddt_roundtrip(): ddt._dt = 0.2 state_pre = ddt.state - snap = model.snapshot() + snap = model.save_state() ddt._dt_history = [None, None] ddt._history_initialised = False ddt._n_solves_completed = 0 - model.restore(snap) + model.load_state(snap) assert ddt.state.dt_history == state_pre.dt_history assert ddt.state.history_initialised is True @@ -549,7 +540,7 @@ def test_backstepping_cfl_recovery_end_to_end(): # Take the snapshot *before* the speculative step. Everything that # will be touched gets captured. - snap = model.snapshot() + snap = model.save_state() # Speculative step at the candidate Δt. Bigger than the user # thinks is safe — they'll check after and back-step if it isn't. @@ -569,7 +560,7 @@ def test_backstepping_cfl_recovery_end_to_end(): # Back-step. Everything captured is brought back to the snapshot # point — swarm positions, the material variable carried with the # swarm, and the DDt's BDF history. - model.restore(snap) + model.load_state(snap) assert np.allclose(swarm._particle_coordinates.data, coords_initial), ( "particle positions did not roll back after restore" @@ -724,7 +715,7 @@ def test_continuation_deterministic_after_restore(): for _ in range(3): _step(uw, V_fn, T, swarm, ddt, 0.05) - snap = model.snapshot() + snap = model.save_state() # Branch A: K steps straight from S. for _ in range(5): @@ -732,7 +723,7 @@ def test_continuation_deterministic_after_restore(): state_A = _capture_full_state(T, swarm, material, ddt) # Branch B: restore S, then the identical K steps. - model.restore(snap) + model.load_state(snap) for _ in range(5): _step(uw, V_fn, T, swarm, ddt, 0.05) state_B = _capture_full_state(T, swarm, material, ddt) @@ -756,7 +747,7 @@ def test_continuation_bit_identical_across_stash_and_recover(): for _ in range(3): _step(uw, V_fn, T, swarm, ddt, 0.05) - snap = model.snapshot() + snap = model.save_state() # Control: K good steps from S. for _ in range(5): @@ -766,9 +757,9 @@ def test_continuation_bit_identical_across_stash_and_recover(): # Stash scenario: back to S, take a deliberately disruptive step # (10x Δt — large advection, big T jump, DDt history shift), then # discard it via restore and run the intended K good steps. - model.restore(snap) + model.load_state(snap) _step(uw, V_fn, T, swarm, ddt, 0.5) # the regretted step - model.restore(snap) + model.load_state(snap) for _ in range(5): _step(uw, V_fn, T, swarm, ddt, 0.05) stash = _capture_full_state(T, swarm, material, ddt) diff --git a/tests/test_0008_snapshot_realsolver.py b/tests/test_0008_snapshot_realsolver.py index 33afe4a3..d37f15de 100644 --- a/tests/test_0008_snapshot_realsolver.py +++ b/tests/test_0008_snapshot_realsolver.py @@ -105,14 +105,14 @@ def test_realsolver_restore_recovers_solution_field(): adv_diff.solve(timestep=1.0e-3) pre_T = _capture(T) - snap = model.snapshot() + snap = model.save_state() adv_diff.solve(timestep=5.0) # absurd Δt: converges, over-diffused assert not np.allclose(_capture(T), pre_T, atol=1e-8), ( "the regretted solve was not actually disruptive" ) - model.restore(snap) + model.load_state(snap) assert np.array_equal(_capture(T), pre_T), ( "solution field not exactly recovered after restore" ) @@ -132,18 +132,18 @@ def test_realsolver_regretted_step_leaves_no_trace(): for _ in range(3): adv_diff.solve(timestep=1.0e-3) - snap = model.snapshot() + snap = model.save_state() # B: restore, then K good solves. - model.restore(snap) + model.load_state(snap) for _ in range(4): adv_diff.solve(timestep=1.0e-3) B = _capture(T) # C: restore, a regretted solve, restore, the same K good solves. - model.restore(snap) + model.load_state(snap) adv_diff.solve(timestep=5.0) - model.restore(snap) + model.load_state(snap) for _ in range(4): adv_diff.solve(timestep=1.0e-3) C = _capture(T) @@ -165,7 +165,7 @@ def test_realsolver_continuation_within_solver_tolerance(): for _ in range(3): adv_diff.solve(timestep=1.0e-3) - snap = model.snapshot() + snap = model.save_state() # Control: never snapshotted/restored — straight K solves. for _ in range(4): @@ -173,9 +173,9 @@ def test_realsolver_continuation_within_solver_tolerance(): ctrl = _capture(T) # Stash path: restore, regretted solve, restore, same K solves. - model.restore(snap) + model.load_state(snap) adv_diff.solve(timestep=5.0) - model.restore(snap) + model.load_state(snap) for _ in range(4): adv_diff.solve(timestep=1.0e-3) stash = _capture(T) diff --git a/tests/test_0009_model_tracker.py b/tests/test_0009_model_tracker.py index 35403654..3f3f12ac 100644 --- a/tests/test_0009_model_tracker.py +++ b/tests/test_0009_model_tracker.py @@ -34,13 +34,13 @@ def test_tracker_builtins_revert_on_restore(): model.tracker.step = 7 model.tracker.dt = 0.05 - snap = model.snapshot() + snap = model.save_state() model.tracker.time = 99.0 model.tracker.step = 999 model.tracker.dt = 1.0 - model.restore(snap) + model.load_state(snap) assert model.tracker.time == 3.14 assert model.tracker.step == 7 @@ -53,9 +53,9 @@ def test_tracker_user_quantity_reverts(): uw, model = _fresh_model() model.tracker.my_diagnostic = 42.0 - snap = model.snapshot() + snap = model.save_state() model.tracker.my_diagnostic = -1.0 - model.restore(snap) + model.load_state(snap) assert model.tracker.my_diagnostic == 42.0 @@ -67,11 +67,11 @@ def test_tracker_numpy_quantity_reverts_by_value(): arr = np.array([1.0, 2.0, 3.0]) model.tracker.history = arr - snap = model.snapshot() + snap = model.save_state() model.tracker.history[:] = -9.0 # in-place mutation assert np.allclose(model.tracker.history, -9.0) - model.restore(snap) + model.load_state(snap) assert np.allclose(model.tracker.history, [1.0, 2.0, 3.0]) @@ -81,11 +81,11 @@ def test_tracker_quantity_added_after_snapshot_is_dropped_on_restore(): uw, model = _fresh_model() model.tracker.a = 1.0 - snap = model.snapshot() + snap = model.save_state() model.tracker.b = 2.0 # created after snapshot assert "b" in model.tracker - model.restore(snap) + model.load_state(snap) assert "a" in model.tracker assert "b" not in model.tracker @@ -99,13 +99,13 @@ def test_tracker_is_what_makes_state_revertible(): loose_time = 0.0 model.tracker.time = 0.0 - snap = model.snapshot() + snap = model.save_state() # Advance both the loose variable and the tracked one. loose_time = 5.0 model.tracker.time = 5.0 - model.restore(snap) + model.load_state(snap) # The loose variable is untouched by restore (the language can't # know about it); the tracked one rolled back. @@ -122,10 +122,10 @@ def test_tracker_state_roundtrip_is_bit_identical(): model.tracker.payload = np.arange(5).astype(float) state_pre = model.tracker.state - snap = model.snapshot() + snap = model.save_state() model.tracker.time = 12345.0 model.tracker.payload[:] = 0.0 - model.restore(snap) + model.load_state(snap) state_post = model.tracker.state assert state_post.managed["time"] == state_pre.managed["time"] @@ -158,7 +158,7 @@ def do_step(dt): for _ in range(3): do_step(0.05) - snap = model.snapshot() + snap = model.save_state() t_snap, s_snap = model.tracker.time, model.tracker.step # Regretted big step. @@ -166,7 +166,7 @@ def do_step(dt): assert model.tracker.step == s_snap + 1 assert model.tracker.time != t_snap - model.restore(snap) + model.load_state(snap) assert model.tracker.time == t_snap assert model.tracker.step == s_snap diff --git a/tests/test_0010_snapshot_disk_format.py b/tests/test_0010_snapshot_disk_format.py index 89a28686..0d6eb826 100644 --- a/tests/test_0010_snapshot_disk_format.py +++ b/tests/test_0010_snapshot_disk_format.py @@ -201,7 +201,7 @@ def test_write_snapshot_produces_wrapper_and_bulk_dir(tmp_path): V.array[:, 0, 0] = -3.0 path = str(tmp_path / "run.snap.h5") - uw.checkpoint.write_snapshot(model, path) + model.save_state(file=path) bulk = str(tmp_path / "run.snap.bulk") assert os.path.exists(path) @@ -226,7 +226,7 @@ def test_write_snapshot_populates_wrapper_layout(tmp_path): V.array[:, 0, 1] = 3.0 path = str(tmp_path / "run.snap.h5") - uw.checkpoint.write_snapshot(model, path) + model.save_state(file=path) with h5py.File(path, "r") as f: assert f["meshes"].attrs["filled_by"] == "phase2" @@ -262,13 +262,13 @@ def test_write_read_snapshot_bit_exact_roundtrip(tmp_path): V_pre = np.asarray(V.array[...]).copy() path = str(tmp_path / "run.snap.h5") - uw.checkpoint.write_snapshot(model, path) + model.save_state(file=path) # Scribble. T.array[...] = -99.0 V.array[...] = -99.0 - uw.checkpoint.read_snapshot(model, path) + model.load_state(path) assert np.array_equal(np.asarray(T.array[...]), T_pre), ( f"T not bit-exact after read_snapshot — max|d|=" @@ -288,7 +288,7 @@ def test_read_snapshot_rejects_missing_bulk_dir(tmp_path): uw, model, mesh, T, V = _fresh_model_mesh_and_vars() path = str(tmp_path / "run.snap.h5") - uw.checkpoint.write_snapshot(model, path) + model.save_state(file=path) # Delete the bulk dir to simulate the move-the-wrapper-only mistake. import shutil @@ -296,7 +296,7 @@ def test_read_snapshot_rejects_missing_bulk_dir(tmp_path): uw, model, mesh, T, V = _fresh_model_mesh_and_vars() with pytest.raises(FileNotFoundError, match="bulk directory missing"): - uw.checkpoint.read_snapshot(model, path) + model.load_state(path) def test_read_snapshot_rejects_mismatched_mesh(tmp_path): @@ -306,7 +306,7 @@ def test_read_snapshot_rejects_mismatched_mesh(tmp_path): uw, model, mesh, T, V = _fresh_model_mesh_and_vars() path = str(tmp_path / "run.snap.h5") - uw.checkpoint.write_snapshot(model, path) + model.save_state(file=path) # Fresh model with a *different* mesh — write_snapshot's mesh.name # won't match. @@ -319,7 +319,7 @@ def test_read_snapshot_rejects_mismatched_mesh(tmp_path): other.name = "definitely_a_different_mesh" with pytest.raises(ValueError, match="not registered on this model"): - uw.checkpoint.read_snapshot(model2, path) + model2.load_state(path) # ----- Phase 3a: state-bearer (Snapshottable) serialisation ----- @@ -339,7 +339,7 @@ def test_tracker_round_trips_through_disk_snapshot(tmp_path): model.tracker.history_arr = np.array([1.0, 2.0, 3.0]) path = str(tmp_path / "run.snap.h5") - uw.checkpoint.write_snapshot(model, path) + model.save_state(file=path) # Scribble everything tracker-side. model.tracker.time = -1.0 @@ -348,7 +348,7 @@ def test_tracker_round_trips_through_disk_snapshot(tmp_path): model.tracker.my_diagnostic = -1.0 model.tracker.history_arr = np.array([-1.0, -1.0, -1.0]) - uw.checkpoint.read_snapshot(model, path) + model.load_state(path) assert model.tracker.time == 3.14 assert model.tracker.step == 42 @@ -371,7 +371,7 @@ def test_python_state_group_is_inspectable(tmp_path): model.tracker.my_q = 7.0 path = str(tmp_path / "run.snap.h5") - uw.checkpoint.write_snapshot(model, path) + model.save_state(file=path) with h5py.File(path, "r") as f: ps = f["python_state"] @@ -410,7 +410,7 @@ def test_ddt_symbolic_state_round_trips_primary_fields(tmp_path): ddt._dt = 0.05 path = str(tmp_path / "run.snap.h5") - uw.checkpoint.write_snapshot(model, path) + model.save_state(file=path) # Scribble the primary fields. ddt._dt_history = [None, None] @@ -418,7 +418,7 @@ def test_ddt_symbolic_state_round_trips_primary_fields(tmp_path): ddt._n_solves_completed = 0 ddt._dt = None - uw.checkpoint.read_snapshot(model, path) + model.load_state(path) assert ddt.state.dt_history == [0.05, 0.03] assert ddt.state.history_initialised is True @@ -437,7 +437,7 @@ def test_read_snapshot_rejects_missing_state_bearer(tmp_path): ddt._dt_history = [0.05, 0.05] path = str(tmp_path / "run.snap.h5") - uw.checkpoint.write_snapshot(model, path) + model.save_state(file=path) # Target model has no DDt. uw, model2, mesh2, T2, V2 = _fresh_model_mesh_and_vars() @@ -445,7 +445,7 @@ def test_read_snapshot_rejects_missing_state_bearer(tmp_path): mesh2.name = mesh.name with pytest.raises(ValueError, match="state-bearer .* not registered"): - uw.checkpoint.read_snapshot(model2, path) + model2.load_state(path) # ----- Phase 3b: swarms in sidecars ----- @@ -477,7 +477,7 @@ def test_swarm_sidecar_lands_in_bulk_dir(tmp_path): uw, model, mesh, swarm, material = _fresh_model_mesh_swarm() path = str(tmp_path / "run.snap.h5") - uw.checkpoint.write_snapshot(model, path) + model.save_state(file=path) bulk = str(tmp_path / "run.snap.bulk") files = sorted(os.listdir(bulk)) @@ -506,7 +506,7 @@ def test_swarm_sidecar_is_inspectable(tmp_path): uw, model, mesh, swarm, material = _fresh_model_mesh_swarm() path = str(tmp_path / "run.snap.h5") - uw.checkpoint.write_snapshot(model, path) + model.save_state(file=path) bulk = str(tmp_path / "run.snap.bulk") swarm_sidecar = [ @@ -537,7 +537,7 @@ def test_swarm_round_trips_through_disk_snapshot(tmp_path): material_pre = np.asarray(material.data).copy() path = str(tmp_path / "run.snap.h5") - uw.checkpoint.write_snapshot(model, path) + model.save_state(file=path) # Scribble: corrupt coords directly via the PETSc field, scribble # material via the standard array view. @@ -548,7 +548,7 @@ def test_swarm_round_trips_through_disk_snapshot(tmp_path): swarm.dm.restoreField("DMSwarmPIC_coor") material.data[...] = -99.0 - uw.checkpoint.read_snapshot(model, path) + model.load_state(path) assert np.array_equal(swarm._particle_coordinates.data, coords_pre), ( "particle coords not bit-exact after disk-snapshot restore" @@ -571,7 +571,7 @@ def test_swarm_restore_recovers_after_particle_count_change(tmp_path): material_pre = np.asarray(material.data).copy() path = str(tmp_path / "run.snap.h5") - uw.checkpoint.write_snapshot(model, path) + model.save_state(file=path) # Mutate the population — add two extra particles. Local count # changes. @@ -580,7 +580,7 @@ def test_swarm_restore_recovers_after_particle_count_change(tmp_path): ) assert swarm.dm.getLocalSize() != n_pre - uw.checkpoint.read_snapshot(model, path) + model.load_state(path) assert swarm.dm.getLocalSize() == n_pre, ( "restore did not roll back to the captured particle count" @@ -599,7 +599,7 @@ def test_swarm_internals_not_in_sidecar(tmp_path): uw, model, mesh, swarm, material = _fresh_model_mesh_swarm() path = str(tmp_path / "run.snap.h5") - uw.checkpoint.write_snapshot(model, path) + model.save_state(file=path) bulk = str(tmp_path / "run.snap.bulk") sidecar = [f for f in os.listdir(bulk) if f.endswith(".swarm.h5")][0] From 52d5277eca01467236b54fe3baa381d0b404193d Mon Sep 17 00:00:00 2001 From: lmoresi Date: Wed, 20 May 2026 20:12:46 +1000 Subject: [PATCH 6/7] feat(snapshot, phase 4): MeshVariable.read_timestep is format-aware MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The selective-read entry point users already know (``var.read_timestep(...)``) now reads BOTH the legacy ``write_timestep`` per-variable HDF5 files AND v1.1 snapshot wrappers — same call, format detection is hidden inside the function. No user code has to learn a second API for the new format; existing scripts with ``var.read_timestep(...)`` calls keep working transparently against new files. This is the compat commitment from the design discussion: "the clean interface lies beneath the surface for this case" — meaning the format dispatch is hidden, not that read_timestep itself is hidden. read_timestep serves a different use case than save_state/load_state (selective per-variable, cross-resolution remap via KDTree, visualisation-style reads); both stay user-facing. Implementation: - ``uw.checkpoint.is_snapshot_wrapper(path)``: cheap format detector — checks for top-level /metadata + /meshes groups. - ``uw.checkpoint.extract_var_via_bridge(wrapper_path, var_name)``: given a v1.1 wrapper + variable name, returns (coords, values) numpy arrays — exactly what the legacy file's h5 read produces. Mechanism: load source mesh from .mesh.h5 sidecar, rebuild source variable with matching degree/components, load DOFs via #146's MeshVariable.read_checkpoint, read out var.coords and var.array. - MeshVariable.read_timestep: before its rank-0 (coord, value) read, dispatches on the file's format. v1.1 → bridge. Legacy → existing per-variable h5 read. Everything after — the source- swarm + query-swarm KDTree-routing machinery — is reused unchanged. Tests (2 new, 23 total in test_0010, 77 across the snapshot suite): - read_timestep against a v1.1 snapshot wrapper round-trips a variable bit-exact (KDTree query lands on captured DOF coords) - read_timestep against a legacy write_timestep file still uses the legacy code path (belt-and-braces no-regression check) Phase 6 (parallel on-disk MPI) remains as the production-readiness gate for the disk path. Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- src/underworld3/checkpoint/__init__.py | 4 + src/underworld3/checkpoint/disk_snapshot.py | 75 +++++++++++++++++++ .../discretisation_mesh_variables.py | 53 ++++++++++--- tests/test_0010_snapshot_disk_format.py | 49 ++++++++++++ 4 files changed, 172 insertions(+), 9 deletions(-) diff --git a/src/underworld3/checkpoint/__init__.py b/src/underworld3/checkpoint/__init__.py index ad17d5b9..eada5e59 100644 --- a/src/underworld3/checkpoint/__init__.py +++ b/src/underworld3/checkpoint/__init__.py @@ -29,7 +29,9 @@ from .tracker import ModelTracker, TrackerState from .disk_snapshot import ( DISK_SNAPSHOT_SCHEMA_VERSION, + extract_var_via_bridge, inspect_snapshot, + is_snapshot_wrapper, read_snapshot, read_snapshot_metadata, write_snapshot, @@ -49,7 +51,9 @@ "ModelTracker", "TrackerState", "DISK_SNAPSHOT_SCHEMA_VERSION", + "extract_var_via_bridge", "inspect_snapshot", + "is_snapshot_wrapper", "read_snapshot", "read_snapshot_metadata", "write_snapshot", diff --git a/src/underworld3/checkpoint/disk_snapshot.py b/src/underworld3/checkpoint/disk_snapshot.py index 622611ba..1b4954db 100644 --- a/src/underworld3/checkpoint/disk_snapshot.py +++ b/src/underworld3/checkpoint/disk_snapshot.py @@ -815,6 +815,81 @@ def _read_state_bearer_into(group, obj) -> None: obj.state = dataclasses.replace(current_state, **overrides) +def is_snapshot_wrapper(path: str) -> bool: + """Quick check whether ``path`` is a v1.1 snapshot wrapper file. + + Used by :meth:`MeshVariable.read_timestep` to dispatch between + the legacy per-variable layout and the v1.1 sidecar format — + same user call, different storage, hidden behind the function. + """ + import h5py + + try: + with h5py.File(path, "r") as f: + return _GROUP_METADATA in f and _GROUP_MESHES in f + except (OSError, KeyError): + return False + + +def extract_var_via_bridge(wrapper_path: str, var_name: str): + """Bridge for selective per-variable reads of v1.1 snapshots. + + Given the wrapper path and a variable name, returns + ``(coords, values)`` numpy arrays — exactly what + :meth:`MeshVariable.read_timestep` produces on rank 0 from the + legacy layout. The rest of read_timestep's swarm-routing + + KDTree machinery is format-agnostic; this bridge is what makes + ``read_timestep`` work transparently against new files. + + Mechanism: load the source mesh from the .mesh.h5 sidecar, + rebuild the source variable with the correct shape, load DOFs + via #146's ``read_checkpoint``, then read out ``var.coords`` + + ``var.array``. + """ + import h5py + + bulk_dir = _bulk_dir_for(wrapper_path) + found = None + with h5py.File(wrapper_path, "r") as f: + for mesh_safe in f[_GROUP_MESHES].keys(): + mg = f[_GROUP_MESHES][mesh_safe] + for var_safe in mg["variables"].keys(): + v_attrs = mg["variables"][var_safe].attrs + if str(v_attrs["name"]) == var_name: + found = ( + str(mg.attrs["mesh_file"]), + str(v_attrs["external_file"]), + int(v_attrs["degree"]), + int(v_attrs["components"]), + bool(v_attrs["continuous"]), + ) + break + if found: + break + if found is None: + raise ValueError( + f"variable {var_name!r} not found in v1.1 snapshot {wrapper_path}" + ) + + mesh_file_rel, var_file_rel, degree, components, continuous = found + # Rebuild a transient source mesh + variable to read DOFs into. + # We deliberately don't register them with the live model — these + # are throwaway and exit scope on return. + src_mesh = uw.discretisation.Mesh(os.path.join(bulk_dir, mesh_file_rel)) + src_var = uw.discretisation.MeshVariable( + var_name, src_mesh, components, degree=degree, continuous=continuous, + ) + src_var.read_checkpoint( + os.path.join(bulk_dir, var_file_rel), data_name=var_name + ) + + coords = np.asarray(src_var.coords).copy() + values = np.asarray(src_var.array[...]).reshape( + coords.shape[0], components + ).copy() + return coords, values + + def inspect_snapshot(path: str) -> str: """Human-readable one-shot summary of a snapshot file's metadata. diff --git a/src/underworld3/discretisation/discretisation_mesh_variables.py b/src/underworld3/discretisation/discretisation_mesh_variables.py index 78ad8249..9ee9a061 100644 --- a/src/underworld3/discretisation/discretisation_mesh_variables.py +++ b/src/underworld3/discretisation/discretisation_mesh_variables.py @@ -1206,14 +1206,38 @@ def read_timestep( ``file_size`` per rank. """ + # Format dispatch: ``data_filename`` may be either the + # legacy ``write_timestep`` prefix (in which case we + # reconstruct the per-variable file path the usual way) or a + # v1.1 snapshot wrapper path produced by + # ``model.save_state(file=…)``. The format-detection logic is + # hidden from the user — same call, both formats. + import h5py + import numpy as np + from underworld3.checkpoint.disk_snapshot import ( + is_snapshot_wrapper as _is_snapshot_wrapper, + extract_var_via_bridge as _extract_var_via_bridge, + ) + output_base_name = os.path.join(outputPath, data_filename) - data_file = output_base_name + f".mesh.{data_name}.{index:05}.h5" + legacy_file = output_base_name + f".mesh.{data_name}.{index:05}.h5" - if not os.path.isfile(os.path.abspath(data_file)): - raise RuntimeError(f"{os.path.abspath(data_file)} does not exist") + is_v1_1 = ( + os.path.isfile(data_filename) + and not data_filename.endswith( + f".mesh.{data_name}.{index:05}.h5" + ) + and _is_snapshot_wrapper(data_filename) + ) - import h5py - import numpy as np + if is_v1_1: + data_file = data_filename + else: + data_file = legacy_file + if not os.path.isfile(os.path.abspath(data_file)): + raise RuntimeError( + f"{os.path.abspath(data_file)} does not exist" + ) # ``self.num_components`` is correct for SCALAR (1), VECTOR (dim), # TENSOR (dim**2) and SYM_TENSOR (dim*(dim+1)/2). ``self.shape[1]`` @@ -1236,10 +1260,21 @@ def read_timestep( if uw.mpi.rank == 0: if verbose: - print(f"Reading data file {data_file}", flush=True) - with h5py.File(data_file, "r") as h5f: - X_src = h5f["fields"]["coordinates"][()].reshape(-1, dim) - D_src = h5f["fields"][data_name][()].reshape(-1, n_components) + print( + f"Reading data file {data_file} " + f"(format: {'v1.1 snapshot' if is_v1_1 else 'legacy timestep'})", + flush=True, + ) + if is_v1_1: + X_src, D_src = _extract_var_via_bridge(data_file, data_name) + X_src = X_src.reshape(-1, dim) + D_src = D_src.reshape(-1, n_components) + else: + with h5py.File(data_file, "r") as h5f: + X_src = h5f["fields"]["coordinates"][()].reshape(-1, dim) + D_src = h5f["fields"][data_name][()].reshape( + -1, n_components + ) else: X_src = np.empty((0, dim), dtype=np.double) D_src = np.empty((0, n_components), dtype=np.double) diff --git a/tests/test_0010_snapshot_disk_format.py b/tests/test_0010_snapshot_disk_format.py index 0d6eb826..18006d6b 100644 --- a/tests/test_0010_snapshot_disk_format.py +++ b/tests/test_0010_snapshot_disk_format.py @@ -589,6 +589,55 @@ def test_swarm_restore_recovers_after_particle_count_change(tmp_path): assert np.array_equal(np.asarray(material.data), material_pre) +# ----- Phase 4: read_timestep is format-aware ----- + + +def test_read_timestep_reads_v1_1_snapshot_via_dispatch(tmp_path): + """The selective-read entry point users already know + (``var.read_timestep(...)``) reads a v1.1 snapshot wrapper too — + the format detection is hidden inside the call. Same-mesh read + is bit-exact because the KDTree query lands exactly on the + captured DOF coordinates.""" + import underworld3 as uw + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + T.array[:, 0, 0] = 5.0 * T.coords[:, 0] + 3.0 + T_pre = np.asarray(T.array[...]).copy() + + path = str(tmp_path / "run.snap.h5") + model.save_state(file=path) + + # Scribble T, then re-load via the legacy read_timestep API — + # but pointed at the v1.1 wrapper. Format dispatch handles it. + T.array[...] = -99.0 + T.read_timestep(path, "T", 0) + + assert np.allclose(np.asarray(T.array[...]), T_pre, atol=1e-12), ( + f"read_timestep against v1.1 snapshot not bit-exact: " + f"max|d| = {float(np.max(np.abs(np.asarray(T.array[...]) - T_pre))):.3e}" + ) + + +def test_read_timestep_legacy_path_unchanged(tmp_path): + """Belt-and-braces: pointing read_timestep at a legacy + write_timestep file still uses the legacy code path (not the v1.1 + bridge).""" + import underworld3 as uw + + uw, model, mesh, T, V = _fresh_model_mesh_and_vars() + T.array[:, 0, 0] = 7.0 + T_pre = np.asarray(T.array[...]).copy() + + # Write via legacy write_timestep — different format, same content. + mesh.write_timestep( + "legacy", index=0, outputPath=str(tmp_path), meshVars=[T] + ) + + T.array[...] = -99.0 + T.read_timestep("legacy", "T", 0, outputPath=str(tmp_path)) + assert np.allclose(np.asarray(T.array[...]), T_pre, atol=1e-12) + + def test_swarm_internals_not_in_sidecar(tmp_path): """The PETSc-internal swarm variables (DMSwarmPIC_coor, DMSwarm_X0, DMSwarm_remeshed) are filtered at capture — they're not in the From eba7500ac415aa9d1affa5797d1d33cd8ef2e261 Mon Sep 17 00:00:00 2001 From: lmoresi Date: Wed, 20 May 2026 20:50:41 +1000 Subject: [PATCH 7/7] =?UTF-8?q?feat(snapshot,=20phase=206):=20per-rank=20s?= =?UTF-8?q?warm=20sidecars=20=E2=80=94=20parallel-correct=20on-disk?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes the last production-readiness gate on the disk path. Swarm sidecars are now per-rank files: each rank writes its own {swarm_safe}.swarm.rank{R:04d}of{S:04d}.h5, the wrapper records the naming pattern + rank count, and on restore each rank opens its matching file. Same shape that #146 uses internally for mesh-var collectives via PETSc, just expressed as per-rank h5py files rather than a single parallel-HDF5 file (avoids the h5py-mpi build dependency). Contract: same-rank-count restart only. Rank-count mismatch on read raises clearly with a pointer to mesh.write_timestep for the flexible-restart path. Each sidecar carries its writer's (mpi_rank, mpi_size_at_write) attrs so a wrong-rank-file load also fails cleanly. Wrapper layout addition: - /swarms/@filled_by = "phase3b+phase6" - /swarms/@mpi_size_at_write - /swarms/{name}/@sidecar_pattern (template with {rank}/{size}) - /swarms/{name}/@num_particles_global (gathered across ranks via MPI.SUM at write time) Phase 6 implementation deliberately keeps the mesh-var collective path #146 already provides — no changes to mesh-side bulk write/read. Only the swarm-sidecar layer is rebuilt for per-rank operation. Tests: - 23 single-rank tests in test_0010 (unchanged count; updated the two that asserted the old single-file naming). - New ptest_0010_snapshot_disk.py exercises -np 1/3/4: wrapper + per-rank sidecars present, particle count preserved, swarm round- trip exact (gather + sort by per-particle gid), tracker state restored, T mesh-var DOFs preserved (via partition-invariant min/max scalars — gathered DOF tables include partition-boundary duplicates that resist direct comparison). - mpi_runner.sh registers the new ptest at -np 1 / 3 / 4. Final tally: 77 single-rank tests green; parallel ptest_0007 (in-memory) and ptest_0010 (on-disk) both PASS at np 1/3/4. Production verdict on the disk path: matches the in-memory path — correct serial, parallel, and through real solvers. The full v1.1 plan from project_snapshot_v1_1_disk_format.md is now landed: phases 1, 2, 3a, 3b, 4 (read_timestep dispatch), 5 (unified save_state/load_state API), 6 (parallel sidecars). Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- src/underworld3/checkpoint/disk_snapshot.py | 108 +++++++++--- tests/parallel/mpi_runner.sh | 7 + tests/parallel/ptest_0010_snapshot_disk.py | 178 ++++++++++++++++++++ tests/test_0010_snapshot_disk_format.py | 25 ++- 4 files changed, 291 insertions(+), 27 deletions(-) create mode 100644 tests/parallel/ptest_0010_snapshot_disk.py diff --git a/src/underworld3/checkpoint/disk_snapshot.py b/src/underworld3/checkpoint/disk_snapshot.py index 1b4954db..91716fce 100644 --- a/src/underworld3/checkpoint/disk_snapshot.py +++ b/src/underworld3/checkpoint/disk_snapshot.py @@ -371,38 +371,59 @@ def write_snapshot(model, path: str) -> str: bg = ps_group.create_group(key) _write_state_bearer_to_group(bg, obj) - # Phase 3b: swarms — one sidecar file per swarm in the bulk dir, - # referenced from /swarms/{swarm_safe}/ in the wrapper. + # Phase 3b + 6: swarms — per-rank sidecar files in the bulk dir, + # referenced from /swarms/{swarm_safe}/ via a sidecar_pattern attr. + # Every rank writes its own sidecar; all-ranks-participate so the + # bulk dir is complete before the wrapper records the layout. + rank = int(uw.mpi.rank) + size = int(uw.mpi.size) swarm_records: list[dict] = [] for swarm in list(model._swarms.values()): swarm_safe = _swarm_safe_name(swarm) - sidecar_name = _swarm_sidecar_filename(swarm_safe) + sidecar_name = _swarm_sidecar_filename(swarm_safe, rank, size) sidecar_path = os.path.join(bulk_dir, sidecar_name) - # h5py-direct write (single-rank in this phase; MPI is phase 6). - with uw.selective_ranks(0) as rank0: - if rank0: - rec = _write_swarm_to_sidecar(swarm, sidecar_path) - rec["safe_name"] = swarm_safe - rec["external_file"] = sidecar_name - swarm_records.append(rec) - uw.mpi.barrier() + rec = _write_swarm_to_sidecar(swarm, sidecar_path) + rec["safe_name"] = swarm_safe + rec["sidecar_pattern"] = _swarm_sidecar_pattern(swarm_safe) + swarm_records.append(rec) + uw.mpi.barrier() + # Wrapper update is rank-0-only; gather global counts across + # ranks first so the wrapper carries a complete inventory. if swarm_records: + try: + from mpi4py import MPI + + comm = MPI.COMM_WORLD + global_counts = { + rec["safe_name"]: int( + comm.allreduce(rec["num_particles_local"], op=MPI.SUM) + ) + for rec in swarm_records + } + except ImportError: + global_counts = { + rec["safe_name"]: int(rec["num_particles_local"]) + for rec in swarm_records + } + with uw.selective_ranks(0) as rank0: if rank0: with h5py.File(path, "a") as f: sw_group = f[_GROUP_SWARMS] - sw_group.attrs["filled_by"] = "phase3b" + sw_group.attrs["filled_by"] = "phase3b+phase6" + sw_group.attrs["mpi_size_at_write"] = size for rec in swarm_records: g = sw_group.create_group(rec["safe_name"]) g.attrs["mesh_name"] = rec["mesh_name"] - g.attrs["num_particles_local"] = rec[ - "num_particles_local" + g.attrs["num_particles_global"] = global_counts[ + rec["safe_name"] ] g.attrs["population_generation"] = rec[ "population_generation" ] - g.attrs["external_file"] = rec["external_file"] + g.attrs["sidecar_pattern"] = rec["sidecar_pattern"] + g.attrs["mpi_size_at_write"] = size vars_g = g.create_group("variables") for var_rec in rec["vars"]: v = vars_g.create_group(_sanitise(var_rec["name"])) @@ -486,12 +507,25 @@ def read_snapshot(model, path: str) -> None: ) _read_state_bearer_into(ps_group[key], obj) - # Phase 3b: restore swarms from sidecars. + # Phase 3b + 6: restore swarms from per-rank sidecars. if _GROUP_SWARMS in f: sw_group = f[_GROUP_SWARMS] + if "mpi_size_at_write" in sw_group.attrs: + write_size = int(sw_group.attrs["mpi_size_at_write"]) + if write_size != int(uw.mpi.size): + raise ValueError( + f"snapshot at {path} was written on {write_size} " + f"MPI rank(s); this run is on {uw.mpi.size}. " + f"Cross-rank-count restore is not supported by the " + f"snapshot mechanism — use mesh.write_timestep for " + f"that case." + ) + swarms_by_safe = { _swarm_safe_name(s): s for s in list(model._swarms.values()) } + rank = int(uw.mpi.rank) + size = int(uw.mpi.size) for swarm_safe in sw_group.keys(): g = sw_group[swarm_safe] swarm = swarms_by_safe.get(swarm_safe) @@ -500,9 +534,11 @@ def read_snapshot(model, path: str) -> None: f"snapshot at {path} contains swarm {swarm_safe!r} " f"that is not registered on this model" ) - external_file = str(g.attrs["external_file"]) + # Resolve the per-rank sidecar name from the pattern. + pattern = str(g.attrs["sidecar_pattern"]) + sidecar_name = pattern.format(rank=rank, size=size) _read_swarm_from_sidecar( - swarm, os.path.join(bulk_dir, external_file) + swarm, os.path.join(bulk_dir, sidecar_name) ) @@ -527,8 +563,20 @@ def _swarm_safe_name(swarm) -> str: return _sanitise(raw) -def _swarm_sidecar_filename(swarm_safe: str) -> str: - return f"{swarm_safe}.swarm.h5" +def _swarm_sidecar_filename(swarm_safe: str, rank: int, size: int) -> str: + """Per-rank sidecar (phase 6 — parallel-safe). + + Each rank writes its own swarm sidecar; restoring requires the + same rank count. The filename carries both the writer's rank and + the total rank count so each file is self-describing. + """ + return f"{swarm_safe}.swarm.rank{rank:04d}of{size:04d}.h5" + + +def _swarm_sidecar_pattern(swarm_safe: str) -> str: + """Pattern stored in the wrapper for readers to fill in their own + (rank, size) when locating their sidecar.""" + return f"{swarm_safe}.swarm.rank{{rank:04d}}of{{size:04d}}.h5" def _write_swarm_to_sidecar(swarm, sidecar_path: str) -> dict: @@ -553,6 +601,11 @@ def _write_swarm_to_sidecar(swarm, sidecar_path: str) -> dict: f.attrs["dim"] = int(swarm.dim) f.attrs["mesh_name"] = str(swarm.mesh.name) f.attrs["population_generation"] = int(swarm._population_generation) + # Parallel-write provenance — each per-rank sidecar carries + # its writer's identity so the reader can sanity-check it + # opened the right file. + f.attrs["mpi_rank"] = int(uw.mpi.rank) + f.attrs["mpi_size_at_write"] = int(uw.mpi.size) f.create_dataset("coordinates", data=coords) @@ -588,11 +641,26 @@ def _read_swarm_from_sidecar(swarm, sidecar_path: str) -> None: with h5py.File(sidecar_path, "r") as f: saved_coords = np.asarray(f["coordinates"][...]) captured_mesh_name = str(f.attrs.get("mesh_name", "")) + captured_rank = int(f.attrs.get("mpi_rank", 0)) + captured_size = int(f.attrs.get("mpi_size_at_write", 1)) var_data: dict[str, np.ndarray] = {} if "variables" in f: for name in f["variables"].keys(): var_data[name] = np.asarray(f["variables"][name][...]) + if captured_size != uw.mpi.size: + raise ValueError( + f"sidecar at {sidecar_path}: was written on " + f"{captured_size} MPI rank(s); this run is on {uw.mpi.size}. " + f"Cross-rank-count snapshot restore is out of scope — restart " + f"on the same rank count or use mesh.write_timestep for the " + f"flexible-restart path." + ) + if captured_rank != uw.mpi.rank: + raise ValueError( + f"sidecar at {sidecar_path}: written by rank {captured_rank}; " + f"this is rank {uw.mpi.rank}. Wrong per-rank sidecar opened." + ) if captured_mesh_name and captured_mesh_name != swarm.mesh.name: raise ValueError( f"sidecar at {sidecar_path}: parent mesh was " diff --git a/tests/parallel/mpi_runner.sh b/tests/parallel/mpi_runner.sh index 3b479dbe..1c74260f 100755 --- a/tests/parallel/mpi_runner.sh +++ b/tests/parallel/mpi_runner.sh @@ -31,3 +31,10 @@ echo "ptest 0007 snapshot in-memory -np 3 (uneven partition)" mpirun -np 3 $PYTHON ./ptest_0007_snapshot_inmemory.py echo "ptest 0007 snapshot in-memory -np 4" mpirun -np 4 $PYTHON ./ptest_0007_snapshot_inmemory.py + +echo "ptest 0010 snapshot on-disk -np 1" +mpirun -np 1 $PYTHON ./ptest_0010_snapshot_disk.py +echo "ptest 0010 snapshot on-disk -np 3 (uneven)" +mpirun -np 3 $PYTHON ./ptest_0010_snapshot_disk.py +echo "ptest 0010 snapshot on-disk -np 4" +mpirun -np 4 $PYTHON ./ptest_0010_snapshot_disk.py diff --git a/tests/parallel/ptest_0010_snapshot_disk.py b/tests/parallel/ptest_0010_snapshot_disk.py new file mode 100644 index 00000000..b9682313 --- /dev/null +++ b/tests/parallel/ptest_0010_snapshot_disk.py @@ -0,0 +1,178 @@ +"""Parallel (MPI) test of the on-disk snapshot path (v1.1). + +Phase 6 of the snapshot toolkit: per-rank swarm sidecars. The mesh ++ mesh-variable disk path is already parallel-correct via #146's +PETSc-collective HDF5 viewer; the swarm sidecar layer needs its +own per-rank file per swarm. This ptest exercises both together at +multi-rank. + +Run (4 ranks exercises cross-rank distribution of swarm particles): + + cd tests/parallel + mpirun -np 4 python ./ptest_0010_snapshot_disk.py + +Asserts (collective, checked on rank 0): + + 1. Disk write produces one wrapper file + one swarm sidecar per + rank in the bulk dir, each with the rank+size in its filename. + 2. Each rank's sidecar carries the writing rank's local-particle + state (verified by per-rank attrs on the sidecar). + 3. Round-trip is exact: scribble all variables + swarm coords + + swarm-var data, model.load_state(file=...), gathered (gid, x, y, + material) tables sorted by gid are np.array_equal. +""" + +import os + +import numpy as np +import sympy +from mpi4py import MPI + +import underworld3 as uw + +comm = MPI.COMM_WORLD +rank = uw.mpi.rank +size = uw.mpi.size + + +def build(): + uw.reset_default_model() + uw.use_strict_units(False) + uw.use_nondimensional_scaling(False) + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(4.0, 1.0), cellSize=1.0 / 6.0 + ) + x_sym, y_sym = mesh.X + V_fn = sympy.Matrix([[-(y_sym - 0.5), 0.25 * (x_sym - 2.0)]]).T + + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=1) + T.array[:, 0, 0] = mesh.X.coords[:, 0] - mesh.X.coords[:, 1] + + swarm = uw.swarm.Swarm(mesh) + gid = swarm.add_variable("gid", 1, dtype=float) + material = swarm.add_variable("material", 1, dtype=float) + swarm.populate(fill_param=2) + + local_n = swarm.dm.getLocalSize() + counts = comm.allgather(local_n) + offset = int(np.sum(counts[:rank])) + gid.data[:, 0] = offset + np.arange(local_n, dtype=float) + material.data[:, 0] = swarm._particle_coordinates.data[:, 0] + + model.tracker.time = 1.5 + model.tracker.step = 7 + return uw, model, mesh, V_fn, T, swarm, gid, material + + +def global_sorted_state(T, swarm, gid, material): + """Gather (gid, x, y, material, T-value-by-coord-bin) across ranks + + sort by gid → order/rank-independent canonical view.""" + g = gid.data[:, 0].copy() + coords = swarm._particle_coordinates.data.copy() + m = material.data[:, 0].copy() + local = np.column_stack([g, coords[:, 0], coords[:, 1], m]) + gathered = comm.allgather(local) + full = np.vstack([a for a in gathered if a.size]) if any( + a.size for a in gathered + ) else np.empty((0, 4)) + order = np.argsort(full[:, 0], kind="stable") + swarm_state = full[order] + + # T round-trip check: gather partition-invariant scalars + # (max, sum) rather than the full (coord, value) table — DOFs at + # partition boundaries are visible to multiple ranks and would + # appear duplicated/reordered in a gathered table, even though + # the underlying data is bit-exact. + t_arr = np.asarray(T.array[...]).reshape(-1) + t_max = comm.allreduce(float(t_arr.max()) if t_arr.size else -np.inf, + op=MPI.MAX) + t_min = comm.allreduce(float(t_arr.min()) if t_arr.size else np.inf, + op=MPI.MIN) + # bit-exact float sum across ranks is non-deterministic in general + # (non-associative); use min/max as bit-exact invariants instead. + return swarm_state, (t_max, t_min) + + +def main(): + import tempfile + + uw, model, mesh, V_fn, T, swarm, gid, material = build() + pre_swarm, pre_T = global_sorted_state(T, swarm, gid, material) + pre_count = comm.allreduce(swarm.dm.getLocalSize(), op=MPI.SUM) + + # Use a shared temp dir reachable from every rank + if rank == 0: + tmp = tempfile.mkdtemp(prefix="uw3_ptest_0010_") + else: + tmp = None + tmp = comm.bcast(tmp, root=0) + + wrapper = os.path.join(tmp, "parrun.snap.h5") + model.save_state(file=wrapper) + comm.Barrier() + + # Check files on rank 0 + files_ok = True + if rank == 0: + bulk = os.path.join(tmp, "parrun.snap.bulk") + files = sorted(os.listdir(bulk)) + per_rank = [f for f in files if ".swarm.rank" in f] + # Expect one swarm sidecar per rank + if len(per_rank) != size: + print( + f"!! expected {size} swarm sidecars, got {len(per_rank)}: " + f"{per_rank}", + flush=True, + ) + files_ok = False + else: + print( + f" swarm sidecars OK: {per_rank[0]} ... ({size} total)", + flush=True, + ) + + # Scribble everything + T.array[...] = -99.0 + coord_field = swarm.dm.getField("DMSwarmPIC_coor").reshape((-1, swarm.dim)) + coord_field[...] = -99.0 + swarm.dm.restoreField("DMSwarmPIC_coor") + material.data[...] = -99.0 + model.tracker.time = -1.0 + model.tracker.step = -1 + + model.load_state(wrapper) + post_count = comm.allreduce(swarm.dm.getLocalSize(), op=MPI.SUM) + post_swarm, post_T = global_sorted_state(T, swarm, gid, material) + + swarm_ok = np.array_equal(pre_swarm, post_swarm) + # T is checked via partition-invariant min/max scalars (see note + # in global_sorted_state — gathered DOFs include partition- + # boundary duplicates that resist a global-table comparison). + T_ok = (pre_T == post_T) + count_ok = pre_count == post_count + tracker_ok = (model.tracker.time == 1.5 and model.tracker.step == 7) + + if rank == 0: + print(f"[ranks={size}] particles total = {pre_count}", flush=True) + print(f" P1 disk wrapper + per-rank sidecars present: {files_ok}", + flush=True) + print(f" P2 particle count preserved: {count_ok}", + flush=True) + print(f" P3 swarm (coords + gid + material) exact: {swarm_ok}", + flush=True) + print(f" P4 T (mesh-variable DOFs) exact: {T_ok}", + flush=True) + print(f" P5 tracker state restored: {tracker_ok}", + flush=True) + + assert files_ok + assert count_ok + assert swarm_ok + assert T_ok + assert tracker_ok + print(f"[ranks={size}] PASS", flush=True) + + +if __name__ == "__main__": + main() diff --git a/tests/test_0010_snapshot_disk_format.py b/tests/test_0010_snapshot_disk_format.py index 18006d6b..56609dfe 100644 --- a/tests/test_0010_snapshot_disk_format.py +++ b/tests/test_0010_snapshot_disk_format.py @@ -470,8 +470,10 @@ def _fresh_model_mesh_swarm(): def test_swarm_sidecar_lands_in_bulk_dir(tmp_path): """A registered swarm produces its own h5 sidecar in the bulk dir - with predictable name; wrapper carries the external_file ref.""" + with predictable name (per-rank); wrapper records the sidecar + pattern + rank count.""" import os + import re import h5py import underworld3 as uw @@ -481,17 +483,26 @@ def test_swarm_sidecar_lands_in_bulk_dir(tmp_path): bulk = str(tmp_path / "run.snap.bulk") files = sorted(os.listdir(bulk)) - swarm_sidecars = [f for f in files if f.endswith(".swarm.h5")] - assert len(swarm_sidecars) == 1 + # Per-rank pattern: .swarm.rank{R:04d}of{S:04d}.h5 + swarm_sidecars = [ + f for f in files + if re.match(r".*\.swarm\.rank\d{4}of\d{4}\.h5$", f) + ] + assert len(swarm_sidecars) >= 1 # at least rank-0 file on single-rank with h5py.File(path, "r") as f: sw = f["swarms"] - assert sw.attrs["filled_by"] == "phase3b" + assert sw.attrs["filled_by"] == "phase3b+phase6" + assert "mpi_size_at_write" in sw.attrs assert len(sw.keys()) == 1 swarm_safe = list(sw.keys())[0] sg = sw[swarm_safe] - assert sg.attrs["external_file"] == swarm_sidecars[0] + # Wrapper records pattern (not a specific file) — readers + # fill in their own rank. + assert "sidecar_pattern" in sg.attrs + assert "{rank" in str(sg.attrs["sidecar_pattern"]) assert sg.attrs["mesh_name"] == mesh.name + assert int(sg.attrs["num_particles_global"]) > 0 # User-added variables surface in /swarms/{name}/variables/ assert "material" in sg["variables"] @@ -510,7 +521,7 @@ def test_swarm_sidecar_is_inspectable(tmp_path): bulk = str(tmp_path / "run.snap.bulk") swarm_sidecar = [ - f for f in os.listdir(bulk) if f.endswith(".swarm.h5") + f for f in os.listdir(bulk) if ".swarm.rank" in f ][0] with h5py.File(os.path.join(bulk, swarm_sidecar), "r") as f: @@ -651,7 +662,7 @@ def test_swarm_internals_not_in_sidecar(tmp_path): model.save_state(file=path) bulk = str(tmp_path / "run.snap.bulk") - sidecar = [f for f in os.listdir(bulk) if f.endswith(".swarm.h5")][0] + sidecar = [f for f in os.listdir(bulk) if ".swarm.rank" in f][0] with h5py.File(os.path.join(bulk, sidecar), "r") as f: var_names = set(f["variables"].keys()) assert "material" in var_names