Skip to content
Merged
117 changes: 84 additions & 33 deletions docs/advanced/snapshot-restore.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ title: "State Snapshots & Restore"

## Overview

`Model.snapshot()` and `Model.restore()` are a *stash for timesteps* —
`Model.save_state()` and `Model.load_state()` are a *stash for timesteps* —
a quick "hold that thought, I might need to come back" mechanism for
time-stepping code. Take a snapshot, try a step, and if you don't like
the result, restore and try again. The system is put back exactly as it
Expand All @@ -23,11 +23,13 @@ Typical uses:
- **Multi-stage time integration** (RK-style) — restore to the start
of a step between stages.

This is intentionally *not* archival checkpointing. It is fast,
in-memory, and meant to be used freely within a run. For long-term,
on-disk restart files, use the existing `mesh.write_timestep()` /
`read_timestep()` path, which is unchanged and serves a different
purpose.
The same entry points serve two storage modes — in-memory for fast
intra-run stash, and on-disk for persistent restart / postprocessing.
You pick by giving (or not giving) a ``file=``. The existing
``mesh.write_timestep()`` / ``read_timestep()`` and
``mesh.write_checkpoint()`` / ``MeshVariable.read_checkpoint()``
paths are unchanged and continue to serve their existing roles (see
"Choosing between paths" at the bottom).

## The API

Expand All @@ -38,20 +40,37 @@ model = uw.get_default_model()

# ... set up mesh, variables, swarm, solvers, step a few times ...

token = model.snapshot() # capture everything, return a token
# In-memory: the "stash for timesteps" use case.
token = model.save_state() # capture everything, return a token

# ... take a speculative step you might regret ...

