From 8946131c65593bab0023c8e0f9aab33d0f6ef116 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 20 May 2026 16:51:41 +0200 Subject: [PATCH] Add Stage 1 checkpoint reuse boundary --- changelog.d/1074.added | 1 + docs/engineering/skills/README.md | 2 + docs/engineering/stages/build_datasets.md | 87 +++ modal_app/data_build.py | 677 ++++++++++++++++-- .../build_datasets/__init__.py | 33 + .../build_datasets/checkpoints.py | 255 +++++++ .../build_datasets/contracts.py | 2 + .../build_datasets/coordinator.py | 139 +++- policyengine_us_data/build_datasets/rerun.py | 292 ++++++++ .../build_datasets/results.py | 26 +- policyengine_us_data/build_datasets/status.py | 4 +- .../stage_contracts/dataset_build.py | 2 + tests/unit/test_build_dataset_checkpoints.py | 182 +++++ tests/unit/test_build_dataset_rerun.py | 214 ++++++ tests/unit/test_build_dataset_status_store.py | 71 +- .../unit/test_dataset_build_stage_contract.py | 26 + tests/unit/test_modal_data_build.py | 371 +++++++++- 17 files changed, 2309 insertions(+), 75 deletions(-) create mode 100644 changelog.d/1074.added create mode 100644 docs/engineering/stages/build_datasets.md create mode 100644 policyengine_us_data/build_datasets/checkpoints.py create mode 100644 policyengine_us_data/build_datasets/rerun.py create mode 100644 tests/unit/test_build_dataset_checkpoints.py create mode 100644 tests/unit/test_build_dataset_rerun.py diff --git a/changelog.d/1074.added b/changelog.d/1074.added new file mode 100644 index 000000000..d1801445d --- /dev/null +++ b/changelog.d/1074.added @@ -0,0 +1 @@ +Added Stage 1 checkpoint adapter and rerun reuse planning boundaries. diff --git a/docs/engineering/skills/README.md b/docs/engineering/skills/README.md index 080210f08..4fc8dbfd7 100644 --- a/docs/engineering/skills/README.md +++ b/docs/engineering/skills/README.md @@ -29,6 +29,8 @@ pipeline path. Current stage guides: +- `build_datasets.md`: Stage 1 build-dataset identity, checkpoint reuse, + conditional running, and contract metadata guidance. - `build_outputs.md`: Stage 4 output-build library boundaries and test expectations. - `release_promotion.md`: Stage 5 release candidate identity, validation-report diff --git a/docs/engineering/stages/build_datasets.md b/docs/engineering/stages/build_datasets.md new file mode 100644 index 000000000..218caf6d4 --- /dev/null +++ b/docs/engineering/stages/build_datasets.md @@ -0,0 +1,87 @@ +# Stage 1: Build Datasets + +Stage 1 builds the public dataset artifacts consumed by later pipeline stages. +Its public status boundary is organized around the `1a_` through `1f_` +substeps, while the transitional Modal runtime still executes several +command-backed units inside some of those public substeps. + +## Rerun And Reuse Model + +Checkpoint reuse has two gates: + +- The semantic gate compares current `Stage1IdentityMaterial` with a persisted + identity from the checkpoint-scoped Stage 1 reuse manifest. +- The physical gate verifies that every expected checkpoint output exists and is + non-empty before a unit is restored. + +The physical checkpoint layout remains `/checkpoints/{branch}/{commit_sha}`. +The Stage 1 reuse manifest is adapter state in that same scope. Missing, +malformed, or unreadable manifest content must fail closed to recompute; it must +not authorize reuse by itself. + +Keep reuse explanations in the existing Stage 1 contract metadata under +`dataset_build_output.json -> metadata.stage_1_status.reuse_reasoning`. +That metadata should distinguish semantic identity results from physical +checkpoint availability, including missing prior identity, identity mismatch, +identity match, missing checkpoint output, empty checkpoint output, and restored +checkpoint output. + +## Identity Granularity + +`substep_id` is the public reporting group. It is not always the right durable +manifest lookup key, because transitional Stage 1 substeps can contain multiple +independently runnable command or script units. For example, raw-data download +and uprating both report through `1a_raw_data_download`, while the base dataset +substep can run several dataset builders. + +When persisting or looking up reuse identities for a command-backed unit, +`identity_key` is the stable execution identity key within the checkpoint scope. +It includes the public `substep_id` plus enough stable execution material to +distinguish the command or script and its expected reusable outputs. Keep +`substep_id` on the record for public status grouping. + +Do not key multiple manifest records only by `substep_id` unless the record +represents an intentionally aggregated identity for the whole public substep. +Otherwise, later units in the same substep can overwrite earlier units and make +future reruns recompute despite valid checkpoints. + +## Conditional Running + +Unit-level conditional running is the compatibility path while Stage 1 is still +command-backed: + +1. Build current identity material for the runnable unit. +2. Compare it with the previous manifest identity for that unit's identity key. +3. Consult physical checkpoints only when the semantic decision is `reuse`. +4. Restore and skip only that unit when both gates pass. +5. Recompute the unit and update the manifest only after successful output + restoration or successful checkpoint save. + +Public substep status should be aggregated from its unit results. A public +substep is fully `reused` only when every required unit in that substep was +reused. If any unit recomputes successfully, report the substep as completed +with reuse reasoning that explains the mixed path. + +Stage-level conditional running is the same idea one level higher. Stage 1 may +skip all builder execution only when every required unit for the requested run +flags has a matching semantic identity and valid physical checkpoint outputs. +Until the canonical Stage 1 coordinator owns whole-stage planning, do not infer +stage-level reuse from a single substep or unit record. + +## Documentation Expectations + +When changing Stage 1 identity material, checkpoint reuse decisions, artifact +outputs, substep aggregation, or contract metadata, keep the durable +documentation surface synchronized: + +- Update this guide when the Stage 1 rerun or checkpoint model changes. +- Update `docs/pipeline_map.yaml` and regenerate generated pipeline docs when + the stage graph, artifact names, or pipeline-node metadata change. +- Keep `dataset_build_output.json` metadata documentation aligned with the + status and reuse reasoning actually emitted by the Modal adapter. +- Put PR-specific migration rationale in the PR description, not in durable + docs or docstrings. + +Tests for Stage 1 reuse changes should cover missing and malformed manifests, +semantic mismatch, physical checkpoint miss or empty output, same-public-substep +units with distinct identity keys, and contract metadata explaining both gates. diff --git a/modal_app/data_build.py b/modal_app/data_build.py index f9c68b74c..ac839cee9 100644 --- a/modal_app/data_build.py +++ b/modal_app/data_build.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import functools import os import shutil import subprocess import sys import threading -from collections.abc import Mapping +from collections.abc import Callable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from datetime import datetime, timezone @@ -22,6 +24,7 @@ from modal_app.images import cpu_image as image # noqa: E402 from policyengine_us_data.__version__ import __version__ as DATA_PACKAGE_VERSION # noqa: E402 from policyengine_us_data.build_datasets import ( # noqa: E402 + CheckpointStore, CommandRunner, DatasetCommand, DatasetCommandError, @@ -30,6 +33,10 @@ DatasetBuildOutputContractBuilder, PipelineArtifactStager, Stage1Coordinator, + Stage1IdentityMaterial, + Stage1RerunPlanner, + Stage1ReuseDecision, + Stage1ReuseManifestRecord, Stage1StatusRecorder, stage_1_artifact_specs, stage_1_script_outputs, @@ -186,30 +193,20 @@ def get_current_commit() -> str: def get_checkpoint_path(branch: str, output_file: str) -> Path: """Get the checkpoint path for an output file, scoped by branch and commit.""" - commit = get_current_commit() - return Path(VOLUME_MOUNT) / branch / commit / Path(output_file).name + return _checkpoint_store(branch).checkpoint_path(output_file) def is_checkpointed(branch: str, output_file: str) -> bool: """Check if output file exists in checkpoint volume and is valid.""" - checkpoint_path = get_checkpoint_path(branch, output_file) - if checkpoint_path.exists(): - # Verify file is not empty/corrupted - if checkpoint_path.stat().st_size > 0: - return True - return False + return _checkpoint_store(branch).decision_for(output_file).action == "reuse" def restore_from_checkpoint(branch: str, output_file: str) -> bool: """Restore output file from checkpoint volume if it exists.""" - checkpoint_path = get_checkpoint_path(branch, output_file) - if checkpoint_path.exists() and checkpoint_path.stat().st_size > 0: - local_path = Path(output_file) - local_path.parent.mkdir(parents=True, exist_ok=True) - shutil.copy2(checkpoint_path, local_path) + restored = _checkpoint_store(branch).restore_output(output_file) + if restored: print(f"Restored from checkpoint: {output_file}") - return True - return False + return restored def save_checkpoint( @@ -218,25 +215,35 @@ def save_checkpoint( volume: modal.Volume, ) -> None: """Save output file to checkpoint volume.""" - local_path = Path(output_file) - if local_path.exists() and local_path.stat().st_size > 0: - checkpoint_path = get_checkpoint_path(branch, output_file) - checkpoint_path.parent.mkdir(parents=True, exist_ok=True) - shutil.copy2(local_path, checkpoint_path) - with _volume_lock: - volume.commit() + saved = _checkpoint_store(branch, volume).save_output(output_file) + if saved: print(f"Checkpointed: {output_file}") def cleanup_checkpoints(branch: str, volume: modal.Volume) -> None: """Delete checkpoints for this branch after successful completion.""" - branch_dir = Path(VOLUME_MOUNT) / branch - if branch_dir.exists(): - shutil.rmtree(branch_dir) - volume.commit() + cleaned = _checkpoint_store(branch, volume).cleanup_branch() + if cleaned: print(f"Cleaned up checkpoints for branch: {branch}") +def _checkpoint_store( + branch: str, + volume: modal.Volume | None = None, +) -> CheckpointStore: + def commit_volume() -> None: + if volume is not None: + with _volume_lock: + volume.commit() + + return CheckpointStore( + root=Path(VOLUME_MOUNT), + branch=branch, + commit_sha=get_current_commit(), + commit=commit_volume if volume is not None else None, + ) + + def run_script_logged( cmd: list, log_file: IO, @@ -369,6 +376,9 @@ def run_script_with_checkpoint( log_file: IO = None, checkpoint_stats: CheckpointStats | None = None, command_results: list[DatasetCommandResult] | None = None, + checkpoint_store: CheckpointStore | None = None, + reuse_decision: Stage1ReuseDecision | None = None, + identity_material: Stage1IdentityMaterial | None = None, ) -> str: """Run script if output not checkpointed, then checkpoint result. @@ -388,24 +398,54 @@ def run_script_with_checkpoint( if isinstance(output_files, str): output_files = [output_files] expected_count = len(output_files) + checkpoint_store = checkpoint_store or _checkpoint_store(branch, volume) + identity_material = identity_material or _script_identity_material( + script_path=script_path, + output_files=output_files, + branch=branch, + ) + reuse_decision = reuse_decision or _reuse_decision_for_material( + identity_material, + checkpoint_store=checkpoint_store, + run_id=(env or {}).get(RUN_ID_ENV), + ) + if reuse_decision.action == "blocked": + raise RuntimeError( + "Stage 1 checkpoint reuse is blocked for " + f"{script_path}: {reuse_decision.reason}" + ) # Check if ALL outputs are checkpointed - all_checkpointed = all(is_checkpointed(branch, f) for f in output_files) + checkpoint_decisions = ( + checkpoint_store.decisions_for(output_files) + if reuse_decision.action == "reuse" + else () + ) + all_checkpointed = reuse_decision.action == "reuse" and all( + decision.action == "reuse" for decision in checkpoint_decisions + ) if all_checkpointed: # Restore all files from checkpoint + checkpoint_store.restore_all_outputs(output_files) for output_file in output_files: - restore_from_checkpoint(branch, output_file) + print(f"Restored from checkpoint: {output_file}") print(f"Skipping {script_path} (restored from checkpoint)") if checkpoint_stats is not None: checkpoint_stats.record( expected_outputs=expected_count, valid_reused_outputs=expected_count, ) + _record_reuse_manifest( + checkpoint_store=checkpoint_store, + identity_material=identity_material, + reuse_decision=reuse_decision, + checkpoint_decisions=checkpoint_decisions, + ) return script_path missing_or_invalid = sum( - 1 for output_file in output_files if not is_checkpointed(branch, output_file) + 1 for decision in checkpoint_decisions if decision.action != "reuse" ) # Run the script @@ -419,7 +459,16 @@ def run_script_with_checkpoint( # Checkpoint all outputs for output_file in output_files: - save_checkpoint(branch, output_file, volume) + saved = checkpoint_store.save_output(output_file) + if saved: + print(f"Checkpointed: {output_file}") + saved_decisions = checkpoint_store.decisions_for(output_files) + _record_reuse_manifest( + checkpoint_store=checkpoint_store, + identity_material=identity_material, + reuse_decision=reuse_decision, + checkpoint_decisions=saved_decisions, + ) if checkpoint_stats is not None: checkpoint_stats.record( expected_outputs=expected_count, @@ -445,6 +494,446 @@ def _stage_base_artifact_paths(artifacts_dir: Path) -> tuple[Path, ...]: return tuple(paths) +def _stage_1_status_metadata(coordinator: Stage1Coordinator) -> dict[str, Any]: + substep_results = [ + result.to_dict() for result in getattr(coordinator, "results", ()) + ] + status_events = [ + event.to_dict() for event in getattr(coordinator, "status_events", ()) + ] + error_records = [ + error.to_dict() for error in getattr(coordinator, "error_records", ()) + ] + return { + "substep_results": substep_results, + "status_events": status_events, + "error_records": error_records, + "reuse_reasoning": _reuse_reasoning(substep_results), + } + + +def _reuse_reasoning( + substep_results: list[Mapping[str, Any]], +) -> list[dict[str, Any]]: + return [ + { + "substep_id": str(result.get("substep_id", "")), + "title": result.get("title"), + "status": result.get("status"), + "outcome_reason": _reuse_outcome_reason(result), + "reuse_decision": result.get("reuse_decision"), + "identity_decisions": _identity_reuse_decisions( + result.get("reuse_decision") + ), + "checkpoint_summary": _checkpoint_summary( + result.get("checkpoint_decisions", ()) + ), + } + for result in substep_results + if result.get("reuse_decision") or result.get("checkpoint_decisions") + ] + + +def _reuse_outcome_reason(result: Mapping[str, Any]) -> str: + status = result.get("status") + reuse_decision = result.get("reuse_decision") + action = _reuse_action(reuse_decision) + semantic_reason = _reuse_reason(reuse_decision) + reason_counts = _reuse_reason_counts(reuse_decision) + checkpoint_summary = _checkpoint_summary(result.get("checkpoint_decisions", ())) + + if status == "reused": + return ( + "Prior semantic identity matched and every expected checkpoint " + "output was present and non-empty." + ) + if action == "blocked": + return f"Semantic reuse blocked: {semantic_reason or 'unspecified'}." + if checkpoint_summary["missing"] or checkpoint_summary["empty"]: + return ( + "Recomputed because at least one expected checkpoint output was " + "missing or empty." + ) + if semantic_reason == "no_previous_identity": + return ( + "Recomputed because no prior semantic identity manifest record " + "existed for this Stage 1 execution unit." + ) + if semantic_reason == "identity_mismatch": + return ( + "Recomputed because persisted semantic identity did not match the " + "current Stage 1 execution-unit identity." + ) + if reason_counts: + return ( + "Recomputed with per-identity semantic reuse reasons: " + f"{_format_reason_counts(reason_counts)}." + ) + if action == "recompute": + return f"Recomputed because semantic reuse returned {semantic_reason}." + return "Recomputed because no reusable checkpoint decision was available." + + +def _reuse_action(reuse_decision: object) -> str | None: + if not isinstance(reuse_decision, Mapping): + return None + if isinstance(reuse_decision.get("action"), str): + return str(reuse_decision["action"]) + nested = reuse_decision.get("decisions") + if isinstance(nested, list): + actions = { + item.get("action") + for item in nested + if isinstance(item, Mapping) and isinstance(item.get("action"), str) + } + if len(actions) == 1: + return str(next(iter(actions))) + if "blocked" in actions: + return "blocked" + if "recompute" in actions: + return "recompute" + return None + + +def _reuse_reason(reuse_decision: object) -> str | None: + if not isinstance(reuse_decision, Mapping): + return None + if isinstance(reuse_decision.get("reason"), str): + return str(reuse_decision["reason"]) + nested = reuse_decision.get("decisions") + if isinstance(nested, list): + reasons = [ + str(item["reason"]) + for item in nested + if isinstance(item, Mapping) and isinstance(item.get("reason"), str) + ] + if reasons: + return ", ".join(sorted(set(reasons))) + return None + + +def _identity_reuse_decisions(reuse_decision: object) -> list[dict[str, Any]]: + decisions = _reuse_decision_records(reuse_decision) + identity_fields = ( + "substep_id", + "identity_key", + "action", + "reason", + "identity_fingerprint", + "artifact_namespace", + "run_id", + "rerun_id", + ) + return [ + { + field: _metadata_scalar(decision[field]) + for field in identity_fields + if field in decision + } + for decision in decisions + ] + + +def _reuse_reason_counts(reuse_decision: object) -> dict[str, int]: + counts: dict[str, int] = {} + for decision in _reuse_decision_records(reuse_decision): + reason = decision.get("reason") + if isinstance(reason, str) and reason: + counts[reason] = counts.get(reason, 0) + 1 + return counts + + +def _reuse_decision_records(reuse_decision: object) -> list[Mapping[str, Any]]: + if not isinstance(reuse_decision, Mapping): + return [] + nested = reuse_decision.get("decisions") + if isinstance(nested, list): + return [item for item in nested if isinstance(item, Mapping)] + return [reuse_decision] + + +def _format_reason_counts(counts: Mapping[str, int]) -> str: + return ", ".join(f"{reason}={counts[reason]}" for reason in sorted(counts)) + + +def _metadata_scalar(value: Any) -> Any: + if isinstance(value, str | int | float | bool) or value is None: + return value + return str(value) + + +def _checkpoint_summary(checkpoint_decisions: object) -> dict[str, int]: + summary = {"total": 0, "reusable": 0, "missing": 0, "empty": 0, "other": 0} + if not isinstance(checkpoint_decisions, (list, tuple)): + return summary + for decision in checkpoint_decisions: + if not isinstance(decision, Mapping): + continue + summary["total"] += 1 + action = decision.get("action") + reason = decision.get("reason") + if action == "reuse": + summary["reusable"] += 1 + elif reason == "missing": + summary["missing"] += 1 + elif reason == "empty": + summary["empty"] += 1 + else: + summary["other"] += 1 + return summary + + +def _identity_material_for_substep( + *, + substep_id: str, + identity_key: str, + inputs: Mapping[str, Any], + output_files: list[str], + branch: str, + artifact_specs: list[Mapping[str, Any]], +) -> Stage1IdentityMaterial: + return Stage1IdentityMaterial( + substep_id=substep_id, + identity_key=identity_key, + inputs=inputs, + parameters={"branch": branch, "outputs": output_files}, + artifact_specs=artifact_specs, + code_sha=get_current_commit(), + schema_version="stage-1-rerun-v1", + upstream_contract_fingerprints=(), + randomness={"checkpoint_scope": "branch_commit"}, + ) + + +def _stage_1_identity_key( + *, + substep_id: str, + execution_id: str, + output_files: Sequence[str], +) -> str: + output_names = ",".join( + sorted(Path(output_file).name for output_file in output_files) + ) + return f"{substep_id}:{execution_id}:{output_names}" + + +def _reuse_decision_for_material( + material: Stage1IdentityMaterial, + *, + checkpoint_store: CheckpointStore, + run_id: str | None = None, + rerun_id: str | None = None, +) -> Stage1ReuseDecision: + manifest = checkpoint_store.load_reuse_manifest() + planner = Stage1RerunPlanner(previous_identities=manifest.previous_identities()) + return planner.decide( + material, + run_id=run_id or os.environ.get(RUN_ID_ENV, "unknown"), + rerun_id=rerun_id, + ) + + +def _artifact_spec_identity( + *, + script_path: str | None = None, + output_files: list[str] | None = None, +) -> list[Mapping[str, Any]]: + output_names = {Path(output_file).name for output_file in output_files or ()} + return [ + { + "filename": spec.filename, + "logical_name": spec.logical_name, + "storage_path": spec.storage_path, + "substage_id": spec.substage_id, + } + for spec in stage_1_artifact_specs() + if ( + (script_path is not None and spec.script_path == script_path) + or spec.filename in output_names + ) + ] + + +def _raw_data_command_names() -> tuple[str, ...]: + return ( + "policyengine_us_data/storage/download_prerequisites.py", + "make database", + ) + + +def _raw_data_outputs() -> tuple[str, ...]: + return ("policyengine_us_data/storage/calibration/policy_data.db",) + + +def _raw_data_identity_material(*, branch: str) -> Stage1IdentityMaterial: + output_files = list(_raw_data_outputs()) + return _identity_material_for_substep( + substep_id="1a_raw_data_download", + identity_key=_stage_1_identity_key( + substep_id="1a_raw_data_download", + execution_id="command_group:raw_data_download", + output_files=output_files, + ), + inputs={"commands": list(_raw_data_command_names())}, + output_files=output_files, + branch=branch, + artifact_specs=_artifact_spec_identity(output_files=output_files), + ) + + +def _raw_data_reuse_decision( + *, + branch: str, + checkpoint_store: CheckpointStore, + run_id: str | None = None, + rerun_id: str | None = None, +) -> Stage1ReuseDecision: + material = _raw_data_identity_material(branch=branch) + return _reuse_decision_for_material( + material, + checkpoint_store=checkpoint_store, + run_id=run_id, + rerun_id=rerun_id, + ) + + +def run_command_group_with_checkpoint( + *, + substep_id: str, + output_files: tuple[str, ...], + branch: str, + volume: modal.Volume, + action: Callable[[], None], + checkpoint_stats: CheckpointStats | None = None, + checkpoint_store: CheckpointStore | None = None, + reuse_decision: Stage1ReuseDecision | None = None, + identity_material: Stage1IdentityMaterial | None = None, +) -> str: + """Run a non-script command group behind the checkpoint/reuse gate.""" + + expected_count = len(output_files) + checkpoint_store = checkpoint_store or _checkpoint_store(branch, volume) + identity_material = identity_material or _identity_material_for_substep( + substep_id=substep_id, + identity_key=_stage_1_identity_key( + substep_id=substep_id, + execution_id="command_group", + output_files=output_files, + ), + inputs={}, + output_files=list(output_files), + branch=branch, + artifact_specs=_artifact_spec_identity(output_files=list(output_files)), + ) + reuse_decision = reuse_decision or _reuse_decision_for_material( + identity_material, + checkpoint_store=checkpoint_store, + ) + if reuse_decision.action == "blocked": + raise RuntimeError( + "Stage 1 checkpoint reuse is blocked for " + f"{substep_id}: {reuse_decision.reason}" + ) + + checkpoint_decisions = ( + checkpoint_store.decisions_for(output_files) + if reuse_decision.action == "reuse" + else () + ) + all_checkpointed = reuse_decision.action == "reuse" and all( + decision.action == "reuse" for decision in checkpoint_decisions + ) + if all_checkpointed: + checkpoint_store.restore_all_outputs(output_files) + for output_file in output_files: + print(f"Restored from checkpoint: {output_file}") + print(f"Skipping {substep_id} (restored from checkpoint)") + if checkpoint_stats is not None: + checkpoint_stats.record( + expected_outputs=expected_count, + valid_reused_outputs=expected_count, + ) + _record_reuse_manifest( + checkpoint_store=checkpoint_store, + identity_material=identity_material, + reuse_decision=reuse_decision, + checkpoint_decisions=checkpoint_decisions, + ) + return substep_id + + missing_or_invalid = sum( + 1 for decision in checkpoint_decisions if decision.action != "reuse" + ) + action() + for output_file in output_files: + saved = checkpoint_store.save_output(output_file) + if saved: + print(f"Checkpointed: {output_file}") + saved_decisions = checkpoint_store.decisions_for(output_files) + _record_reuse_manifest( + checkpoint_store=checkpoint_store, + identity_material=identity_material, + reuse_decision=reuse_decision, + checkpoint_decisions=saved_decisions, + ) + if checkpoint_stats is not None: + checkpoint_stats.record( + expected_outputs=expected_count, + recomputed_outputs=expected_count, + invalid_outputs=missing_or_invalid, + ) + return substep_id + + +def _record_reuse_manifest( + *, + checkpoint_store: CheckpointStore, + identity_material: Stage1IdentityMaterial, + reuse_decision: Stage1ReuseDecision, + checkpoint_decisions: tuple, +) -> None: + if not checkpoint_decisions or any( + decision.action != "reuse" for decision in checkpoint_decisions + ): + return + checkpoint_store.record_reuse_manifest( + Stage1ReuseManifestRecord( + substep_id=identity_material.substep_id, + identity_key=identity_material.identity_key, + identity_fingerprint=identity_material.fingerprint(), + identity_material=identity_material.to_dict(), + reuse_decision=reuse_decision.to_dict(), + checkpoint_summary=_checkpoint_summary( + tuple(decision.to_dict() for decision in checkpoint_decisions) + ), + ) + ) + + +def _script_identity_material( + *, + script_path: str, + output_files: list[str], + branch: str, +) -> Stage1IdentityMaterial: + substep_id = stage_1_substep_id_for_script(script_path) + return _identity_material_for_substep( + substep_id=substep_id, + identity_key=_stage_1_identity_key( + substep_id=substep_id, + execution_id=f"script:{script_path}", + output_files=output_files, + ), + inputs={"script_path": script_path}, + output_files=output_files, + branch=branch, + artifact_specs=_artifact_spec_identity( + script_path=script_path, + output_files=output_files, + ), + ) + + def _run_checkpointed_substep( *, coordinator: Stage1Coordinator | None, @@ -457,6 +946,35 @@ def _run_checkpointed_substep( checkpoint_stats: CheckpointStats | None = None, ) -> str: command_results: list[DatasetCommandResult] = [] + output_list = output_files if isinstance(output_files, list) else [output_files] + if coordinator is None: + return run_script_with_checkpoint( + script_path, + output_files, + branch, + volume, + env=env, + log_file=log_file, + checkpoint_stats=checkpoint_stats, + command_results=command_results, + ) + + checkpoint_store = _checkpoint_store(branch, volume) + identity_material = _script_identity_material( + script_path=script_path, + output_files=output_list, + branch=branch, + ) + reuse_decision = _reuse_decision_for_material( + identity_material, + checkpoint_store=checkpoint_store, + run_id=env.get(RUN_ID_ENV), + ) + checkpoint_decisions = ( + checkpoint_store.decisions_for(output_list) + if reuse_decision.action == "reuse" + else () + ) def action() -> str: return run_script_with_checkpoint( @@ -468,10 +986,11 @@ def action() -> str: log_file=log_file, checkpoint_stats=checkpoint_stats, command_results=command_results, + checkpoint_store=checkpoint_store, + reuse_decision=reuse_decision, + identity_material=identity_material, ) - if coordinator is None: - return action() substep_id = stage_1_substep_id_for_script(script_path) return coordinator.run_substep( substep_id, @@ -480,6 +999,10 @@ def action() -> str: command_names=(script_path,), command_results=command_results, artifact_paths=_output_paths(output_files), + reuse_decision=reuse_decision.to_dict(), + checkpoint_decisions=tuple( + decision.to_dict() for decision in checkpoint_decisions + ), aggregate=True, ) @@ -584,6 +1107,7 @@ def write_dataset_build_contract( package_version: str = DATA_PACKAGE_VERSION, branch: str = "unknown", diagnostics: tuple = (), + stage_1_status_metadata: Mapping[str, Any] | None = None, ) -> StageContract: """Write the Stage 1 semantic handoff contract next to copied artifacts.""" context = DatasetBuildContext( @@ -603,6 +1127,7 @@ def write_dataset_build_contract( skip_enhanced_cps=skip_enhanced_cps, skip_stage_5=skip_stage_5, diagnostics=diagnostics, + stage_1_status_metadata=stage_1_status_metadata, ) @@ -692,14 +1217,11 @@ def build_datasets( os.chdir("/root/policyengine-us-data") # Clean stale checkpoints from other commits - branch_dir = Path(VOLUME_MOUNT) / branch - if branch_dir.exists(): - current_commit = get_current_commit() - for entry in branch_dir.iterdir(): - if entry.is_dir() and entry.name != current_commit: - shutil.rmtree(entry) - print(f"Removed stale checkpoint dir: {entry.name[:12]}") - checkpoint_volume.commit() + for removed_checkpoint in _checkpoint_store( + branch, + checkpoint_volume, + ).cleanup_other_commits(): + print(f"Removed stale checkpoint dir: {removed_checkpoint.name[:12]}") # Open persistent build log with provenance header commit = get_current_commit() @@ -743,29 +1265,48 @@ def record_skipped_script(script: str, reason: str) -> None: ) raw_data_command_results: list[DatasetCommandResult] = [] + raw_data_outputs = _raw_data_outputs() + raw_data_checkpoint_store = _checkpoint_store(branch, checkpoint_volume) + raw_data_identity_material = _raw_data_identity_material(branch=branch) + raw_data_reuse_decision = _raw_data_reuse_decision( + branch=branch, + checkpoint_store=raw_data_checkpoint_store, + run_id=run_id, + ) + raw_data_checkpoint_decisions = ( + raw_data_checkpoint_store.decisions_for(raw_data_outputs) + if raw_data_reuse_decision.action == "reuse" + else () + ) def run_raw_data_download() -> None: - run_script( - "policyengine_us_data/storage/download_prerequisites.py", - env=env, - log_file=log_file, - command_results=raw_data_command_results, - ) - env["PYTHONUNBUFFERED"] = "1" - log_file.write(f"\n{'=' * 60}\nStarting make database...\n{'=' * 60}\n") - log_file.flush() - run_script_logged( - ["make", "database"], - log_file, - env, - command_results=raw_data_command_results, - ) - # Checkpoint policy_data.db immediately after build so it survives - # test failures and can be restored on retries. - save_checkpoint( - branch, - "policyengine_us_data/storage/calibration/policy_data.db", - checkpoint_volume, + def build_raw_data() -> None: + run_script( + "policyengine_us_data/storage/download_prerequisites.py", + env=env, + log_file=log_file, + command_results=raw_data_command_results, + ) + env["PYTHONUNBUFFERED"] = "1" + log_file.write(f"\n{'=' * 60}\nStarting make database...\n{'=' * 60}\n") + log_file.flush() + run_script_logged( + ["make", "database"], + log_file, + env, + command_results=raw_data_command_results, + ) + + run_command_group_with_checkpoint( + substep_id="1a_raw_data_download", + output_files=raw_data_outputs, + branch=branch, + volume=checkpoint_volume, + action=build_raw_data, + checkpoint_stats=checkpoint_stats, + checkpoint_store=raw_data_checkpoint_store, + reuse_decision=raw_data_reuse_decision, + identity_material=raw_data_identity_material, ) try: @@ -773,12 +1314,13 @@ def run_raw_data_download() -> None: "1a_raw_data_download", stage_1_substep_title("1a_raw_data_download"), run_raw_data_download, - command_names=( - "policyengine_us_data/storage/download_prerequisites.py", - "make database", - ), + command_names=_raw_data_command_names(), command_results=raw_data_command_results, - artifact_paths=("policyengine_us_data/storage/calibration/policy_data.db",), + artifact_paths=raw_data_outputs, + reuse_decision=raw_data_reuse_decision.to_dict(), + checkpoint_decisions=tuple( + decision.to_dict() for decision in raw_data_checkpoint_decisions + ), aggregate=True, ) @@ -1038,6 +1580,7 @@ def run_stage_base_handoff() -> None: package_version=version, branch=branch, diagnostics=diagnostics, + stage_1_status_metadata=_stage_1_status_metadata(coordinator), ) pipeline_volume.commit() print("Pipeline artifacts committed to shared volume") diff --git a/policyengine_us_data/build_datasets/__init__.py b/policyengine_us_data/build_datasets/__init__.py index 48610874e..75308fa9f 100644 --- a/policyengine_us_data/build_datasets/__init__.py +++ b/policyengine_us_data/build_datasets/__init__.py @@ -1,5 +1,8 @@ """Canonical Stage 1 dataset-build specifications.""" +from importlib import import_module +from typing import Any + from .artifacts import ( DatasetArtifactSpec, STAGE_1_ARTIFACT_SPECS, @@ -52,6 +55,9 @@ __all__ = [ "ARTIFACT_SCHEMA_VERSION", + "CheckpointDecision", + "CheckpointReuseSummary", + "CheckpointStore", "CommandBackedSubstepRunner", "CommandRunner", "DatasetArtifactSpec", @@ -70,6 +76,11 @@ "SourceDatasetSchemaSummaryWriter", "Stage1Coordinator", "Stage1ErrorRecord", + "Stage1IdentityMaterial", + "Stage1RerunPlanner", + "Stage1ReuseDecision", + "Stage1ReuseManifest", + "Stage1ReuseManifestRecord", "Stage1StatusRecorder", "Stage1StatusReadError", "Stage1StatusEvent", @@ -91,3 +102,25 @@ "stage_1_step_specs", "write_stage_1_diagnostics", ] + +_LAZY_EXPORTS = { + "CheckpointDecision": (".checkpoints", "CheckpointDecision"), + "CheckpointReuseSummary": (".checkpoints", "CheckpointReuseSummary"), + "CheckpointStore": (".checkpoints", "CheckpointStore"), + "Stage1IdentityMaterial": (".rerun", "Stage1IdentityMaterial"), + "Stage1RerunPlanner": (".rerun", "Stage1RerunPlanner"), + "Stage1ReuseDecision": (".rerun", "Stage1ReuseDecision"), + "Stage1ReuseManifest": (".rerun", "Stage1ReuseManifest"), + "Stage1ReuseManifestRecord": (".rerun", "Stage1ReuseManifestRecord"), +} + + +def __getattr__(name: str) -> Any: + """Load checkpoint and rerun exports without package-import cycles.""" + + if name not in _LAZY_EXPORTS: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + module_name, attribute = _LAZY_EXPORTS[name] + value = getattr(import_module(module_name, __name__), attribute) + globals()[name] = value + return value diff --git a/policyengine_us_data/build_datasets/checkpoints.py b/policyengine_us_data/build_datasets/checkpoints.py new file mode 100644 index 000000000..d9556a67f --- /dev/null +++ b/policyengine_us_data/build_datasets/checkpoints.py @@ -0,0 +1,255 @@ +"""Checkpoint adapter for Stage 1 dataset-build execution.""" + +from __future__ import annotations + +import json +import shutil +import threading +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +from .rerun import ( + STAGE_1_REUSE_MANIFEST_FILENAME, + Stage1ReuseManifest, + Stage1ReuseManifestRecord, +) + + +CheckpointAction = Literal["reuse", "recompute", "blocked"] +_manifest_lock = threading.Lock() + + +@dataclass(frozen=True, kw_only=True) +class CheckpointDecision: + """Physical checkpoint decision for one expected output.""" + + output_file: str + checkpoint_path: Path + action: CheckpointAction + reason: str + size_bytes: int = 0 + + def to_dict(self) -> dict[str, object]: + """Return a JSON-compatible checkpoint decision.""" + + return { + "output_file": self.output_file, + "checkpoint_path": str(self.checkpoint_path), + "action": self.action, + "reason": self.reason, + "size_bytes": self.size_bytes, + } + + +@dataclass(frozen=True, kw_only=True) +class CheckpointReuseSummary: + """Checkpoint counter summary compatible with existing Stage 1 stats.""" + + expected_outputs: int + valid_reused_outputs: int + recomputed_outputs: int + invalid_outputs: int + + @classmethod + def from_decisions( + cls, + decisions: Sequence[CheckpointDecision], + *, + recomputed: bool, + ) -> "CheckpointReuseSummary": + """Build prior-compatible counters from checkpoint decisions.""" + + reusable = sum(decision.action == "reuse" for decision in decisions) + invalid = sum(decision.action != "reuse" for decision in decisions) + return cls( + expected_outputs=len(decisions), + valid_reused_outputs=reusable if not recomputed else 0, + recomputed_outputs=len(decisions) if recomputed else 0, + invalid_outputs=invalid, + ) + + def to_dict(self) -> dict[str, int]: + """Return counters in the Stage 1 contract shape.""" + + return { + "expected_outputs": self.expected_outputs, + "valid_reused_outputs": self.valid_reused_outputs, + "recomputed_outputs": self.recomputed_outputs, + "invalid_outputs": self.invalid_outputs, + } + + +@dataclass(frozen=True, kw_only=True) +class CheckpointStore: + """Adapter around the Stage 1 physical checkpoint volume layout.""" + + root: Path + branch: str + commit_sha: str + commit: Callable[[], None] | None = None + + def checkpoint_path(self, output_file: str) -> Path: + """Return the checkpoint path for an output file.""" + + return self.root / self.branch / self.commit_sha / Path(output_file).name + + def reuse_manifest_path(self) -> Path: + """Return the checkpoint-scoped Stage 1 reuse manifest path.""" + + return ( + self.root / self.branch / self.commit_sha / STAGE_1_REUSE_MANIFEST_FILENAME + ) + + def load_reuse_manifest(self) -> Stage1ReuseManifest: + """Load semantic checkpoint identity, failing closed on invalid data.""" + + path = self.reuse_manifest_path() + if not path.exists(): + return Stage1ReuseManifest.empty( + branch=self.branch, + commit_sha=self.commit_sha, + ) + try: + payload = json.loads(path.read_text()) + if not isinstance(payload, dict): + raise ValueError("Stage 1 reuse manifest must be an object") + return Stage1ReuseManifest.from_dict( + payload, + branch=self.branch, + commit_sha=self.commit_sha, + ) + except (OSError, TypeError, ValueError, json.JSONDecodeError): + return Stage1ReuseManifest.empty( + branch=self.branch, + commit_sha=self.commit_sha, + ) + + def write_reuse_manifest(self, manifest: Stage1ReuseManifest) -> Path: + """Persist the Stage 1 reuse manifest for this checkpoint scope.""" + + path = self.reuse_manifest_path() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(manifest.to_dict(), indent=2, sort_keys=True) + "\n") + if self.commit is not None: + self.commit() + return path + + def record_reuse_manifest( + self, + record: Stage1ReuseManifestRecord, + ) -> Path: + """Merge and persist one substep identity record.""" + + with _manifest_lock: + manifest = self.load_reuse_manifest().with_record(record) + return self.write_reuse_manifest(manifest) + + def decision_for(self, output_file: str) -> CheckpointDecision: + """Return the physical checkpoint decision for one output.""" + + path = self.checkpoint_path(output_file) + if not path.exists(): + return CheckpointDecision( + output_file=output_file, + checkpoint_path=path, + action="recompute", + reason="missing", + ) + size = path.stat().st_size + if size <= 0: + return CheckpointDecision( + output_file=output_file, + checkpoint_path=path, + action="recompute", + reason="empty", + size_bytes=size, + ) + return CheckpointDecision( + output_file=output_file, + checkpoint_path=path, + action="reuse", + reason="valid", + size_bytes=size, + ) + + def decisions_for( + self, output_files: Sequence[str] + ) -> tuple[CheckpointDecision, ...]: + """Return physical checkpoint decisions for expected outputs.""" + + return tuple(self.decision_for(output_file) for output_file in output_files) + + def all_outputs_reusable(self, output_files: Sequence[str]) -> bool: + """Return true only when every expected output has a valid checkpoint.""" + + return all( + decision.action == "reuse" for decision in self.decisions_for(output_files) + ) + + def restore_output(self, output_file: str) -> bool: + """Restore one checkpointed output if it is valid.""" + + decision = self.decision_for(output_file) + if decision.action != "reuse": + return False + local_path = Path(output_file) + local_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(decision.checkpoint_path, local_path) + return True + + def restore_all_outputs(self, output_files: Sequence[str]) -> bool: + """Restore outputs only when all expected checkpoints are valid.""" + + if not self.all_outputs_reusable(output_files): + return False + for output_file in output_files: + self.restore_output(output_file) + return True + + def save_output(self, output_file: str) -> bool: + """Save one local output to the checkpoint store if it is non-empty.""" + + local_path = Path(output_file) + if not local_path.exists() or local_path.stat().st_size <= 0: + return False + checkpoint_path = self.checkpoint_path(output_file) + checkpoint_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(local_path, checkpoint_path) + if self.commit is not None: + self.commit() + return True + + def cleanup_branch(self) -> bool: + """Remove all checkpoint attempts for this branch.""" + + branch_dir = self.root / self.branch + if not branch_dir.exists(): + return False + shutil.rmtree(branch_dir) + if self.commit is not None: + self.commit() + return True + + def cleanup_other_commits(self) -> tuple[Path, ...]: + """Remove stale checkpoint directories for other commits in the branch.""" + + branch_dir = self.root / self.branch + if not branch_dir.exists(): + return () + removed: list[Path] = [] + for entry in branch_dir.iterdir(): + if entry.is_dir() and entry.name != self.commit_sha: + shutil.rmtree(entry) + removed.append(entry) + if removed and self.commit is not None: + self.commit() + return tuple(removed) + + +__all__ = [ + "CheckpointDecision", + "CheckpointReuseSummary", + "CheckpointStore", +] diff --git a/policyengine_us_data/build_datasets/contracts.py b/policyengine_us_data/build_datasets/contracts.py index 13a627daf..781f3c201 100644 --- a/policyengine_us_data/build_datasets/contracts.py +++ b/policyengine_us_data/build_datasets/contracts.py @@ -26,6 +26,7 @@ def build( skip_enhanced_cps: bool, skip_stage_5: bool = False, diagnostics: Sequence[object] = (), + stage_1_status_metadata: Mapping[str, object] | None = None, ): """Build the Stage 1 handoff contract from staged artifacts.""" @@ -47,6 +48,7 @@ def build( skip_enhanced_cps=skip_enhanced_cps, skip_stage_5=skip_stage_5, diagnostics=tuple(diagnostics), + stage_1_status_metadata=stage_1_status_metadata, ) def write(self, **kwargs): diff --git a/policyengine_us_data/build_datasets/coordinator.py b/policyengine_us_data/build_datasets/coordinator.py index 416b79eae..27dc0c4f6 100644 --- a/policyengine_us_data/build_datasets/coordinator.py +++ b/policyengine_us_data/build_datasets/coordinator.py @@ -59,6 +59,8 @@ class _SubstepAggregate: command_names: list[str] = field(default_factory=list) command_results: list[DatasetCommandResult] = field(default_factory=list) artifact_paths: list[str | Path] = field(default_factory=list) + reuse_decisions: list[Mapping[str, Any]] = field(default_factory=list) + checkpoint_decisions: list[Mapping[str, Any]] = field(default_factory=list) metadata: dict[str, Any] = field(default_factory=dict) skip_reasons: list[str] = field(default_factory=list) skipped: bool = False @@ -89,6 +91,8 @@ def run_substep( command_names: Sequence[str] = (), command_results: Sequence[DatasetCommandResult] = (), artifact_paths: Sequence[str | Path] = (), + reuse_decision: Mapping[str, Any] | None = None, + checkpoint_decisions: Sequence[Mapping[str, Any]] = (), skip: bool = False, skip_reason: str | None = None, metadata: Mapping[str, Any] | None = None, @@ -107,6 +111,8 @@ def run_substep( command_names=command_names, command_results=command_results, artifact_paths=artifact_paths, + reuse_decision=reuse_decision, + checkpoint_decisions=checkpoint_decisions, skip=skip, skip_reason=skip_reason, metadata=metadata, @@ -116,6 +122,8 @@ def run_substep( runner=runner, command_names=command_names, command_results=command_results, + reuse_decision=reuse_decision, + checkpoint_decisions=checkpoint_decisions, skip_reason=skip_reason, metadata=metadata, ) @@ -154,6 +162,8 @@ def run_substep( command_names=command_names, command_results=captured_command_results, artifact_paths=artifact_paths, + reuse_decision=reuse_decision, + checkpoint_decisions=checkpoint_decisions, error=error, metadata=metadata, ) @@ -162,11 +172,17 @@ def run_substep( result = self._result( runner=runner, - status="completed", + status=_successful_status( + reuse_decision=reuse_decision, + checkpoint_decisions=checkpoint_decisions, + command_results=command_results, + ), started_dt=started_dt, command_names=command_names, command_results=tuple(command_results), artifact_paths=artifact_paths, + reuse_decision=reuse_decision, + checkpoint_decisions=checkpoint_decisions, metadata=metadata, ) self._record(result) @@ -195,6 +211,8 @@ def _run_aggregated_substep( command_names: Sequence[str], command_results: Sequence[DatasetCommandResult], artifact_paths: Sequence[str | Path], + reuse_decision: Mapping[str, Any] | None, + checkpoint_decisions: Sequence[Mapping[str, Any]], skip: bool, skip_reason: str | None, metadata: Mapping[str, Any] | None, @@ -204,6 +222,8 @@ def _run_aggregated_substep( runner=runner, command_names=command_names, command_results=command_results, + reuse_decision=reuse_decision, + checkpoint_decisions=checkpoint_decisions, skip_reason=skip_reason, metadata=metadata, ) @@ -232,6 +252,8 @@ def _run_aggregated_substep( command_names=command_names, command_results=captured_command_results, artifact_paths=artifact_paths, + reuse_decision=reuse_decision, + checkpoint_decisions=checkpoint_decisions, error=error, metadata=metadata, ) @@ -244,6 +266,8 @@ def _run_aggregated_substep( command_names=command_names, command_results=tuple(command_results), artifact_paths=artifact_paths, + reuse_decision=reuse_decision, + checkpoint_decisions=checkpoint_decisions, metadata=metadata, ) return value @@ -280,6 +304,8 @@ def _record_aggregate_skip( runner: CommandBackedSubstepRunner, command_names: Sequence[str], command_results: Sequence[DatasetCommandResult], + reuse_decision: Mapping[str, Any] | None, + checkpoint_decisions: Sequence[Mapping[str, Any]], skip_reason: str | None, metadata: Mapping[str, Any] | None, ) -> None: @@ -287,6 +313,8 @@ def _record_aggregate_skip( state = self._aggregate_state(runner) _extend_unique(state.command_names, command_names) state.command_results.extend(command_results) + _append_reuse_decision(state, reuse_decision) + state.checkpoint_decisions.extend(checkpoint_decisions) state.metadata.update(dict(metadata or {})) state.skipped = True if skip_reason is not None and skip_reason not in state.skip_reasons: @@ -300,6 +328,8 @@ def _update_aggregate_success( command_names: Sequence[str], command_results: Sequence[DatasetCommandResult], artifact_paths: Sequence[str | Path], + reuse_decision: Mapping[str, Any] | None, + checkpoint_decisions: Sequence[Mapping[str, Any]], metadata: Mapping[str, Any] | None, ) -> None: with self._lock: @@ -309,6 +339,8 @@ def _update_aggregate_success( _extend_unique(state.command_names, command_names) state.command_results.extend(command_results) state.artifact_paths.extend(artifact_paths) + _append_reuse_decision(state, reuse_decision) + state.checkpoint_decisions.extend(checkpoint_decisions) state.metadata.update(dict(metadata or {})) def _finish_aggregate_failure( @@ -319,6 +351,8 @@ def _finish_aggregate_failure( command_names: Sequence[str], command_results: Sequence[DatasetCommandResult], artifact_paths: Sequence[str | Path], + reuse_decision: Mapping[str, Any] | None, + checkpoint_decisions: Sequence[Mapping[str, Any]], error: Stage1ErrorRecord, metadata: Mapping[str, Any] | None, ) -> _SubstepAggregate: @@ -329,6 +363,8 @@ def _finish_aggregate_failure( _extend_unique(state.command_names, command_names) state.command_results.extend(command_results) state.artifact_paths.extend(artifact_paths) + _append_reuse_decision(state, reuse_decision) + state.checkpoint_decisions.extend(checkpoint_decisions) state.metadata.update(dict(metadata or {})) state.error = error return state @@ -353,6 +389,8 @@ def _result_from_aggregate( completed_dt = state.completed_dt or datetime.now(timezone.utc) if state.error is not None: status = "failed" + elif _aggregate_was_reused(state): + status = "reused" elif state.started_dt is None and state.skipped: status = "skipped" else: @@ -362,6 +400,10 @@ def _result_from_aggregate( metadata["skip_reasons"] = list(state.skip_reasons) if len(state.skip_reasons) == 1: metadata["skip_reason"] = state.skip_reasons[0] + if state.reuse_decisions: + metadata["reuse_decisions"] = list(state.reuse_decisions) + if state.checkpoint_decisions: + metadata["checkpoint_decisions"] = list(state.checkpoint_decisions) duration_s = ( (completed_dt - state.started_dt).total_seconds() if state.started_dt is not None @@ -377,6 +419,8 @@ def _result_from_aggregate( command_names=tuple(state.command_names), command_results=tuple(state.command_results), artifact_paths=_existing_artifact_paths(state.artifact_paths), + reuse_decision=_aggregate_reuse_decision(state), + checkpoint_decisions=tuple(state.checkpoint_decisions), error=state.error, metadata=metadata, ) @@ -387,6 +431,8 @@ def _skipped_result( runner: CommandBackedSubstepRunner, command_names: Sequence[str], command_results: Sequence[DatasetCommandResult], + reuse_decision: Mapping[str, Any] | None, + checkpoint_decisions: Sequence[Mapping[str, Any]], skip_reason: str | None, metadata: Mapping[str, Any] | None, ) -> DatasetSubstepResult: @@ -400,6 +446,8 @@ def _skipped_result( duration_s=None, command_names=tuple(command_names), command_results=tuple(command_results), + reuse_decision=reuse_decision, + checkpoint_decisions=tuple(checkpoint_decisions), metadata={**dict(metadata or {}), "skip_reason": skip_reason}, ) @@ -412,6 +460,8 @@ def _result( command_names: Sequence[str], command_results: Sequence[DatasetCommandResult], artifact_paths: Sequence[str | Path], + reuse_decision: Mapping[str, Any] | None, + checkpoint_decisions: Sequence[Mapping[str, Any]], error: Stage1ErrorRecord | None = None, metadata: Mapping[str, Any] | None = None, ) -> DatasetSubstepResult: @@ -426,6 +476,8 @@ def _result( command_names=tuple(command_names), command_results=tuple(command_results), artifact_paths=_existing_artifact_paths(artifact_paths), + reuse_decision=reuse_decision, + checkpoint_decisions=tuple(checkpoint_decisions), error=error, metadata=dict(metadata or {}), ) @@ -436,7 +488,7 @@ def _record(self, result: DatasetSubstepResult) -> None: status=result.status, created_at=result.completed_at, message=f"{result.title}: {result.status}", - metadata=dict(result.metadata), + metadata=_result_event_metadata(result), ) with self._lock: self.results.append(result) @@ -482,6 +534,89 @@ def _extend_unique(target: list[str], values: Sequence[str]) -> None: target.append(value) +def _append_reuse_decision( + state: _SubstepAggregate, + reuse_decision: Mapping[str, Any] | None, +) -> None: + if reuse_decision is None: + return + decision = dict(reuse_decision) + if decision not in state.reuse_decisions: + state.reuse_decisions.append(decision) + + +def _result_event_metadata(result: DatasetSubstepResult) -> dict[str, Any]: + metadata = dict(result.metadata) + if result.reuse_decision is not None: + metadata["reuse_decision"] = dict(result.reuse_decision) + if result.checkpoint_decisions: + metadata["checkpoint_decisions"] = [ + dict(decision) for decision in result.checkpoint_decisions + ] + return metadata + + +def _aggregate_reuse_decision( + state: _SubstepAggregate, +) -> Mapping[str, Any] | None: + if not state.reuse_decisions: + return None + if len(state.reuse_decisions) == 1: + return state.reuse_decisions[0] + return { + "substep_id": state.substep_id, + "decisions": list(state.reuse_decisions), + } + + +def _successful_status( + *, + reuse_decision: Mapping[str, Any] | None, + checkpoint_decisions: Sequence[Mapping[str, Any]], + command_results: Sequence[DatasetCommandResult], +) -> str: + if command_results: + return "completed" + if _reuse_action(reuse_decision) != "reuse": + return "completed" + if not checkpoint_decisions: + return "completed" + if all( + dict(decision).get("action") == "reuse" for decision in checkpoint_decisions + ): + return "reused" + return "completed" + + +def _aggregate_was_reused(state: _SubstepAggregate) -> bool: + return ( + _successful_status( + reuse_decision=_aggregate_reuse_decision(state), + checkpoint_decisions=state.checkpoint_decisions, + command_results=state.command_results, + ) + == "reused" + ) + + +def _reuse_action(reuse_decision: Mapping[str, Any] | None) -> object: + if reuse_decision is None: + return None + decision = dict(reuse_decision) + if "action" in decision: + return decision["action"] + nested = decision.get("decisions") + if isinstance(nested, list) and nested: + actions = { + item.get("action") + for item in nested + if isinstance(item, Mapping) and "action" in item + } + if actions == {"reuse"}: + return "reuse" + return None + + def _command_results_with_exception( command_results: Sequence[DatasetCommandResult], exc: Exception, diff --git a/policyengine_us_data/build_datasets/rerun.py b/policyengine_us_data/build_datasets/rerun.py new file mode 100644 index 000000000..97eb8804e --- /dev/null +++ b/policyengine_us_data/build_datasets/rerun.py @@ -0,0 +1,292 @@ +"""Rerun and semantic reuse planning for Stage 1 dataset builds.""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from typing import Any, Literal + +from policyengine_us_data.stage_contracts.fingerprints import fingerprint_material + + +Stage1ReuseAction = Literal["reuse", "recompute", "blocked"] +STAGE_1_REUSE_MANIFEST_FILENAME = "stage_1_reuse_manifest.json" +STAGE_1_REUSE_MANIFEST_SCHEMA_VERSION = "stage-1-reuse-manifest-v1" + + +def _json_safe(value: Any) -> Any: + if isinstance(value, Mapping): + return {str(key): _json_safe(value[key]) for key in sorted(value, key=str)} + if isinstance(value, tuple | list): + return [_json_safe(item) for item in value] + if isinstance(value, str | int | float | bool) or value is None: + return value + return str(value) + + +@dataclass(frozen=True, kw_only=True) +class Stage1IdentityMaterial: + """Semantic identity material for one Stage 1 execution unit.""" + + substep_id: str + identity_key: str + code_sha: str + schema_version: str + inputs: Mapping[str, Any] = field(default_factory=dict) + parameters: Mapping[str, Any] = field(default_factory=dict) + artifact_specs: Sequence[Mapping[str, Any]] = () + upstream_contract_fingerprints: Sequence[str] = () + randomness: Mapping[str, Any] = field(default_factory=dict) + blocked_reason: str | None = None + + def fingerprint(self) -> str: + """Return the deterministic semantic identity fingerprint.""" + + return fingerprint_material(self.to_dict()).value + + def to_dict(self) -> dict[str, Any]: + """Return JSON-safe semantic identity material.""" + + return { + "substep_id": self.substep_id, + "identity_key": self.identity_key, + "inputs": _json_safe(self.inputs), + "parameters": _json_safe(self.parameters), + "artifact_specs": _json_safe(self.artifact_specs), + "code_sha": self.code_sha, + "schema_version": self.schema_version, + "upstream_contract_fingerprints": _json_safe( + self.upstream_contract_fingerprints + ), + "randomness": _json_safe(self.randomness), + } + + +@dataclass(frozen=True, kw_only=True) +class Stage1ReuseDecision: + """Semantic rerun/reuse decision for a Stage 1 substep.""" + + run_id: str + rerun_id: str | None + artifact_namespace: str + substep_id: str + identity_key: str + action: Stage1ReuseAction + reason: str + identity_fingerprint: str + + def to_dict(self) -> dict[str, str | None]: + """Return a JSON-compatible reuse decision.""" + + return { + "run_id": self.run_id, + "rerun_id": self.rerun_id, + "artifact_namespace": self.artifact_namespace, + "substep_id": self.substep_id, + "identity_key": self.identity_key, + "action": self.action, + "reason": self.reason, + "identity_fingerprint": self.identity_fingerprint, + } + + +@dataclass(frozen=True, kw_only=True) +class Stage1ReuseManifestRecord: + """Persisted semantic identity for one checkpointed Stage 1 execution unit.""" + + substep_id: str + identity_key: str + identity_fingerprint: str + identity_material: Mapping[str, Any] + reuse_decision: Mapping[str, Any] | None = None + checkpoint_summary: Mapping[str, int] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-compatible manifest record.""" + + return { + "substep_id": self.substep_id, + "identity_key": self.identity_key, + "identity_fingerprint": self.identity_fingerprint, + "identity_material": _json_safe(self.identity_material), + "reuse_decision": _json_safe(self.reuse_decision or {}), + "checkpoint_summary": _json_safe(self.checkpoint_summary), + } + + @classmethod + def from_dict(cls, payload: Mapping[str, Any]) -> "Stage1ReuseManifestRecord": + """Load one manifest record or raise ValueError for invalid payloads.""" + + substep_id = payload.get("substep_id") + identity_key = payload.get("identity_key") + identity_fingerprint = payload.get("identity_fingerprint") + identity_material = payload.get("identity_material") + if not isinstance(substep_id, str) or not substep_id: + raise ValueError("reuse manifest record missing substep_id") + if not isinstance(identity_key, str) or not identity_key: + raise ValueError("reuse manifest record missing identity_key") + if not isinstance(identity_fingerprint, str) or not identity_fingerprint: + raise ValueError("reuse manifest record missing identity_fingerprint") + if not isinstance(identity_material, Mapping): + raise ValueError("reuse manifest record missing identity_material") + reuse_decision = payload.get("reuse_decision") + checkpoint_summary = payload.get("checkpoint_summary") + return cls( + substep_id=substep_id, + identity_key=identity_key, + identity_fingerprint=identity_fingerprint, + identity_material=dict(identity_material), + reuse_decision=dict(reuse_decision) + if isinstance(reuse_decision, Mapping) + else None, + checkpoint_summary={ + str(key): int(value) + for key, value in dict(checkpoint_summary or {}).items() + } + if isinstance(checkpoint_summary, Mapping) + else {}, + ) + + +@dataclass(frozen=True, kw_only=True) +class Stage1ReuseManifest: + """Checkpoint-scoped semantic identity manifest for Stage 1 reruns.""" + + branch: str + commit_sha: str + records: Mapping[str, Stage1ReuseManifestRecord] = field(default_factory=dict) + schema_version: str = STAGE_1_REUSE_MANIFEST_SCHEMA_VERSION + + @classmethod + def empty(cls, *, branch: str, commit_sha: str) -> "Stage1ReuseManifest": + """Return an empty manifest for a checkpoint scope.""" + + return cls(branch=branch, commit_sha=commit_sha) + + @classmethod + def from_dict( + cls, + payload: Mapping[str, Any], + *, + branch: str, + commit_sha: str, + ) -> "Stage1ReuseManifest": + """Load a manifest or raise ValueError for invalid payloads.""" + + if payload.get("schema_version") != STAGE_1_REUSE_MANIFEST_SCHEMA_VERSION: + raise ValueError("unsupported Stage 1 reuse manifest schema") + records_payload = payload.get("records", ()) + if not isinstance(records_payload, list): + raise ValueError("Stage 1 reuse manifest records must be a list") + records: dict[str, Stage1ReuseManifestRecord] = {} + for record_payload in records_payload: + if not isinstance(record_payload, Mapping): + raise ValueError("Stage 1 reuse manifest contains invalid records") + record = Stage1ReuseManifestRecord.from_dict(record_payload) + if record.identity_key in records: + raise ValueError("Stage 1 reuse manifest contains duplicate keys") + records[record.identity_key] = record + return cls( + branch=str(payload.get("branch") or branch), + commit_sha=str(payload.get("commit_sha") or commit_sha), + records=records, + ) + + def previous_identities(self) -> dict[str, str]: + """Return prior semantic fingerprints keyed by execution identity key.""" + + return { + identity_key: record.identity_fingerprint + for identity_key, record in self.records.items() + } + + def with_record( + self, + record: Stage1ReuseManifestRecord, + ) -> "Stage1ReuseManifest": + """Return a manifest with one record added or replaced.""" + + records = dict(self.records) + records[record.identity_key] = record + return Stage1ReuseManifest( + branch=self.branch, + commit_sha=self.commit_sha, + records=records, + schema_version=self.schema_version, + ) + + def to_dict(self) -> dict[str, Any]: + """Return a deterministic JSON-compatible manifest payload.""" + + return { + "schema_version": self.schema_version, + "branch": self.branch, + "commit_sha": self.commit_sha, + "records": [ + self.records[identity_key].to_dict() + for identity_key in sorted(self.records) + ], + } + + +@dataclass(frozen=True, kw_only=True) +class Stage1RerunPlanner: + """Decide whether Stage 1 substeps may reuse semantic work.""" + + previous_identities: Mapping[str, str] = field(default_factory=dict) + + def decide( + self, + material: Stage1IdentityMaterial, + *, + run_id: str, + rerun_id: str | None = None, + ) -> Stage1ReuseDecision: + """Return a semantic reuse, recompute, or blocked decision.""" + + fingerprint = material.fingerprint() + if material.blocked_reason: + return Stage1ReuseDecision( + run_id=run_id, + rerun_id=rerun_id, + artifact_namespace=run_id, + substep_id=material.substep_id, + identity_key=material.identity_key, + action="blocked", + reason=material.blocked_reason, + identity_fingerprint=fingerprint, + ) + + previous = self.previous_identities.get(material.identity_key) + if previous == fingerprint: + action = "reuse" + reason = "identity_match" + elif previous is None: + action = "recompute" + reason = "no_previous_identity" + else: + action = "recompute" + reason = "identity_mismatch" + + return Stage1ReuseDecision( + run_id=run_id, + rerun_id=rerun_id, + artifact_namespace=run_id, + substep_id=material.substep_id, + identity_key=material.identity_key, + action=action, + reason=reason, + identity_fingerprint=fingerprint, + ) + + +__all__ = [ + "STAGE_1_REUSE_MANIFEST_FILENAME", + "STAGE_1_REUSE_MANIFEST_SCHEMA_VERSION", + "Stage1IdentityMaterial", + "Stage1RerunPlanner", + "Stage1ReuseAction", + "Stage1ReuseDecision", + "Stage1ReuseManifest", + "Stage1ReuseManifestRecord", +] diff --git a/policyengine_us_data/build_datasets/results.py b/policyengine_us_data/build_datasets/results.py index 5afc53079..4b2bf096c 100644 --- a/policyengine_us_data/build_datasets/results.py +++ b/policyengine_us_data/build_datasets/results.py @@ -75,6 +75,8 @@ class DatasetSubstepResult: command_names: tuple[str, ...] = () command_results: tuple[DatasetCommandResult, ...] = () artifact_paths: tuple[str, ...] = () + reuse_decision: Mapping[str, Any] | None = None + checkpoint_decisions: tuple[Mapping[str, Any], ...] = () error: Stage1ErrorRecord | None = None metadata: Mapping[str, Any] = field(default_factory=dict) @@ -100,6 +102,8 @@ def from_dict(cls, data: Mapping[str, Any]) -> "DatasetSubstepResult": for result in command_results ), artifact_paths=_string_tuple(data.get("artifact_paths", ())), + reuse_decision=_optional_mapping(data.get("reuse_decision")), + checkpoint_decisions=_mapping_tuple(data.get("checkpoint_decisions", ())), error=_error_record_from_payload(data.get("error")), metadata=_metadata_mapping(data.get("metadata", {})), ) @@ -117,6 +121,12 @@ def to_dict(self) -> dict[str, Any]: "command_names": list(self.command_names), "command_results": [result.to_dict() for result in self.command_results], "artifact_paths": list(self.artifact_paths), + "reuse_decision": ( + dict(self.reuse_decision) if self.reuse_decision is not None else None + ), + "checkpoint_decisions": [ + dict(decision) for decision in self.checkpoint_decisions + ], "error": self.error.to_dict() if self.error else None, "metadata": dict(self.metadata), } @@ -129,7 +139,7 @@ def _command_execution_status(value: Any) -> CommandExecutionStatus: def _stage_1_substep_status(value: Any) -> Stage1SubstepStatus: - if value in ("started", "completed", "skipped", "failed"): + if value in ("started", "completed", "reused", "skipped", "failed"): return cast(Stage1SubstepStatus, value) raise ValueError(f"Invalid Stage 1 substep status: {value!r}") @@ -146,6 +156,20 @@ def _mapping_payload(value: Any) -> Mapping[str, Any]: return value +def _optional_mapping(value: Any) -> Mapping[str, Any] | None: + if value is None: + return None + return dict(_mapping_payload(value)) + + +def _mapping_tuple(value: Any) -> tuple[Mapping[str, Any], ...]: + if value is None: + return () + if not isinstance(value, Sequence) or isinstance(value, str): + raise TypeError("Expected a sequence") + return tuple(dict(_mapping_payload(item)) for item in value) + + def _metadata_mapping(value: Any) -> Mapping[str, Any]: if not isinstance(value, Mapping): raise TypeError("metadata must be a mapping") diff --git a/policyengine_us_data/build_datasets/status.py b/policyengine_us_data/build_datasets/status.py index 3277561ea..09dcc29a4 100644 --- a/policyengine_us_data/build_datasets/status.py +++ b/policyengine_us_data/build_datasets/status.py @@ -11,7 +11,7 @@ from modal_app.step_manifests.errors import PipelineErrorRecord -Stage1SubstepStatus = Literal["started", "completed", "skipped", "failed"] +Stage1SubstepStatus = Literal["started", "completed", "reused", "skipped", "failed"] def utc_timestamp(value: datetime | None = None) -> str: @@ -187,7 +187,7 @@ def _pipeline_traceback_text(error: Stage1ErrorRecord) -> str: def _stage_1_substep_status(value: Any) -> Stage1SubstepStatus: - if value in ("started", "completed", "skipped", "failed"): + if value in ("started", "completed", "reused", "skipped", "failed"): return cast(Stage1SubstepStatus, value) raise ValueError(f"Invalid Stage 1 substep status: {value!r}") diff --git a/policyengine_us_data/stage_contracts/dataset_build.py b/policyengine_us_data/stage_contracts/dataset_build.py index 53843f953..8ed3095ba 100644 --- a/policyengine_us_data/stage_contracts/dataset_build.py +++ b/policyengine_us_data/stage_contracts/dataset_build.py @@ -39,6 +39,7 @@ def build_dataset_build_output_contract( skip_enhanced_cps: bool = False, skip_stage_5: bool = False, diagnostics: tuple[DiagnosticRef, ...] = (), + stage_1_status_metadata: Mapping[str, Any] | None = None, ) -> StageContract: """Build the Stage 1 handoff contract from copied pipeline artifacts.""" @@ -89,6 +90,7 @@ def build_dataset_build_output_contract( "artifact_directory": str(artifacts_dir), "contract_file": DATASET_BUILD_OUTPUT_CONTRACT_FILENAME, "diagnostic_count": len(diagnostics), + "stage_1_status": stage_1_status_metadata or {}, }, ) diff --git a/tests/unit/test_build_dataset_checkpoints.py b/tests/unit/test_build_dataset_checkpoints.py new file mode 100644 index 000000000..fb0026735 --- /dev/null +++ b/tests/unit/test_build_dataset_checkpoints.py @@ -0,0 +1,182 @@ +import json +from pathlib import Path + +from policyengine_us_data.build_datasets import ( + CheckpointReuseSummary, + CheckpointStore, + Stage1ReuseManifestRecord, +) + + +def _store(tmp_path: Path) -> CheckpointStore: + return CheckpointStore(root=tmp_path, branch="stage-1", commit_sha="abc123") + + +def test_checkpoint_store_scopes_paths_by_branch_and_commit(tmp_path): + store = _store(tmp_path) + + assert store.checkpoint_path("policyengine_us_data/storage/cps_2024.h5") == ( + tmp_path / "stage-1" / "abc123" / "cps_2024.h5" + ) + assert store.reuse_manifest_path() == ( + tmp_path / "stage-1" / "abc123" / "stage_1_reuse_manifest.json" + ) + + +def test_missing_reuse_manifest_loads_empty(tmp_path): + manifest = _store(tmp_path).load_reuse_manifest() + + assert manifest.previous_identities() == {} + + +def test_reuse_manifest_round_trips_deterministically(tmp_path): + store = _store(tmp_path) + record = Stage1ReuseManifestRecord( + substep_id="1b_base_dataset_construction", + identity_key=( + "1b_base_dataset_construction:" + "script:policyengine_us_data/datasets/cps/cps.py:cps_2024.h5" + ), + identity_fingerprint="sha256:abc", + identity_material={"substep_id": "1b_base_dataset_construction"}, + reuse_decision={"action": "recompute", "reason": "no_previous_identity"}, + checkpoint_summary={"total": 1, "reusable": 1}, + ) + + path = store.record_reuse_manifest(record) + loaded = store.load_reuse_manifest() + + assert path == store.reuse_manifest_path() + assert loaded.previous_identities() == { + ( + "1b_base_dataset_construction:" + "script:policyengine_us_data/datasets/cps/cps.py:cps_2024.h5" + ): "sha256:abc" + } + assert json.loads(path.read_text()) == loaded.to_dict() + + +def test_malformed_reuse_manifest_fails_closed(tmp_path): + store = _store(tmp_path) + path = store.reuse_manifest_path() + path.parent.mkdir(parents=True) + path.write_text("{not-json") + + manifest = store.load_reuse_manifest() + + assert manifest.previous_identities() == {} + + +def test_reuse_manifest_without_identity_key_fails_closed(tmp_path): + store = _store(tmp_path) + path = store.reuse_manifest_path() + path.parent.mkdir(parents=True) + path.write_text( + json.dumps( + { + "schema_version": "stage-1-reuse-manifest-v1", + "branch": "stage-1", + "commit_sha": "abc123", + "records": [ + { + "substep_id": "1b_base_dataset_construction", + "identity_fingerprint": "sha256:abc", + "identity_material": { + "substep_id": "1b_base_dataset_construction", + }, + } + ], + } + ) + ) + + manifest = store.load_reuse_manifest() + + assert manifest.previous_identities() == {} + + +def test_checkpoint_store_requires_all_outputs_for_restore(tmp_path, monkeypatch): + store = _store(tmp_path) + checkpoint = store.checkpoint_path("a.txt") + checkpoint.parent.mkdir(parents=True) + checkpoint.write_text("cached") + monkeypatch.chdir(tmp_path) + + assert store.restore_all_outputs(("a.txt", "b.txt")) is False + assert not (tmp_path / "a.txt").exists() + + store.checkpoint_path("b.txt").write_text("cached-b") + assert store.restore_all_outputs(("a.txt", "b.txt")) is True + assert (tmp_path / "a.txt").read_text() == "cached" + assert (tmp_path / "b.txt").read_text() == "cached-b" + + +def test_missing_or_empty_checkpoint_invalidates_reuse(tmp_path): + store = _store(tmp_path) + empty = store.checkpoint_path("empty.txt") + empty.parent.mkdir(parents=True) + empty.write_text("") + + missing_decision = store.decision_for("missing.txt") + empty_decision = store.decision_for("empty.txt") + + assert missing_decision.action == "recompute" + assert missing_decision.reason == "missing" + assert empty_decision.action == "recompute" + assert empty_decision.reason == "empty" + assert store.all_outputs_reusable(("missing.txt", "empty.txt")) is False + + +def test_checkpoint_summary_matches_prior_counters(tmp_path): + store = _store(tmp_path) + checkpoint = store.checkpoint_path("valid.txt") + checkpoint.parent.mkdir(parents=True) + checkpoint.write_text("cached") + decisions = store.decisions_for(("valid.txt", "missing.txt")) + + reused_summary = CheckpointReuseSummary.from_decisions( + decisions, + recomputed=False, + ) + recomputed_summary = CheckpointReuseSummary.from_decisions( + decisions, + recomputed=True, + ) + + assert reused_summary.to_dict() == { + "expected_outputs": 2, + "valid_reused_outputs": 1, + "recomputed_outputs": 0, + "invalid_outputs": 1, + } + assert recomputed_summary.to_dict() == { + "expected_outputs": 2, + "valid_reused_outputs": 0, + "recomputed_outputs": 2, + "invalid_outputs": 1, + } + + +def test_checkpoint_cleanup_removes_only_branch_scope(tmp_path): + store = _store(tmp_path) + (tmp_path / "stage-1" / "abc123").mkdir(parents=True) + (tmp_path / "other-branch" / "abc123").mkdir(parents=True) + + assert store.cleanup_branch() is True + + assert not (tmp_path / "stage-1").exists() + assert (tmp_path / "other-branch" / "abc123").exists() + + +def test_checkpoint_cleanup_removes_only_other_commits(tmp_path): + store = _store(tmp_path) + current = tmp_path / "stage-1" / "abc123" + stale = tmp_path / "stage-1" / "old" + current.mkdir(parents=True) + stale.mkdir(parents=True) + + removed = store.cleanup_other_commits() + + assert removed == (stale,) + assert current.exists() + assert not stale.exists() diff --git a/tests/unit/test_build_dataset_rerun.py b/tests/unit/test_build_dataset_rerun.py new file mode 100644 index 000000000..8c76a7f56 --- /dev/null +++ b/tests/unit/test_build_dataset_rerun.py @@ -0,0 +1,214 @@ +from policyengine_us_data.build_datasets import ( + Stage1Coordinator, + Stage1IdentityMaterial, + Stage1RerunPlanner, + Stage1ReuseManifest, + Stage1ReuseManifestRecord, +) + + +def test_stage_contracts_and_lazy_reuse_exports_import_together(): + import policyengine_us_data.stage_contracts as stage_contracts + from policyengine_us_data.build_datasets import CheckpointStore + + assert stage_contracts.fingerprint_material + assert CheckpointStore + + +def _material(**overrides) -> Stage1IdentityMaterial: + values = { + "substep_id": "1b_base_dataset_construction", + "identity_key": ( + "1b_base_dataset_construction:" + "script:policyengine_us_data/datasets/cps/cps.py:cps_2024.h5" + ), + "code_sha": "abc123", + "schema_version": "stage-1-rerun-v1", + "inputs": {"dataset": "cps"}, + "parameters": {"period": 2024}, + "artifact_specs": ({"filename": "cps_2024.h5"},), + "upstream_contract_fingerprints": ("sha256:upstream",), + "randomness": {"seed": 1}, + } + values.update(overrides) + return Stage1IdentityMaterial(**values) + + +def test_rerun_planner_reuses_matching_identity(): + material = _material() + manifest = Stage1ReuseManifest.empty( + branch="stage-1", + commit_sha="abc123", + ).with_record( + Stage1ReuseManifestRecord( + substep_id=material.substep_id, + identity_key=material.identity_key, + identity_fingerprint=material.fingerprint(), + identity_material=material.to_dict(), + ) + ) + planner = Stage1RerunPlanner(previous_identities=manifest.previous_identities()) + + decision = planner.decide(material, run_id="run-a", rerun_id="attempt-2") + + assert decision.action == "reuse" + assert decision.reason == "identity_match" + assert decision.rerun_id == "attempt-2" + + +def test_rerun_planner_recomputes_without_manifest_identity(): + material = _material() + planner = Stage1RerunPlanner( + previous_identities=Stage1ReuseManifest.empty( + branch="stage-1", + commit_sha="abc123", + ).previous_identities() + ) + + decision = planner.decide(material, run_id="run-a") + + assert decision.action == "recompute" + assert decision.reason == "no_previous_identity" + + +def test_rerun_planner_recomputes_mismatched_parameters(): + previous = _material() + current = _material(parameters={"period": 2025}) + planner = Stage1RerunPlanner( + previous_identities={previous.identity_key: previous.fingerprint()} + ) + + decision = planner.decide(current, run_id="run-a") + + assert decision.action == "recompute" + assert decision.reason == "identity_mismatch" + + +def test_rerun_planner_recomputes_mismatched_schema_or_upstream_fingerprint(): + previous = _material() + planner = Stage1RerunPlanner( + previous_identities={previous.identity_key: previous.fingerprint()} + ) + + schema_decision = planner.decide( + _material(schema_version="stage-1-rerun-v2"), + run_id="run-a", + ) + upstream_decision = planner.decide( + _material(upstream_contract_fingerprints=("sha256:changed",)), + run_id="run-a", + ) + + assert schema_decision.action == "recompute" + assert upstream_decision.action == "recompute" + + +def test_rerun_id_does_not_change_artifact_namespace(): + material = _material() + + decision = Stage1RerunPlanner().decide( + material, + run_id="canonical-run", + rerun_id="attempt-2", + ) + + assert decision.artifact_namespace == "canonical-run" + assert decision.rerun_id == "attempt-2" + + +def test_reuse_manifest_keeps_same_substep_identity_keys_distinct(): + cps = _material( + identity_key=( + "1b_base_dataset_construction:" + "script:policyengine_us_data/datasets/cps/cps.py:cps_2024.h5" + ), + inputs={"dataset": "cps"}, + artifact_specs=({"filename": "cps_2024.h5"},), + ) + puf = _material( + identity_key=( + "1b_base_dataset_construction:" + "script:policyengine_us_data/datasets/puf/puf.py:puf_2024.h5" + ), + inputs={"dataset": "puf"}, + artifact_specs=({"filename": "puf_2024.h5"},), + ) + manifest = ( + Stage1ReuseManifest.empty(branch="stage-1", commit_sha="abc123") + .with_record( + Stage1ReuseManifestRecord( + substep_id=cps.substep_id, + identity_key=cps.identity_key, + identity_fingerprint=cps.fingerprint(), + identity_material=cps.to_dict(), + ) + ) + .with_record( + Stage1ReuseManifestRecord( + substep_id=puf.substep_id, + identity_key=puf.identity_key, + identity_fingerprint=puf.fingerprint(), + identity_material=puf.to_dict(), + ) + ) + ) + + assert manifest.previous_identities() == { + cps.identity_key: cps.fingerprint(), + puf.identity_key: puf.fingerprint(), + } + planner = Stage1RerunPlanner(previous_identities=manifest.previous_identities()) + + assert planner.decide(cps, run_id="run-a").action == "reuse" + assert planner.decide(puf, run_id="run-a").action == "reuse" + + +def test_blocked_decision_serializes_into_substep_status(): + material = _material(blocked_reason="missing upstream contract") + decision = Stage1RerunPlanner().decide(material, run_id="run-a") + coordinator = Stage1Coordinator() + + coordinator.run_substep( + material.substep_id, + "Base dataset construction", + lambda: None, + reuse_decision=decision.to_dict(), + checkpoint_decisions=(), + skip=True, + skip_reason="blocked", + ) + + [result] = coordinator.results + assert decision.action == "blocked" + assert result.reuse_decision == decision.to_dict() + assert coordinator.status_events[-1].metadata["reuse_decision"] == ( + decision.to_dict() + ) + + +def test_reused_decision_becomes_first_class_substep_status(): + material = _material() + decision = Stage1RerunPlanner( + previous_identities={material.identity_key: material.fingerprint()} + ).decide(material, run_id="run-a") + coordinator = Stage1Coordinator() + + coordinator.run_substep( + material.substep_id, + "Base dataset construction", + lambda: None, + reuse_decision=decision.to_dict(), + checkpoint_decisions=( + { + "output_file": "cps_2024.h5", + "checkpoint_path": "/checkpoints/branch/sha/cps_2024.h5", + "action": "reuse", + "reason": "valid", + "size_bytes": 3, + }, + ), + ) + + [result] = coordinator.results + assert result.status == "reused" + assert coordinator.status_events[-1].status == "reused" diff --git a/tests/unit/test_build_dataset_status_store.py b/tests/unit/test_build_dataset_status_store.py index 47c302515..2c6fdd8e1 100644 --- a/tests/unit/test_build_dataset_status_store.py +++ b/tests/unit/test_build_dataset_status_store.py @@ -24,11 +24,24 @@ def test_stage_1_status_recorder_persists_events_results_and_current(tmp_path): result = DatasetSubstepResult( substep_id="1c_extended_cps_puf_clone", title="Extended CPS PUF clone", - status="completed", + status="reused", started_at="2026-05-22T12:00:00Z", completed_at="2026-05-22T12:05:00Z", duration_s=300.0, command_names=("extended-cps",), + reuse_decision={ + "identity_key": "1c_extended_cps_puf_clone:extended-cps", + "action": "reuse", + "reason": "identity_match", + }, + checkpoint_decisions=( + { + "output_file": "extended_cps_2024.h5", + "action": "reuse", + "reason": "valid", + "size_bytes": 10, + }, + ), ) recorder.record_event(event) @@ -49,7 +62,20 @@ def test_stage_1_status_recorder_persists_events_results_and_current(tmp_path): } assert snapshot.events == (snapshot.current,) assert snapshot.results[0].substep_id == "1c_extended_cps_puf_clone" - assert snapshot.results[0].status == "completed" + assert snapshot.results[0].status == "reused" + assert snapshot.results[0].reuse_decision == { + "identity_key": "1c_extended_cps_puf_clone:extended-cps", + "action": "reuse", + "reason": "identity_match", + } + assert snapshot.results[0].checkpoint_decisions == ( + { + "output_file": "extended_cps_2024.h5", + "action": "reuse", + "reason": "valid", + "size_bytes": 10, + }, + ) def test_stage_1_status_recorder_is_best_effort_by_default(tmp_path): @@ -125,6 +151,47 @@ def test_stage_1_coordinator_writes_to_status_recorder(tmp_path): assert snapshot.results[0].substep_id == "1b_base_dataset_construction" +def test_stage_1_coordinator_writes_reuse_metadata_to_status_recorder(tmp_path): + recorder = Stage1StatusRecorder(tmp_path / "runs" / "run-1") + coordinator = Stage1Coordinator(status_recorder=recorder) + + coordinator.run_substep( + "1a_raw_data_download", + "Raw data download", + lambda: None, + reuse_decision={ + "identity_key": "1a_raw_data_download:raw-data", + "action": "reuse", + "reason": "identity_match", + }, + checkpoint_decisions=( + { + "output_file": "policy_data.db", + "action": "reuse", + "reason": "valid", + "size_bytes": 3, + }, + ), + ) + + snapshot = read_stage_1_status_snapshot(tmp_path / "runs" / "run-1") + + assert snapshot.current is not None + assert snapshot.current.status == "reused" + assert snapshot.current.metadata["reuse_decision"]["identity_key"] == ( + "1a_raw_data_download:raw-data" + ) + assert snapshot.current.metadata["checkpoint_decisions"] == [ + { + "output_file": "policy_data.db", + "action": "reuse", + "reason": "valid", + "size_bytes": 3, + } + ] + assert snapshot.results[0].status == "reused" + + def test_stage_1_coordinator_writes_aggregated_status_on_finalize(tmp_path): recorder = Stage1StatusRecorder(tmp_path / "runs" / "run-1") coordinator = Stage1Coordinator(status_recorder=recorder) diff --git a/tests/unit/test_dataset_build_stage_contract.py b/tests/unit/test_dataset_build_stage_contract.py index 3e69cb691..cc5ad4db2 100644 --- a/tests/unit/test_dataset_build_stage_contract.py +++ b/tests/unit/test_dataset_build_stage_contract.py @@ -229,3 +229,29 @@ def test_dataset_build_contract_records_diagnostic_refs(tmp_path): assert contract.diagnostics == (diagnostic,) assert contract.metadata["diagnostic_count"] == 1 + + +def test_dataset_build_contract_records_stage_1_status_metadata(tmp_path): + _write_artifacts(tmp_path) + + contract = build_dataset_build_output_contract( + artifacts_dir=tmp_path, + run_id="run-a", + code_sha="abc123", + package_version="1.98.2", + checkpoint_stats={"expected_outputs": 4}, + started_at="2026-05-08T12:00:00Z", + completed_at="2026-05-08T12:01:00Z", + stage_1_status_metadata={ + "substep_results": [ + { + "substep_id": "1b_base_dataset_construction", + "reuse_decision": {"action": "reuse"}, + } + ] + }, + ) + + assert contract.metadata["stage_1_status"]["substep_results"][0][ + "reuse_decision" + ] == {"action": "reuse"} diff --git a/tests/unit/test_modal_data_build.py b/tests/unit/test_modal_data_build.py index 1ebfe3ebd..7144b80a3 100644 --- a/tests/unit/test_modal_data_build.py +++ b/tests/unit/test_modal_data_build.py @@ -1,9 +1,15 @@ import importlib import sys from datetime import datetime, timedelta, timezone +from pathlib import Path from types import ModuleType, SimpleNamespace -from policyengine_us_data.build_datasets import stage_1_script_outputs +from policyengine_us_data.build_datasets import ( + CheckpointStore, + Stage1ReuseDecision, + Stage1ReuseManifestRecord, + stage_1_script_outputs, +) from policyengine_us_data.stage_contracts import read_contract @@ -71,6 +77,29 @@ def test_script_outputs_are_generated_from_stage_1_artifact_specs(): assert data_build.SCRIPT_OUTPUTS == stage_1_script_outputs() +def test_script_identity_material_distinguishes_same_substep_units(): + data_build = _load_data_build_module() + + cps = data_build._script_identity_material( + script_path=data_build.CPS_BUILD_SCRIPT, + output_files=["cps_2024.h5"], + branch="branch", + ) + puf = data_build._script_identity_material( + script_path=data_build.PUF_BUILD_SCRIPT, + output_files=["puf_2024.h5"], + branch="branch", + ) + + assert cps.substep_id == "1b_base_dataset_construction" + assert puf.substep_id == "1b_base_dataset_construction" + assert cps.identity_key != puf.identity_key + assert data_build.CPS_BUILD_SCRIPT in cps.identity_key + assert "cps_2024.h5" in cps.identity_key + assert data_build.PUF_BUILD_SCRIPT in puf.identity_key + assert "puf_2024.h5" in puf.identity_key + + def test_build_datasets_records_stage_base_handoff_substep(tmp_path, monkeypatch): data_build = _load_data_build_module() calls = [] @@ -165,6 +194,10 @@ def fake_run_script_logged(cmd, log_file, env, **kwargs): "1g_stage_base_datasets", ] assert created_coordinators[0].status_recorder is not None + raw_data_kwargs = calls[1][1] + assert raw_data_kwargs["reuse_decision"]["substep_id"] == "1a_raw_data_download" + assert raw_data_kwargs["reuse_decision"]["reason"] == "no_previous_identity" + assert raw_data_kwargs["checkpoint_decisions"] == () stage_base_kwargs = calls[-1][1] assert stage_base_kwargs["command_names"] == ("stage_base_datasets",) assert any( @@ -433,6 +466,342 @@ def test_write_dataset_build_contract_writes_stage_1_handoff(tmp_path): assert contract.parameters["stage_only"] is True +def test_run_script_with_checkpoint_rejects_blocked_reuse_decision( + tmp_path, + monkeypatch, +): + data_build = _load_data_build_module() + + def fail_run_script(*args, **kwargs): + raise AssertionError("blocked decisions must not run commands") + + monkeypatch.setattr(data_build, "run_script", fail_run_script) + decision = Stage1ReuseDecision( + run_id="run-a", + rerun_id="attempt-2", + artifact_namespace="run-a", + substep_id="1b_base_dataset_construction", + identity_key=( + "1b_base_dataset_construction:" + "script:policyengine_us_data/datasets/cps/cps.py:cps_2024.h5" + ), + action="blocked", + reason="identity_mismatch", + identity_fingerprint="sha256:abc", + ) + + try: + data_build.run_script_with_checkpoint( + "policyengine_us_data/datasets/cps/cps.py", + "cps_2024.h5", + "branch", + object(), + checkpoint_store=CheckpointStore( + root=tmp_path, + branch="branch", + commit_sha="abc123", + ), + reuse_decision=decision, + ) + except RuntimeError as exc: + assert "identity_mismatch" in str(exc) + else: + raise AssertionError("blocked reuse decision should fail") + + +def test_command_group_with_checkpoint_recomputes_without_manifest( + tmp_path, + monkeypatch, +): + data_build = _load_data_build_module() + output_file = "policyengine_us_data/storage/calibration/policy_data.db" + store = CheckpointStore( + root=tmp_path / "checkpoints", branch="branch", commit_sha="abc" + ) + checkpoint = store.checkpoint_path(output_file) + checkpoint.parent.mkdir(parents=True) + checkpoint.write_bytes(b"old-db") + stats = data_build.CheckpointStats() + identity_material = data_build._raw_data_identity_material(branch="branch") + + def build_raw_data(): + Path(output_file).parent.mkdir(parents=True) + Path(output_file).write_bytes(b"new-db") + + monkeypatch.chdir(tmp_path) + + data_build.run_command_group_with_checkpoint( + substep_id="1a_raw_data_download", + output_files=(output_file,), + branch="branch", + volume=object(), + action=build_raw_data, + checkpoint_stats=stats, + checkpoint_store=store, + identity_material=identity_material, + ) + + assert (tmp_path / output_file).read_bytes() == b"new-db" + assert stats.snapshot() == { + "expected_outputs": 1, + "valid_reused_outputs": 0, + "recomputed_outputs": 1, + "invalid_outputs": 0, + } + assert store.load_reuse_manifest().previous_identities() == { + identity_material.identity_key: identity_material.fingerprint() + } + + +def test_command_group_with_checkpoint_restores_manifest_matched_outputs( + tmp_path, + monkeypatch, +): + data_build = _load_data_build_module() + output_file = "policyengine_us_data/storage/calibration/policy_data.db" + store = CheckpointStore( + root=tmp_path / "checkpoints", branch="branch", commit_sha="abc" + ) + checkpoint = store.checkpoint_path(output_file) + checkpoint.parent.mkdir(parents=True) + checkpoint.write_bytes(b"db") + stats = data_build.CheckpointStats() + identity_material = data_build._raw_data_identity_material(branch="branch") + store.record_reuse_manifest( + Stage1ReuseManifestRecord( + substep_id=identity_material.substep_id, + identity_key=identity_material.identity_key, + identity_fingerprint=identity_material.fingerprint(), + identity_material=identity_material.to_dict(), + ) + ) + + def fail_action(): + raise AssertionError("reused command groups must not run commands") + + monkeypatch.chdir(tmp_path) + + data_build.run_command_group_with_checkpoint( + substep_id="1a_raw_data_download", + output_files=(output_file,), + branch="branch", + volume=object(), + action=fail_action, + checkpoint_stats=stats, + checkpoint_store=store, + identity_material=identity_material, + ) + + assert (tmp_path / output_file).read_bytes() == b"db" + assert stats.snapshot() == { + "expected_outputs": 1, + "valid_reused_outputs": 1, + "recomputed_outputs": 0, + "invalid_outputs": 0, + } + + +def test_command_group_with_checkpoint_recomputes_missing_physical_checkpoint( + tmp_path, + monkeypatch, +): + data_build = _load_data_build_module() + output_file = "policyengine_us_data/storage/calibration/policy_data.db" + store = CheckpointStore( + root=tmp_path / "checkpoints", branch="branch", commit_sha="abc" + ) + stats = data_build.CheckpointStats() + identity_material = data_build._raw_data_identity_material(branch="branch") + store.record_reuse_manifest( + Stage1ReuseManifestRecord( + substep_id=identity_material.substep_id, + identity_key=identity_material.identity_key, + identity_fingerprint=identity_material.fingerprint(), + identity_material=identity_material.to_dict(), + ) + ) + + def build_raw_data(): + Path(output_file).parent.mkdir(parents=True) + Path(output_file).write_bytes(b"new-db") + + monkeypatch.chdir(tmp_path) + + data_build.run_command_group_with_checkpoint( + substep_id="1a_raw_data_download", + output_files=(output_file,), + branch="branch", + volume=object(), + action=build_raw_data, + checkpoint_stats=stats, + checkpoint_store=store, + identity_material=identity_material, + ) + + assert (tmp_path / output_file).read_bytes() == b"new-db" + assert stats.snapshot() == { + "expected_outputs": 1, + "valid_reused_outputs": 0, + "recomputed_outputs": 1, + "invalid_outputs": 1, + } + + +def test_stage_1_status_metadata_publishes_reuse_reasoning(): + data_build = _load_data_build_module() + + class Record: + def __init__(self, payload): + self.payload = payload + + def to_dict(self): + return self.payload + + coordinator = SimpleNamespace( + results=[ + Record( + { + "substep_id": "1a_raw_data_download", + "title": "Raw Data Download", + "status": "reused", + "reuse_decision": { + "action": "reuse", + "reason": "identity_match", + }, + "checkpoint_decisions": [ + { + "action": "reuse", + "reason": "valid", + } + ], + } + ), + Record( + { + "substep_id": "1b_base_dataset_construction", + "title": "Base Dataset Construction", + "status": "completed", + "reuse_decision": { + "action": "recompute", + "reason": "no_previous_identity", + }, + "checkpoint_decisions": [], + } + ), + Record( + { + "substep_id": "1c_extended_cps_puf_clone", + "title": "Extended CPS And PUF Clone", + "status": "completed", + "reuse_decision": { + "action": "reuse", + "reason": "identity_match", + }, + "checkpoint_decisions": [ + { + "action": "recompute", + "reason": "missing", + } + ], + } + ), + Record( + { + "substep_id": "1d_enhanced_cps_reweighting", + "title": "Enhanced CPS Reweighting", + "status": "completed", + "reuse_decision": { + "action": "recompute", + "reason": "identity_mismatch", + }, + "checkpoint_decisions": [], + } + ), + Record( + { + "substep_id": "1e_stratified_cps", + "title": "Stratified CPS", + "status": "completed", + "reuse_decision": { + "substep_id": "1e_stratified_cps", + "decisions": [ + { + "substep_id": "1e_stratified_cps", + "identity_key": ( + "1e_stratified_cps:enhanced:enhanced_cps_2024.h5" + ), + "action": "reuse", + "reason": "identity_match", + "identity_fingerprint": "sha256:enhanced", + }, + { + "substep_id": "1e_stratified_cps", + "identity_key": ( + "1e_stratified_cps:stratified:stratified_cps_2024.h5" + ), + "action": "recompute", + "reason": "no_previous_identity", + "identity_fingerprint": "sha256:stratified", + }, + ], + }, + "checkpoint_decisions": [], + } + ), + ], + status_events=[], + error_records=[], + ) + + metadata = data_build._stage_1_status_metadata(coordinator) + reasons = { + item["substep_id"]: item["outcome_reason"] + for item in metadata["reuse_reasoning"] + } + + assert reasons["1a_raw_data_download"] == ( + "Prior semantic identity matched and every expected checkpoint output " + "was present and non-empty." + ) + assert reasons["1b_base_dataset_construction"] == ( + "Recomputed because no prior semantic identity manifest record existed " + "for this Stage 1 execution unit." + ) + assert reasons["1c_extended_cps_puf_clone"] == ( + "Recomputed because at least one expected checkpoint output was missing " + "or empty." + ) + assert reasons["1d_enhanced_cps_reweighting"] == ( + "Recomputed because persisted semantic identity did not match the " + "current Stage 1 execution-unit identity." + ) + mixed = next( + item + for item in metadata["reuse_reasoning"] + if item["substep_id"] == "1e_stratified_cps" + ) + assert mixed["outcome_reason"] == ( + "Recomputed with per-identity semantic reuse reasons: " + "identity_match=1, no_previous_identity=1." + ) + assert mixed["identity_decisions"] == [ + { + "substep_id": "1e_stratified_cps", + "identity_key": "1e_stratified_cps:enhanced:enhanced_cps_2024.h5", + "action": "reuse", + "reason": "identity_match", + "identity_fingerprint": "sha256:enhanced", + }, + { + "substep_id": "1e_stratified_cps", + "identity_key": "1e_stratified_cps:stratified:stratified_cps_2024.h5", + "action": "recompute", + "reason": "no_previous_identity", + "identity_fingerprint": "sha256:stratified", + }, + ] + + def test_utc_timestamp_renders_zulu_time_for_build_log(): data_build = _load_data_build_module() budapest_summer = timezone(timedelta(hours=2))