model.restore(token) # put everything back exactly
model.load_state(token) # put everything back exactly
```

`snapshot()` returns a plain in-memory token. You can hold several at
once and restore any of them. `restore()` returns the model to the
exact state at the moment that token was captured.
To persist on disk instead, pass a path. The same call captures the
same state — only the storage layer differs:

```python
# On-disk: persistent snapshot for restart, postprocessing,
# bisection studies, or transferring a run to another machine.
model.save_state(file="step42.snap.h5")
# ... later, or in a fresh process with the same model set up ...
model.load_state("step42.snap.h5")
```

``save_state()`` returns a :class:`Snapshot` token when called
without ``file``; you can hold several at once and restore any of
them. With ``file=...`` it writes a self-contained on-disk snapshot
and returns the path.

``load_state()`` takes either a token or a path string — it figures
out which from the argument type, so the same call works for both
storage modes.

## What is captured

You do not enumerate anything — `snapshot()` captures the full state
You do not enumerate anything — `save_state()` captures the full state
of the model automatically:

- mesh coordinates,
Expand All @@ -69,13 +88,13 @@ and restore — that is exactly the situation restore exists for.

A subtle trap in time-stepping scripts: your loop counter and
simulation time usually live in plain Python variables, and
`restore()` has no way to know about them.
`load_state()` has no way to know about them.

```python
model_time = 0.0
token = model.snapshot()
token = model.save_state()
model_time = 5.0 # advance
model.restore(token)
model.load_state(token)
# model_time is still 5.0 — restore cannot reach a local variable
```

Expand All @@ -87,12 +106,12 @@ restored.
model.tracker.time = 0.0
model.tracker.step = 0

token = model.snapshot()
token = model.save_state()

model.tracker.time = 5.0
model.tracker.step = 100

model.restore(token)
model.load_state(token)

model.tracker.time # 0.0 — reverted automatically
model.tracker.step # 0 — reverted automatically
Expand All @@ -109,7 +128,7 @@ model.tracker.energy_history = np.zeros(3)
These now travel with every snapshot and revert on every restore — no
extra code, no special handling in your solvers. Using the tracker is
optional; solvers do not depend on it. It is simply the place to keep
the things you want `restore()` to manage.
the things you want `load_state()` to manage.

```{note} Reserved name
`state` is reserved on the tracker (it is the snapshot mechanism's own
Expand Down Expand Up @@ -139,7 +158,7 @@ cfl_limit = mesh.get_min_radius()
dt = 0.5

while model.tracker.time < t_end:
token = model.snapshot()
token = model.save_state()
coords_before = swarm._particle_coordinates.data.copy()

# Speculative step at the current Δt.
Expand All @@ -153,7 +172,7 @@ while model.tracker.time < t_end:

if moved > cfl_limit:
# Too big — discard and retry with a smaller Δt.
model.restore(token)
model.load_state(token)
dt *= 0.5
continue

Expand All @@ -180,24 +199,56 @@ attempt starts from precisely where the failed one began.
```

```{warning} Limitations
- **In-memory only.** Snapshots live in process memory and are not
written to disk; they do not survive the process exiting. They are
also a full copy of model state — holding many large snapshots at
once costs memory.
- **Same rank count.** A snapshot taken on *N* MPI ranks is restored
on *N* ranks. Changing the rank count is not supported by this
mechanism (use the `write_timestep` restart path for that).
- **In-memory tokens are intra-run only.** Tokens returned by
``save_state()`` (no ``file``) live in process memory and do not
survive the process exiting. They are also a full copy of model
state — holding many large tokens at once costs memory. To persist
across runs, use ``save_state(file=…)`` instead.
- **Same rank count.** A snapshot written on *N* MPI ranks must be
read on *N* ranks. Cross-rank-count restart is not supported by
this mechanism (use the ``mesh.write_timestep`` path for that).
- **No mesh adaptation across a snapshot.** If the mesh is adapted
between snapshot and restore, restore refuses with a clear error
between save and load, ``load_state`` refuses with a clear error
rather than corrupting state.
- **Recovery vs. a never-snapshotted run** is bit-exact for the
*discarded-step* guarantee above. Continuing after a restore that
ran a real solver may differ from a run that never snapshotted by a
small amount within solver tolerance — restore resyncs solver
fields rather than reproducing their exact internal buffers. This
does not affect the correctness of backtracking.
*discarded-step* guarantee above. Continuing after a load_state
that ran a real solver may differ from a run that never
snapshotted by a small amount within solver tolerance — load_state
resyncs solver fields rather than reproducing their exact internal
buffers. This does not affect the correctness of backtracking.
```

## On-disk file layout (when ``save_state(file=…)`` is used)

A disk snapshot is **two artifacts**: a small wrapper HDF5 file plus
a sibling ``.bulk/`` directory holding the bulk data. They are a
unit — move them together. The wrapper is rich in metadata and
inspectable with standard h5 tools without UW3 in the loop:

```text
my_run.snap.h5 (~tens of KB; metadata, group structure)
my_run.snap.bulk/ (per-mesh + per-swarm sidecars)
{mesh}.mesh.00000.h5
{mesh}.{var}.00000.h5 (one per mesh-variable)
{swarm}.swarm.h5 (one per swarm)
```
Comment on lines +228 to +234

A quick ``h5ls -v my_run.snap.h5/metadata`` shows you the run name,
schema version, simulation time, step, dimensions, MPI rank count at
write, and inventories of meshes / swarms / variables / state-bearer
classes. For a Python-side summary use
``uw.checkpoint.inspect_snapshot(path)``.

## Choosing between paths

| Need | Use |
|---|---|
| Backtrack a few steps inside a running script (RK staging, adaptive Δt, predictor–corrector probes) | ``save_state()`` → token |
| Persist whole-model state across runs (crash recovery, bisection studies, full restart) | ``save_state(file=…)`` / ``load_state(file=…)`` |
| Restart from a previous run on a *different* rank count or remap onto a *different* resolution | ``mesh.write_timestep()`` / ``MeshVariable.read_timestep()`` (KDTree remap) |
| Efficient same-rank restart writing only specific variables for postprocessing | ``mesh.write_checkpoint()`` / ``MeshVariable.read_checkpoint()`` (PETSc DMPlex per-variable) |
| Visualisation for ParaView (XDMF + per-step HDF5) | ``mesh.write_timestep(create_xdmf=True)`` |

## Related

- [Parallel-Safe Scripting](parallel-computing.md) — MPI patterns;
Expand Down
18 changes: 18 additions & 0 deletions src/underworld3/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@
)
from .state import Snapshottable, SnapshottableState
from .tracker import ModelTracker, TrackerState
from .disk_snapshot import (
DISK_SNAPSHOT_SCHEMA_VERSION,
extract_var_via_bridge,
inspect_snapshot,
is_snapshot_wrapper,
read_snapshot,
read_snapshot_metadata,
write_snapshot,
write_snapshot_skeleton,
)

__all__ = [
"CheckpointBackend",
Expand All @@ -40,4 +50,12 @@
"SnapshottableState",
"ModelTracker",
"TrackerState",
"DISK_SNAPSHOT_SCHEMA_VERSION",
"extract_var_via_bridge",
"inspect_snapshot",
"is_snapshot_wrapper",
"read_snapshot",
"read_snapshot_metadata",
"write_snapshot",
"write_snapshot_skeleton",
]
Loading
Loading