diff --git a/changelog.d/1048.added b/changelog.d/1048.added new file mode 100644 index 000000000..054951215 --- /dev/null +++ b/changelog.d/1048.added @@ -0,0 +1 @@ +Added a Stage 1 dataset-build context, artifact stager, and diagnostic artifact writers for the pipeline handoff. diff --git a/changelog.d/1069.added b/changelog.d/1069.added new file mode 100644 index 000000000..ca3d6ce2a --- /dev/null +++ b/changelog.d/1069.added @@ -0,0 +1 @@ +Added a Stage 1 command runner and substep status boundary around dataset-build execution. diff --git a/changelog.d/1074.added b/changelog.d/1074.added new file mode 100644 index 000000000..d1801445d --- /dev/null +++ b/changelog.d/1074.added @@ -0,0 +1 @@ +Added Stage 1 checkpoint adapter and rerun reuse planning boundaries. diff --git a/changelog.d/1078.added b/changelog.d/1078.added new file mode 100644 index 000000000..95def99f2 --- /dev/null +++ b/changelog.d/1078.added @@ -0,0 +1 @@ +Added ordered Stage 1 validation adapters that emit canonical validation reports. diff --git a/modal_app/data_build.py b/modal_app/data_build.py index e629ed615..08c1ec851 100644 --- a/modal_app/data_build.py +++ b/modal_app/data_build.py @@ -1,5 +1,6 @@ +from __future__ import annotations + import functools -import json import os import shutil import subprocess @@ -22,14 +23,31 @@ from modal_app.images import cpu_image as image # noqa: E402 from policyengine_us_data.__version__ import __version__ as DATA_PACKAGE_VERSION # noqa: E402 -from policyengine_us_data.build_datasets import stage_1_script_outputs # noqa: E402 +from policyengine_us_data.build_datasets import ( # noqa: E402 + CheckpointStore, + CommandRunner, + DatasetCommand, + DatasetBuildContext, + DatasetBuildOutputContractBuilder, + PipelineArtifactStager, + Stage1Coordinator, + Stage1IdentityMaterial, + Stage1RerunPlanner, + Stage1ReuseDecision, + Stage1ValidationResultWriter, + Stage1ValidationRunner, + ValidationTargetCatalog, + stage_1_artifact_specs, + stage_1_script_outputs, + stage_1_substep_id_for_script, + stage_1_substep_title, + write_stage_1_diagnostics, +) from policyengine_us_data.pipeline_metadata import pipeline_node # noqa: E402 from policyengine_us_data.pipeline_schema import PipelineNode # noqa: E402 from policyengine_us_data.stage_contracts import ( # noqa: E402 - DATASET_BUILD_OUTPUT_CONTRACT_FILENAME, StageContract, - build_dataset_build_output_contract, - write_contract, + ValidationReport, ) from policyengine_us_data.utils.run_context import ( # noqa: E402 CANDIDATE_VERSION_ENV, @@ -103,16 +121,6 @@ def snapshot(self) -> dict[str, int]: CPS_BUILD_SCRIPT = "policyengine_us_data/datasets/cps/cps.py" PUF_BUILD_SCRIPT = "policyengine_us_data/datasets/puf/puf.py" -# Post-build validation modules to run individually for checkpoint tracking. -VALIDATION_MODULES = [ - "validation/stage_1/", -] - - -def _python_cmd(*args: str) -> list[str]: - """Build a command that uses the current interpreter.""" - return [sys.executable, *args] - def _utc_timestamp(value: datetime | None = None) -> str: """Render a UTC timestamp for persisted pipeline metadata.""" @@ -164,30 +172,20 @@ def get_current_commit() -> str: def get_checkpoint_path(branch: str, output_file: str) -> Path: """Get the checkpoint path for an output file, scoped by branch and commit.""" - commit = get_current_commit() - return Path(VOLUME_MOUNT) / branch / commit / Path(output_file).name + return _checkpoint_store(branch).checkpoint_path(output_file) def is_checkpointed(branch: str, output_file: str) -> bool: """Check if output file exists in checkpoint volume and is valid.""" - checkpoint_path = get_checkpoint_path(branch, output_file) - if checkpoint_path.exists(): - # Verify file is not empty/corrupted - if checkpoint_path.stat().st_size > 0: - return True - return False + return _checkpoint_store(branch).decision_for(output_file).action == "reuse" def restore_from_checkpoint(branch: str, output_file: str) -> bool: """Restore output file from checkpoint volume if it exists.""" - checkpoint_path = get_checkpoint_path(branch, output_file) - if checkpoint_path.exists() and checkpoint_path.stat().st_size > 0: - local_path = Path(output_file) - local_path.parent.mkdir(parents=True, exist_ok=True) - shutil.copy2(checkpoint_path, local_path) + restored = _checkpoint_store(branch).restore_output(output_file) + if restored: print(f"Restored from checkpoint: {output_file}") - return True - return False + return restored def save_checkpoint( @@ -196,25 +194,35 @@ def save_checkpoint( volume: modal.Volume, ) -> None: """Save output file to checkpoint volume.""" - local_path = Path(output_file) - if local_path.exists() and local_path.stat().st_size > 0: - checkpoint_path = get_checkpoint_path(branch, output_file) - checkpoint_path.parent.mkdir(parents=True, exist_ok=True) - shutil.copy2(local_path, checkpoint_path) - with _volume_lock: - volume.commit() + saved = _checkpoint_store(branch, volume).save_output(output_file) + if saved: print(f"Checkpointed: {output_file}") def cleanup_checkpoints(branch: str, volume: modal.Volume) -> None: """Delete checkpoints for this branch after successful completion.""" - branch_dir = Path(VOLUME_MOUNT) / branch - if branch_dir.exists(): - shutil.rmtree(branch_dir) - volume.commit() + cleaned = _checkpoint_store(branch, volume).cleanup_branch() + if cleaned: print(f"Cleaned up checkpoints for branch: {branch}") +def _checkpoint_store( + branch: str, + volume: modal.Volume | None = None, +) -> CheckpointStore: + def commit_volume() -> None: + if volume is not None: + with _volume_lock: + volume.commit() + + return CheckpointStore( + root=Path(VOLUME_MOUNT), + branch=branch, + commit_sha=get_current_commit(), + commit=commit_volume if volume is not None else None, + ) + + def run_script_logged( cmd: list, log_file: IO, @@ -222,21 +230,19 @@ def run_script_logged( check: bool = True, ) -> subprocess.CompletedProcess: """Run a command, streaming output to both stdout and a log file.""" - proc = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, + command = DatasetCommand( + name=" ".join(cmd), + argv=tuple(cmd), + kind="side_effect", + metadata={"command": cmd}, + ) + result = CommandRunner().run( + command, env=env, + log_file=log_file, + check=check, ) - for line in proc.stdout: - sys.stdout.write(line) - sys.stdout.flush() - log_file.write(line) - proc.wait() - if check and proc.returncode != 0: - raise subprocess.CalledProcessError(proc.returncode, cmd) - return subprocess.CompletedProcess(cmd, proc.returncode) + return subprocess.CompletedProcess(cmd, result.returncode) def run_script( @@ -256,28 +262,22 @@ def run_script( The script_path that was executed. Raises: - subprocess.CalledProcessError: If the script fails. + DatasetCommandError: If the script fails. """ - script = Path(script_path) - if ( - script.suffix == ".py" - and script.parts - and script.parts[0] in {"policyengine_us_data", "modal_app"} - ): - cmd = _python_cmd("-u", "-m", ".".join(script.with_suffix("").parts)) - else: - cmd = _python_cmd("-u", script_path) - if args: - cmd.extend(args) + command = DatasetCommand.from_script( + script_path, + args=tuple(args or ()), + python_executable=sys.executable, + ) run_env = env or os.environ.copy() run_env["PYTHONUNBUFFERED"] = "1" print(f"Starting {script_path}...") if log_file: log_file.write(f"\n{'=' * 60}\nStarting {script_path}...\n{'=' * 60}\n") log_file.flush() - run_script_logged(cmd, log_file, run_env) + CommandRunner().run(command, env=run_env, log_file=log_file) else: - subprocess.run(cmd, check=True, env=run_env) + CommandRunner().run(command, env=run_env) print(f"Completed {script_path}") return script_path @@ -333,6 +333,8 @@ def run_script_with_checkpoint( env: Optional[dict] = None, log_file: IO = None, checkpoint_stats: CheckpointStats | None = None, + checkpoint_store: CheckpointStore | None = None, + reuse_decision: Stage1ReuseDecision | None = None, ) -> str: """Run script if output not checkpointed, then checkpoint result. @@ -352,14 +354,29 @@ def run_script_with_checkpoint( if isinstance(output_files, str): output_files = [output_files] expected_count = len(output_files) + checkpoint_store = checkpoint_store or _checkpoint_store(branch, volume) + reuse_decision = reuse_decision or _compat_reuse_decision( + script_path=script_path, + output_files=output_files, + branch=branch, + ) + if reuse_decision.action == "blocked": + raise RuntimeError( + "Stage 1 checkpoint reuse is blocked for " + f"{script_path}: {reuse_decision.reason}" + ) # Check if ALL outputs are checkpointed - all_checkpointed = all(is_checkpointed(branch, f) for f in output_files) + checkpoint_decisions = checkpoint_store.decisions_for(output_files) + all_checkpointed = reuse_decision.action == "reuse" and all( + decision.action == "reuse" for decision in checkpoint_decisions + ) if all_checkpointed: # Restore all files from checkpoint + checkpoint_store.restore_all_outputs(output_files) for output_file in output_files: - restore_from_checkpoint(branch, output_file) + print(f"Restored from checkpoint: {output_file}") print(f"Skipping {script_path} (restored from checkpoint)") if checkpoint_stats is not None: checkpoint_stats.record( @@ -369,7 +386,7 @@ def run_script_with_checkpoint( return script_path missing_or_invalid = sum( - 1 for output_file in output_files if not is_checkpointed(branch, output_file) + 1 for decision in checkpoint_decisions if decision.action != "reuse" ) # Run the script @@ -377,7 +394,9 @@ def run_script_with_checkpoint( # Checkpoint all outputs for output_file in output_files: - save_checkpoint(branch, output_file, volume) + saved = checkpoint_store.save_output(output_file) + if saved: + print(f"Checkpointed: {output_file}") if checkpoint_stats is not None: checkpoint_stats.record( expected_outputs=expected_count, @@ -388,6 +407,109 @@ def run_script_with_checkpoint( return script_path +def _output_paths(output_files: str | list[str]) -> tuple[Path, ...]: + paths = output_files if isinstance(output_files, list) else [output_files] + return tuple(Path(path) for path in paths) + + +def _compat_reuse_decision( + *, + script_path: str, + output_files: list[str], + branch: str, + run_id: str | None = None, + rerun_id: str | None = None, +) -> Stage1ReuseDecision: + """Return the current compatible semantic reuse decision for a script.""" + + substep_id = stage_1_substep_id_for_script(script_path) + material = Stage1IdentityMaterial( + substep_id=substep_id, + inputs={"script_path": script_path}, + parameters={"branch": branch, "outputs": output_files}, + artifact_specs=[ + { + "filename": spec.filename, + "logical_name": spec.logical_name, + "storage_path": spec.storage_path, + "substage_id": spec.substage_id, + } + for spec in stage_1_artifact_specs() + if spec.script_path == script_path + ], + code_sha=get_current_commit(), + schema_version="stage-1-rerun-v1", + upstream_contract_fingerprints=(), + randomness={"checkpoint_scope": "branch_commit"}, + ) + planner = Stage1RerunPlanner( + previous_identities={substep_id: material.fingerprint()} + ) + return planner.decide( + material, + run_id=run_id or os.environ.get("US_DATA_RUN_ID", "unknown"), + rerun_id=rerun_id, + ) + + +def _run_checkpointed_substep( + *, + coordinator: Stage1Coordinator | None, + script_path: str, + output_files: str | list[str], + branch: str, + volume: modal.Volume, + env: dict, + log_file: IO = None, + checkpoint_stats: CheckpointStats | None = None, +) -> str: + output_list = output_files if isinstance(output_files, list) else [output_files] + if coordinator is None: + return run_script_with_checkpoint( + script_path, + output_files, + branch, + volume, + env=env, + log_file=log_file, + checkpoint_stats=checkpoint_stats, + ) + + checkpoint_store = _checkpoint_store(branch, volume) + reuse_decision = _compat_reuse_decision( + script_path=script_path, + output_files=output_list, + branch=branch, + ) + checkpoint_decisions = checkpoint_store.decisions_for(output_list) + + def action() -> str: + return run_script_with_checkpoint( + script_path, + output_files, + branch, + volume, + env=env, + log_file=log_file, + checkpoint_stats=checkpoint_stats, + checkpoint_store=checkpoint_store, + reuse_decision=reuse_decision, + ) + + substep_id = stage_1_substep_id_for_script(script_path) + return coordinator.run_substep( + substep_id, + stage_1_substep_title(substep_id), + action, + command_names=(script_path,), + artifact_paths=_output_paths(output_files), + reuse_decision=reuse_decision.to_dict(), + checkpoint_decisions=tuple( + decision.to_dict() for decision in checkpoint_decisions + ), + ) + + @pipeline_node( PipelineNode( id="cps_puf_build_phase", @@ -408,68 +530,22 @@ def run_cps_then_puf_phase( env: dict, log_file: IO = None, checkpoint_stats: CheckpointStats | None = None, + coordinator: Stage1Coordinator | None = None, ) -> None: """Build CPS before PUF because PUF pension imputation loads CPS_2024.""" for script in (CPS_BUILD_SCRIPT, PUF_BUILD_SCRIPT): - run_script_with_checkpoint( - script, - SCRIPT_OUTPUTS[script], - branch, - volume, + _run_checkpointed_substep( + coordinator=coordinator, + script_path=script, + output_files=SCRIPT_OUTPUTS[script], + branch=branch, + volume=volume, env=env, log_file=log_file, checkpoint_stats=checkpoint_stats, ) -def run_tests_with_checkpoints( - branch: str, - volume: modal.Volume, - env: dict, -) -> None: - """Run post-build validators module-by-module, checkpointing progress. - - Args: - branch: Git branch name for checkpoint scoping. - volume: Modal volume for checkpointing. - env: Environment variables dict. - - Raises: - RuntimeError: If any validation module fails. - """ - commit = get_current_commit() - checkpoint_dir = Path(VOLUME_MOUNT) / branch / commit / "tests" - checkpoint_dir.mkdir(parents=True, exist_ok=True) - - for module in VALIDATION_MODULES: - # Use stem for files, or last component for directories - module_path = Path(module) - if module_path.suffix: - module_name = module_path.stem - else: - module_name = module_path.name.rstrip("/") - - marker_file = checkpoint_dir / f"{module_name}.passed" - - if marker_file.exists(): - print(f"Skipping {module} (already passed)") - continue - - print(f"Running validation: {module}") - result = subprocess.run( - _python_cmd("-u", "-m", "pytest", module, "-v"), - env=env, - ) - - if result.returncode != 0: - raise RuntimeError(f"Validation failed: {module}") - - # Mark as passed - marker_file.touch() - volume.commit() - print(f"Checkpointed: {module} passed") - - def write_dataset_build_contract( *, artifacts_dir: Path, @@ -484,13 +560,21 @@ def write_dataset_build_contract( skip_enhanced_cps: bool, skip_stage_5: bool = False, package_version: str = DATA_PACKAGE_VERSION, + branch: str = "unknown", + diagnostics: tuple = (), + validation: ValidationReport | None = None, + substage_validation: Mapping[str, ValidationReport] | None = None, + stage_1_status_metadata: Mapping[str, Any] | None = None, ) -> StageContract: """Write the Stage 1 semantic handoff contract next to copied artifacts.""" - contract = build_dataset_build_output_contract( - artifacts_dir=artifacts_dir, + context = DatasetBuildContext( run_id=run_id, + branch=branch, code_sha=code_sha, package_version=package_version, + artifacts_dir=artifacts_dir, + ) + return DatasetBuildOutputContractBuilder(context=context).write( checkpoint_stats=checkpoint_stats, started_at=started_at, completed_at=completed_at, @@ -499,12 +583,11 @@ def write_dataset_build_contract( stage_only=stage_only, skip_enhanced_cps=skip_enhanced_cps, skip_stage_5=skip_stage_5, + diagnostics=diagnostics, + validation=validation, + substage_validation=substage_validation, + stage_1_status_metadata=stage_1_status_metadata, ) - write_contract( - contract, - artifacts_dir / DATASET_BUILD_OUTPUT_CONTRACT_FILENAME, - ) - return contract @app.function( @@ -587,14 +670,11 @@ def build_datasets( os.chdir("/root/policyengine-us-data") # Clean stale checkpoints from other commits - branch_dir = Path(VOLUME_MOUNT) / branch - if branch_dir.exists(): - current_commit = get_current_commit() - for entry in branch_dir.iterdir(): - if entry.is_dir() and entry.name != current_commit: - shutil.rmtree(entry) - print(f"Removed stale checkpoint dir: {entry.name[:12]}") - checkpoint_volume.commit() + for removed_checkpoint in _checkpoint_store( + branch, + checkpoint_volume, + ).cleanup_other_commits(): + print(f"Removed stale checkpoint dir: {removed_checkpoint.name[:12]}") env = os.environ.copy() @@ -613,24 +693,63 @@ def build_datasets( f"{'=' * 40}\n" ) log_file.flush() + validation_runner = None + if not skip_tests: + validation_runner = Stage1ValidationRunner( + run_id=run_id, + catalog=ValidationTargetCatalog.from_stage_1_specs( + skip_enhanced_cps=skip_enhanced_cps, + skip_stage_5=skip_stage_5, + ), + metadata={"branch": branch, "code_sha": commit}, + ) + coordinator = Stage1Coordinator(validation_runner=validation_runner) + recorded_skips: set[tuple[str, str]] = set() + + def record_skipped_script(script: str, reason: str) -> None: + substep_id = stage_1_substep_id_for_script(script) + if reason == "--skip-stage-5" and substep_id != "1f_source_imputation": + return + key = (substep_id, reason) + if key in recorded_skips: + return + recorded_skips.add(key) + coordinator.run_substep( + substep_id, + stage_1_substep_title(substep_id), + lambda: None, + command_names=(script,), + skip=True, + skip_reason=reason, + ) - # Download prerequisites - run_script( - "policyengine_us_data/storage/download_prerequisites.py", - env=env, - log_file=log_file, - ) - # Build policy_data.db from source - env["PYTHONUNBUFFERED"] = "1" - log_file.write(f"\n{'=' * 60}\nStarting make database...\n{'=' * 60}\n") - log_file.flush() - run_script_logged(["make", "database"], log_file, env) - # Checkpoint policy_data.db immediately after build so it survives - # test failures and can be restored on retries. - save_checkpoint( - branch, - "policyengine_us_data/storage/calibration/policy_data.db", - checkpoint_volume, + def run_raw_data_download() -> None: + run_script( + "policyengine_us_data/storage/download_prerequisites.py", + env=env, + log_file=log_file, + ) + env["PYTHONUNBUFFERED"] = "1" + log_file.write(f"\n{'=' * 60}\nStarting make database...\n{'=' * 60}\n") + log_file.flush() + run_script_logged(["make", "database"], log_file, env) + # Checkpoint policy_data.db immediately after build so it survives + # test failures and can be restored on retries. + save_checkpoint( + branch, + "policyengine_us_data/storage/calibration/policy_data.db", + checkpoint_volume, + ) + + coordinator.run_substep( + "1a_raw_data_download", + stage_1_substep_title("1a_raw_data_download"), + run_raw_data_download, + command_names=( + "policyengine_us_data/storage/download_prerequisites.py", + "make database", + ), + artifact_paths=("policyengine_us_data/storage/calibration/policy_data.db",), ) if sequential: @@ -640,18 +759,21 @@ def build_datasets( "policyengine_us_data/datasets/cps/small_enhanced_cps.py", ): print(f"Skipping {script} (--skip-stage-5)") + record_skipped_script(script, "--skip-stage-5") continue if skip_enhanced_cps and script in ( "policyengine_us_data/datasets/cps/enhanced_cps.py", "policyengine_us_data/datasets/cps/small_enhanced_cps.py", ): print(f"Skipping {script} (--skip-enhanced-cps)") + record_skipped_script(script, "--skip-enhanced-cps") continue - run_script_with_checkpoint( - script, - output, - branch, - checkpoint_volume, + _run_checkpointed_substep( + coordinator=coordinator, + script_path=script, + output_files=output, + branch=branch, + volume=checkpoint_volume, env=env, log_file=log_file, checkpoint_stats=checkpoint_stats, @@ -677,11 +799,12 @@ def build_datasets( with ThreadPoolExecutor(max_workers=3) as executor: futures = { executor.submit( - run_script_with_checkpoint, - script, - output, - branch, - checkpoint_volume, + _run_checkpointed_substep, + coordinator=coordinator, + script_path=script, + output_files=output, + branch=branch, + volume=checkpoint_volume, env=env, log_file=log_file, checkpoint_stats=checkpoint_stats, @@ -701,15 +824,19 @@ def build_datasets( env=env, log_file=log_file, checkpoint_stats=checkpoint_stats, + coordinator=coordinator, ) # SEQUENTIAL: Extended CPS (needs both cps and puf) print("=== Phase 3: Building extended CPS ===") - run_script_with_checkpoint( - "policyengine_us_data/datasets/cps/extended_cps.py", - SCRIPT_OUTPUTS["policyengine_us_data/datasets/cps/extended_cps.py"], - branch, - checkpoint_volume, + _run_checkpointed_substep( + coordinator=coordinator, + script_path="policyengine_us_data/datasets/cps/extended_cps.py", + output_files=SCRIPT_OUTPUTS[ + "policyengine_us_data/datasets/cps/extended_cps.py" + ], + branch=branch, + volume=checkpoint_volume, env=env, log_file=log_file, checkpoint_stats=checkpoint_stats, @@ -723,13 +850,14 @@ def build_datasets( if not skip_enhanced_cps: phase4_futures.append( executor.submit( - run_script_with_checkpoint, - "policyengine_us_data/datasets/cps/enhanced_cps.py", - SCRIPT_OUTPUTS[ + _run_checkpointed_substep, + coordinator=coordinator, + script_path="policyengine_us_data/datasets/cps/enhanced_cps.py", + output_files=SCRIPT_OUTPUTS[ "policyengine_us_data/datasets/cps/enhanced_cps.py" ], - branch, - checkpoint_volume, + branch=branch, + volume=checkpoint_volume, env=env, log_file=log_file, checkpoint_stats=checkpoint_stats, @@ -737,15 +865,22 @@ def build_datasets( ) else: print("Skipping enhanced_cps.py (--skip-enhanced-cps)") + record_skipped_script( + "policyengine_us_data/datasets/cps/enhanced_cps.py", + "--skip-enhanced-cps", + ) phase4_futures.append( executor.submit( - run_script_with_checkpoint, - "policyengine_us_data/calibration/create_stratified_cps.py", - SCRIPT_OUTPUTS[ + _run_checkpointed_substep, + coordinator=coordinator, + script_path=( + "policyengine_us_data/calibration/create_stratified_cps.py" + ), + output_files=SCRIPT_OUTPUTS[ "policyengine_us_data/calibration/create_stratified_cps.py" ], - branch, - checkpoint_volume, + branch=branch, + volume=checkpoint_volume, env=env, log_file=log_file, checkpoint_stats=checkpoint_stats, @@ -759,6 +894,10 @@ def build_datasets( # small_enhanced_cps needs enhanced_cps if skip_stage_5: print("Skipping Phase 5 (--skip-stage-5)") + record_skipped_script( + "policyengine_us_data/calibration/create_source_imputed_cps.py", + "--skip-stage-5", + ) else: print( "=== Phase 5: Building source imputed CPS " @@ -768,13 +907,16 @@ def build_datasets( with ThreadPoolExecutor(max_workers=2) as executor: phase5_futures.append( executor.submit( - run_script_with_checkpoint, - "policyengine_us_data/calibration/create_source_imputed_cps.py", - SCRIPT_OUTPUTS[ + _run_checkpointed_substep, + coordinator=coordinator, + script_path=( + "policyengine_us_data/calibration/create_source_imputed_cps.py" + ), + output_files=SCRIPT_OUTPUTS[ "policyengine_us_data/calibration/create_source_imputed_cps.py" ], - branch, - checkpoint_volume, + branch=branch, + volume=checkpoint_volume, env=env, log_file=log_file, checkpoint_stats=checkpoint_stats, @@ -783,13 +925,16 @@ def build_datasets( if not skip_enhanced_cps: phase5_futures.append( executor.submit( - run_script_with_checkpoint, - "policyengine_us_data/datasets/cps/small_enhanced_cps.py", - SCRIPT_OUTPUTS[ + _run_checkpointed_substep, + coordinator=coordinator, + script_path=( + "policyengine_us_data/datasets/cps/small_enhanced_cps.py" + ), + output_files=SCRIPT_OUTPUTS[ "policyengine_us_data/datasets/cps/small_enhanced_cps.py" ], - branch, - checkpoint_volume, + branch=branch, + volume=checkpoint_volume, env=env, log_file=log_file, checkpoint_stats=checkpoint_stats, @@ -797,6 +942,10 @@ def build_datasets( ) else: print("Skipping small_enhanced_cps.py (--skip-enhanced-cps)") + record_skipped_script( + "policyengine_us_data/datasets/cps/small_enhanced_cps.py", + "--skip-enhanced-cps", + ) for future in as_completed(phase5_futures): future.result() @@ -810,41 +959,63 @@ def build_datasets( artifacts_dir = Path(PIPELINE_MOUNT) / "artifacts" if run_id: artifacts_dir = artifacts_dir / run_id - artifacts_dir.mkdir(parents=True, exist_ok=True) - - # Copy all intermediate H5 datasets for lineage tracing - for output in SCRIPT_OUTPUTS.values(): - paths = output if isinstance(output, list) else [output] - for p in paths: - src = Path(p) - if src.suffix == ".h5" and src.exists(): - shutil.copy2(src, artifacts_dir / src.name) - print( - f" Copied {src.name} ({src.stat().st_size / 1024 / 1024:.1f} MB)" - ) - - # Yearless alias for pipeline consumers (remote_calibration_runner, local_area) - si = artifacts_dir / "source_imputed_stratified_extended_cps_2024.h5" - if si.exists(): - shutil.copy2(si, artifacts_dir / "source_imputed_stratified_extended_cps.h5") - - shutil.copy2( - "policyengine_us_data/storage/calibration/policy_data.db", - artifacts_dir / "policy_data.db", + build_context = DatasetBuildContext( + run_id=run_id, + branch=branch, + code_sha=commit, + package_version=version, + artifacts_dir=artifacts_dir, ) - cal_weights = Path("policyengine_us_data/storage/calibration_weights.npy") - if cal_weights.exists(): - shutil.copy2( - cal_weights, - artifacts_dir / "calibration_weights.npy", + stager = PipelineArtifactStager(context=build_context) + staged_paths = stager.stage_declared_artifacts( + skip_enhanced_cps=skip_enhanced_cps, + skip_stage_5=skip_stage_5, + ) + for staged_path in staged_paths: + print( + f" Copied {staged_path.name} " + f"({staged_path.stat().st_size / 1024 / 1024:.1f} MB)" ) - print(" Copied calibration_weights.npy") - shutil.copy2(log_path, artifacts_dir / "build_log.txt") checkpoint_snapshot = checkpoint_stats.snapshot() - with open(artifacts_dir / "data_build_checkpoint_stats.json", "w") as f: - json.dump(checkpoint_snapshot, f, indent=2, sort_keys=True) + stager.write_checkpoint_stats(checkpoint_snapshot) log_file.close() completed_at_dt = datetime.now(timezone.utc) + diagnostics = write_stage_1_diagnostics( + context=build_context, + skip_enhanced_cps=skip_enhanced_cps, + skip_stage_5=skip_stage_5, + ) + validation_report: ValidationReport | None = None + validation_diagnostics: tuple = () + substage_validation: Mapping[str, ValidationReport] = {} + if skip_tests: + print("Skipping Stage 1 validation (--skip-tests)") + validation_report = ValidationReport( + status="not_run", + metadata={ + "stage_id": "1_build_datasets", + "run_id": run_id, + "skip_reason": "--skip-tests", + }, + ) + else: + print("Writing Stage 1 validation artifacts...") + validation_reports = [ + ValidationReport.from_dict(result.validation_report) + for result in coordinator.results + if result.validation_report is not None + ] + validation_summary = Stage1ValidationResultWriter( + output_dir=Path(PIPELINE_MOUNT) / "runs" / run_id / "validation" + ).write(validation_reports) + validation_report = validation_summary.report + validation_diagnostics = validation_summary.diagnostics + substage_validation = validation_summary.substage_reports + stage_1_status_metadata = { + "substep_results": [result.to_dict() for result in coordinator.results], + "status_events": [event.to_dict() for event in coordinator.status_events], + "error_records": [error.to_dict() for error in coordinator.error_records], + } write_dataset_build_contract( artifacts_dir=artifacts_dir, run_id=run_id, @@ -858,17 +1029,15 @@ def build_datasets( skip_enhanced_cps=skip_enhanced_cps, skip_stage_5=skip_stage_5, package_version=version, + branch=branch, + diagnostics=(*diagnostics, *validation_diagnostics), + validation=validation_report, + substage_validation=substage_validation, + stage_1_status_metadata=stage_1_status_metadata, ) pipeline_volume.commit() print("Pipeline artifacts committed to shared volume") - # Run post-build validators with checkpointing. - if skip_tests: - print("Skipping tests (--skip-tests)") - else: - print("=== Running post-build validation with checkpointing ===") - run_tests_with_checkpoints(branch, checkpoint_volume, env) - validate_and_maybe_upload_datasets( upload=upload, skip_enhanced_cps=skip_enhanced_cps, diff --git a/policyengine_us_data/build_datasets/__init__.py b/policyengine_us_data/build_datasets/__init__.py index e6b86fabf..a5876e6c6 100644 --- a/policyengine_us_data/build_datasets/__init__.py +++ b/policyengine_us_data/build_datasets/__init__.py @@ -5,23 +5,109 @@ STAGE_1_ARTIFACT_SPECS, stage_1_artifact_specs, stage_1_contract_artifact_specs, + stage_1_diagnostic_artifact_specs, + stage_1_pipeline_artifact_specs, stage_1_script_outputs, ) +from .checkpoints import ( + CheckpointDecision, + CheckpointReuseSummary, + CheckpointStore, +) +from .commands import CommandRunner, DatasetCommand, DatasetCommandError +from .context import DatasetBuildContext +from .contracts import DatasetBuildOutputContractBuilder +from .coordinator import ( + CommandBackedSubstepRunner, + Stage1Coordinator, + Stage1SubstepRunner, + Stage1ValidationAdapter, + stage_1_substep_id_for_script, + stage_1_substep_title, +) +from .diagnostics import ( + ARTIFACT_SCHEMA_VERSION, + DatasetInventoryWriter, + SourceDatasetSchemaSummaryWriter, + TargetDatabaseSchemaSummaryWriter, + write_stage_1_diagnostics, +) from .specs import ( DatasetBuildStepSpec, STAGE_1_BUILD_DATASETS, STAGE_1_BUILD_STEP_SPECS, stage_1_step_specs, ) +from .results import DatasetCommandResult, DatasetSubstepResult +from .rerun import ( + Stage1IdentityMaterial, + Stage1RerunPlanner, + Stage1ReuseDecision, +) +from .staging import PipelineArtifactStager +from .status import Stage1ErrorRecord, Stage1StatusEvent +from .validation import ( + Stage1ValidationContext, + Stage1ValidationError, + Stage1ValidationRunner, + Stage1Validator, + Stage1ValidatorSpec, + iter_stage_1_validators, + run_stage_1_validators, + validators_for_substage, +) +from .validation_results import Stage1ValidationResultWriter, Stage1ValidationSummary +from .validation_targets import ValidationTarget, ValidationTargetCatalog __all__ = [ + "ARTIFACT_SCHEMA_VERSION", + "CheckpointDecision", + "CheckpointReuseSummary", + "CheckpointStore", + "CommandBackedSubstepRunner", + "CommandRunner", "DatasetArtifactSpec", + "DatasetBuildContext", + "DatasetBuildOutputContractBuilder", "DatasetBuildStepSpec", + "DatasetCommand", + "DatasetCommandError", + "DatasetCommandResult", + "DatasetInventoryWriter", + "DatasetSubstepResult", + "PipelineArtifactStager", "STAGE_1_ARTIFACT_SPECS", "STAGE_1_BUILD_DATASETS", "STAGE_1_BUILD_STEP_SPECS", + "SourceDatasetSchemaSummaryWriter", + "Stage1Coordinator", + "Stage1ErrorRecord", + "Stage1IdentityMaterial", + "Stage1RerunPlanner", + "Stage1ReuseDecision", + "Stage1StatusEvent", + "Stage1SubstepRunner", + "Stage1ValidationAdapter", + "Stage1ValidationContext", + "Stage1ValidationError", + "Stage1ValidationResultWriter", + "Stage1ValidationRunner", + "Stage1ValidationSummary", + "Stage1Validator", + "Stage1ValidatorSpec", + "TargetDatabaseSchemaSummaryWriter", + "ValidationTarget", + "ValidationTargetCatalog", + "iter_stage_1_validators", + "run_stage_1_validators", "stage_1_artifact_specs", "stage_1_contract_artifact_specs", + "stage_1_diagnostic_artifact_specs", + "stage_1_pipeline_artifact_specs", "stage_1_script_outputs", + "stage_1_substep_id_for_script", + "stage_1_substep_title", "stage_1_step_specs", + "validators_for_substage", + "write_stage_1_diagnostics", ] diff --git a/policyengine_us_data/build_datasets/artifacts.py b/policyengine_us_data/build_datasets/artifacts.py index b68c5ae67..07d4cd36c 100644 --- a/policyengine_us_data/build_datasets/artifacts.py +++ b/policyengine_us_data/build_datasets/artifacts.py @@ -26,6 +26,9 @@ class DatasetArtifactSpec: required_for_stage_2: bool = False yearless_alias: bool = False contract_output: bool = True + pipeline_output: bool = True + diagnostic_output: bool = False + diagnostic_kind: str | None = None skip_when_enhanced_cps_skipped: bool = False skip_when_stage_5_skipped: bool = False @@ -53,6 +56,7 @@ class DatasetArtifactSpec: storage_path="policyengine_us_data/storage/uprating_factors.csv", script_path=_UPRATING_SCRIPT, contract_output=False, + pipeline_output=False, ), DatasetArtifactSpec( filename="acs_2022.h5", @@ -120,6 +124,7 @@ class DatasetArtifactSpec: ), script_path=_ENHANCED_CPS_SCRIPT, contract_output=False, + pipeline_output=False, skip_when_enhanced_cps_skipped=True, ), DatasetArtifactSpec( @@ -130,6 +135,7 @@ class DatasetArtifactSpec: storage_path="calibration_log.csv", script_path=_ENHANCED_CPS_SCRIPT, contract_output=False, + pipeline_output=False, skip_when_enhanced_cps_skipped=True, ), DatasetArtifactSpec( @@ -184,11 +190,21 @@ class DatasetArtifactSpec: storage_path="policyengine_us_data/storage/calibration/policy_data.db", required_for_stage_2=True, ), + DatasetArtifactSpec( + filename="calibration_weights.npy", + logical_name="calibration_weights", + artifact_family="legacy_optional_weight", + substage_id="1g_stage_base_datasets", + storage_path="policyengine_us_data/storage/calibration_weights.npy", + required=False, + contract_output=False, + ), DatasetArtifactSpec( filename="build_log.txt", logical_name="build_log", artifact_family="log", substage_id="1g_stage_base_datasets", + storage_path="build_log.txt", ), DatasetArtifactSpec( filename="data_build_checkpoint_stats.json", @@ -196,6 +212,37 @@ class DatasetArtifactSpec: artifact_family="execution_metadata", substage_id="1g_stage_base_datasets", ), + DatasetArtifactSpec( + filename="dataset_inventory.json", + logical_name="dataset_inventory", + artifact_family="diagnostic", + substage_id="1g_stage_base_datasets", + required=False, + contract_output=False, + diagnostic_output=True, + diagnostic_kind="dataset_inventory", + ), + DatasetArtifactSpec( + filename="source_dataset_schema_summary.json", + logical_name="source_dataset_schema_summary", + artifact_family="diagnostic", + substage_id="1f_source_imputation", + required=False, + contract_output=False, + diagnostic_output=True, + diagnostic_kind="source_dataset_schema_summary", + skip_when_stage_5_skipped=True, + ), + DatasetArtifactSpec( + filename="target_database_schema_summary.json", + logical_name="target_database_schema_summary", + artifact_family="diagnostic", + substage_id="1g_stage_base_datasets", + required=False, + contract_output=False, + diagnostic_output=True, + diagnostic_kind="target_database_schema_summary", + ), ) _STAGE_1_CONTRACT_OUTPUT_FILENAMES = ( @@ -253,6 +300,18 @@ def stage_1_contract_artifact_specs() -> tuple[DatasetArtifactSpec, ...]: ) +def stage_1_pipeline_artifact_specs() -> tuple[DatasetArtifactSpec, ...]: + """Return artifact specs staged into the run-scoped pipeline directory.""" + + return tuple(spec for spec in STAGE_1_ARTIFACT_SPECS if spec.pipeline_output) + + +def stage_1_diagnostic_artifact_specs() -> tuple[DatasetArtifactSpec, ...]: + """Return diagnostic artifact specs emitted by Stage 1 writers.""" + + return tuple(spec for spec in STAGE_1_ARTIFACT_SPECS if spec.diagnostic_output) + + def stage_1_script_outputs() -> Mapping[str, ScriptOutput]: """Return the checkpoint output mapping consumed by Modal data-build.""" @@ -274,5 +333,7 @@ def stage_1_script_outputs() -> Mapping[str, ScriptOutput]: "ScriptOutput", "stage_1_artifact_specs", "stage_1_contract_artifact_specs", + "stage_1_diagnostic_artifact_specs", + "stage_1_pipeline_artifact_specs", "stage_1_script_outputs", ] diff --git a/policyengine_us_data/build_datasets/checkpoints.py b/policyengine_us_data/build_datasets/checkpoints.py new file mode 100644 index 000000000..a5519d3a4 --- /dev/null +++ b/policyengine_us_data/build_datasets/checkpoints.py @@ -0,0 +1,195 @@ +"""Checkpoint adapter for Stage 1 dataset-build execution.""" + +from __future__ import annotations + +import shutil +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + + +CheckpointAction = Literal["reuse", "recompute", "blocked"] + + +@dataclass(frozen=True, kw_only=True) +class CheckpointDecision: + """Physical checkpoint decision for one expected output.""" + + output_file: str + checkpoint_path: Path + action: CheckpointAction + reason: str + size_bytes: int = 0 + + def to_dict(self) -> dict[str, object]: + """Return a JSON-compatible checkpoint decision.""" + + return { + "output_file": self.output_file, + "checkpoint_path": str(self.checkpoint_path), + "action": self.action, + "reason": self.reason, + "size_bytes": self.size_bytes, + } + + +@dataclass(frozen=True, kw_only=True) +class CheckpointReuseSummary: + """Checkpoint counter summary compatible with existing Stage 1 stats.""" + + expected_outputs: int + valid_reused_outputs: int + recomputed_outputs: int + invalid_outputs: int + + @classmethod + def from_decisions( + cls, + decisions: Sequence[CheckpointDecision], + *, + recomputed: bool, + ) -> "CheckpointReuseSummary": + """Build prior-compatible counters from checkpoint decisions.""" + + reusable = sum(decision.action == "reuse" for decision in decisions) + invalid = sum(decision.action != "reuse" for decision in decisions) + return cls( + expected_outputs=len(decisions), + valid_reused_outputs=reusable if not recomputed else 0, + recomputed_outputs=len(decisions) if recomputed else 0, + invalid_outputs=invalid, + ) + + def to_dict(self) -> dict[str, int]: + """Return counters in the Stage 1 contract shape.""" + + return { + "expected_outputs": self.expected_outputs, + "valid_reused_outputs": self.valid_reused_outputs, + "recomputed_outputs": self.recomputed_outputs, + "invalid_outputs": self.invalid_outputs, + } + + +@dataclass(frozen=True, kw_only=True) +class CheckpointStore: + """Adapter around the Stage 1 physical checkpoint volume layout.""" + + root: Path + branch: str + commit_sha: str + commit: Callable[[], None] | None = None + + def checkpoint_path(self, output_file: str) -> Path: + """Return the checkpoint path for an output file.""" + + return self.root / self.branch / self.commit_sha / Path(output_file).name + + def decision_for(self, output_file: str) -> CheckpointDecision: + """Return the physical checkpoint decision for one output.""" + + path = self.checkpoint_path(output_file) + if not path.exists(): + return CheckpointDecision( + output_file=output_file, + checkpoint_path=path, + action="recompute", + reason="missing", + ) + size = path.stat().st_size + if size <= 0: + return CheckpointDecision( + output_file=output_file, + checkpoint_path=path, + action="recompute", + reason="empty", + size_bytes=size, + ) + return CheckpointDecision( + output_file=output_file, + checkpoint_path=path, + action="reuse", + reason="valid", + size_bytes=size, + ) + + def decisions_for( + self, output_files: Sequence[str] + ) -> tuple[CheckpointDecision, ...]: + """Return physical checkpoint decisions for expected outputs.""" + + return tuple(self.decision_for(output_file) for output_file in output_files) + + def all_outputs_reusable(self, output_files: Sequence[str]) -> bool: + """Return true only when every expected output has a valid checkpoint.""" + + return all( + decision.action == "reuse" for decision in self.decisions_for(output_files) + ) + + def restore_output(self, output_file: str) -> bool: + """Restore one checkpointed output if it is valid.""" + + decision = self.decision_for(output_file) + if decision.action != "reuse": + return False + local_path = Path(output_file) + local_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(decision.checkpoint_path, local_path) + return True + + def restore_all_outputs(self, output_files: Sequence[str]) -> bool: + """Restore outputs only when all expected checkpoints are valid.""" + + if not self.all_outputs_reusable(output_files): + return False + for output_file in output_files: + self.restore_output(output_file) + return True + + def save_output(self, output_file: str) -> bool: + """Save one local output to the checkpoint store if it is non-empty.""" + + local_path = Path(output_file) + if not local_path.exists() or local_path.stat().st_size <= 0: + return False + checkpoint_path = self.checkpoint_path(output_file) + checkpoint_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(local_path, checkpoint_path) + if self.commit is not None: + self.commit() + return True + + def cleanup_branch(self) -> bool: + """Remove all checkpoint attempts for this branch.""" + + branch_dir = self.root / self.branch + if not branch_dir.exists(): + return False + shutil.rmtree(branch_dir) + if self.commit is not None: + self.commit() + return True + + def cleanup_other_commits(self) -> tuple[Path, ...]: + """Remove stale checkpoint directories for other commits in the branch.""" + + branch_dir = self.root / self.branch + if not branch_dir.exists(): + return () + removed: list[Path] = [] + for entry in branch_dir.iterdir(): + if entry.is_dir() and entry.name != self.commit_sha: + shutil.rmtree(entry) + removed.append(entry) + if removed and self.commit is not None: + self.commit() + return tuple(removed) + + +__all__ = [ + "CheckpointDecision", + "CheckpointReuseSummary", + "CheckpointStore", +] diff --git a/policyengine_us_data/build_datasets/commands.py b/policyengine_us_data/build_datasets/commands.py new file mode 100644 index 000000000..127aec20e --- /dev/null +++ b/policyengine_us_data/build_datasets/commands.py @@ -0,0 +1,173 @@ +"""Command construction and execution for Stage 1 dataset builds.""" + +from __future__ import annotations + +import subprocess +import sys +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import IO, Any + +from .results import DatasetCommandResult +from .status import Stage1ErrorRecord, utc_timestamp + + +@dataclass(frozen=True, kw_only=True) +class DatasetCommand: + """A side-effecting command used by Stage 1 dataset builds.""" + + name: str + argv: tuple[str, ...] + kind: str = "python" + side_effecting: bool = True + metadata: Mapping[str, Any] = field(default_factory=dict) + + @classmethod + def from_script( + cls, + script_path: str, + *, + args: Sequence[str] | None = None, + python_executable: str | None = None, + ) -> "DatasetCommand": + """Build the command used to run a Python script or module.""" + + script = Path(script_path) + executable = python_executable or sys.executable + if ( + script.suffix == ".py" + and script.parts + and script.parts[0] in {"policyengine_us_data", "modal_app"} + ): + argv = ( + executable, + "-u", + "-m", + ".".join(script.with_suffix("").parts), + ) + else: + argv = (executable, "-u", script_path) + if args: + argv = (*argv, *tuple(args)) + return cls( + name=script_path, + argv=argv, + kind="python_module" if "-m" in argv else "python_script", + metadata={"script_path": script_path}, + ) + + +class DatasetCommandError(RuntimeError): + """Raised when a Stage 1 command exits unsuccessfully.""" + + def __init__(self, result: DatasetCommandResult): + self.result = result + super().__init__(f"Command failed ({result.returncode}): {result.command_name}") + + +@dataclass(frozen=True, kw_only=True) +class CommandRunner: + """Run Stage 1 commands while streaming and capturing output.""" + + output_tail_lines: int = 200 + + def run( + self, + command: DatasetCommand, + *, + env: Mapping[str, str] | None = None, + log_file: IO[str] | None = None, + check: bool = True, + ) -> DatasetCommandResult: + """Run a command and return a structured execution result.""" + + started_dt = datetime.now(timezone.utc) + combined_output: list[str] = [] + run_env = dict(env) if env is not None else None + if run_env is not None: + run_env["PYTHONUNBUFFERED"] = "1" + + try: + proc = subprocess.Popen( + list(command.argv), + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + env=run_env, + ) + if proc.stdout is not None: + for line in proc.stdout: + sys.stdout.write(line) + sys.stdout.flush() + if log_file is not None: + log_file.write(line) + combined_output.append(line) + proc.wait() + result = self._result( + command=command, + started_dt=started_dt, + returncode=proc.returncode, + combined_output=combined_output, + ) + if check and proc.returncode != 0: + raise DatasetCommandError(result) + return result + except DatasetCommandError: + raise + except Exception as exc: + result = self._result( + command=command, + started_dt=started_dt, + returncode=None, + combined_output=combined_output, + exception=exc, + ) + if check: + raise DatasetCommandError(result) from exc + return result + + def _result( + self, + *, + command: DatasetCommand, + started_dt: datetime, + returncode: int | None, + combined_output: Sequence[str], + exception: BaseException | None = None, + ) -> DatasetCommandResult: + completed_dt = datetime.now(timezone.utc) + status = "completed" if returncode == 0 and exception is None else "failed" + error = None + if status == "failed": + error = Stage1ErrorRecord.from_exception( + exception or RuntimeError(f"Command exited with {returncode}"), + command_name=command.name, + returncode=returncode, + metadata={"argv": list(command.argv), "kind": command.kind}, + ) + return DatasetCommandResult( + command_name=command.name, + argv=command.argv, + status=status, + returncode=returncode, + started_at=utc_timestamp(started_dt), + completed_at=utc_timestamp(completed_dt), + duration_s=(completed_dt - started_dt).total_seconds(), + combined_output_tail=tuple(combined_output[-self.output_tail_lines :]), + error=error, + metadata={ + **dict(command.metadata), + "kind": command.kind, + "side_effecting": command.side_effecting, + "stderr_merged": True, + }, + ) + + +__all__ = [ + "CommandRunner", + "DatasetCommand", + "DatasetCommandError", +] diff --git a/policyengine_us_data/build_datasets/context.py b/policyengine_us_data/build_datasets/context.py new file mode 100644 index 000000000..16e57825d --- /dev/null +++ b/policyengine_us_data/build_datasets/context.py @@ -0,0 +1,66 @@ +"""Run context for the Stage 1 dataset-build handoff.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + +from .specs import STAGE_1_BUILD_DATASETS + + +@dataclass(frozen=True, kw_only=True) +class DatasetBuildContext: + """Identity and filesystem context for one Stage 1 dataset-build run.""" + + run_id: str + branch: str + code_sha: str + package_version: str + artifacts_dir: Path + storage_dir: Path = Path("policyengine_us_data/storage") + work_dir: Path = Path(".") + stage_id: str = STAGE_1_BUILD_DATASETS + + def __post_init__(self) -> None: + if not self.run_id: + raise ValueError("run_id is required") + if not self.branch: + raise ValueError("branch is required") + if not self.code_sha: + raise ValueError("code_sha is required") + if not self.package_version: + raise ValueError("package_version is required") + object.__setattr__(self, "artifacts_dir", Path(self.artifacts_dir)) + object.__setattr__(self, "storage_dir", Path(self.storage_dir)) + object.__setattr__(self, "work_dir", Path(self.work_dir)) + + def source_path(self, storage_path: str) -> Path: + """Resolve a declared storage or working-directory source path.""" + + path = Path(storage_path) + if path.is_absolute(): + return path + storage_prefix = Path("policyengine_us_data/storage") + try: + return self.storage_dir / path.relative_to(storage_prefix) + except ValueError: + return self.work_dir / path + + def artifact_path(self, filename: str) -> Path: + """Return the run-scoped destination path for a staged artifact.""" + + return self.artifacts_dir / filename + + def identity(self) -> dict[str, str]: + """Return stable identity fields for Stage 1 diagnostic payloads.""" + + return { + "run_id": self.run_id, + "stage_id": self.stage_id, + "branch": self.branch, + "code_sha": self.code_sha, + "package_version": self.package_version, + } + + +__all__ = ["DatasetBuildContext"] diff --git a/policyengine_us_data/build_datasets/contracts.py b/policyengine_us_data/build_datasets/contracts.py new file mode 100644 index 000000000..c544abdb6 --- /dev/null +++ b/policyengine_us_data/build_datasets/contracts.py @@ -0,0 +1,74 @@ +"""Contract builder facade for Stage 1 dataset-build outputs.""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass + +from .context import DatasetBuildContext + + +@dataclass(frozen=True, kw_only=True) +class DatasetBuildOutputContractBuilder: + """Build and persist the Stage 1 dataset-build handoff contract.""" + + context: DatasetBuildContext + + def build( + self, + *, + checkpoint_stats: Mapping[str, int], + started_at: str | None, + completed_at: str, + duration_s: float | None, + upload_requested: bool, + stage_only: bool, + skip_enhanced_cps: bool, + skip_stage_5: bool = False, + diagnostics: Sequence[object] = (), + validation: object | None = None, + substage_validation: Mapping[str, object] | None = None, + stage_1_status_metadata: Mapping[str, object] | None = None, + ): + """Build the Stage 1 handoff contract from staged artifacts.""" + + from policyengine_us_data.stage_contracts import ( + build_dataset_build_output_contract, + ) + + return build_dataset_build_output_contract( + artifacts_dir=self.context.artifacts_dir, + run_id=self.context.run_id, + code_sha=self.context.code_sha, + package_version=self.context.package_version, + checkpoint_stats=checkpoint_stats, + started_at=started_at, + completed_at=completed_at, + duration_s=duration_s, + upload_requested=upload_requested, + stage_only=stage_only, + skip_enhanced_cps=skip_enhanced_cps, + skip_stage_5=skip_stage_5, + diagnostics=tuple(diagnostics), + validation=validation, + substage_validation=substage_validation, + stage_1_status_metadata=stage_1_status_metadata, + ) + + def write(self, **kwargs): + """Build and write the Stage 1 handoff contract next to artifacts.""" + + from policyengine_us_data.stage_contracts import ( + DATASET_BUILD_OUTPUT_CONTRACT_FILENAME, + write_contract, + ) + + contract = self.build(**kwargs) + write_contract( + contract, + self.context.artifacts_dir / DATASET_BUILD_OUTPUT_CONTRACT_FILENAME, + ) + return contract + + +__all__ = ["DatasetBuildOutputContractBuilder"] diff --git a/policyengine_us_data/build_datasets/coordinator.py b/policyengine_us_data/build_datasets/coordinator.py new file mode 100644 index 000000000..14c2b7bc2 --- /dev/null +++ b/policyengine_us_data/build_datasets/coordinator.py @@ -0,0 +1,282 @@ +"""Substep coordination for Stage 1 dataset builds.""" + +from __future__ import annotations + +import threading +from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass, field, replace +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Protocol + +from policyengine_us_data.stage_contracts import ValidationReport + +from .artifacts import stage_1_artifact_specs +from .results import DatasetSubstepResult +from .specs import STAGE_1_BUILD_STEP_SPECS +from .status import Stage1ErrorRecord, Stage1StatusEvent, utc_timestamp + + +class Stage1SubstepRunner(Protocol): + """Callable runner for one Stage 1 substep.""" + + substep_id: str + title: str + + def run(self) -> Any: + """Run the substep action.""" + + +class Stage1ValidationAdapter(Protocol): + """Adapter that validates a completed Stage 1 substep result.""" + + def run_for_substep_result( + self, + result: DatasetSubstepResult, + ) -> ValidationReport: + """Run validation for one substep result.""" + + def should_stop(self, report: ValidationReport) -> bool: + """Return whether validation failure should stop downstream work.""" + + +@dataclass(frozen=True, kw_only=True) +class CommandBackedSubstepRunner: + """Run a Stage 1 substep backed by existing side-effecting commands.""" + + substep_id: str + title: str + action: Callable[[], Any] + + def run(self) -> Any: + """Run the wrapped substep action.""" + + return self.action() + + +@dataclass +class Stage1Coordinator: + """Collect Stage 1 substep status events, errors, and results.""" + + validation_runner: Stage1ValidationAdapter | None = None + results: list[DatasetSubstepResult] = field(default_factory=list) + status_events: list[Stage1StatusEvent] = field(default_factory=list) + error_records: list[Stage1ErrorRecord] = field(default_factory=list) + _lock: threading.Lock = field(default_factory=threading.Lock, repr=False) + + def run_substep( + self, + substep_id: str, + title: str | None, + action: Callable[[], Any], + *, + command_names: Sequence[str] = (), + artifact_paths: Sequence[str | Path] = (), + reuse_decision: Mapping[str, Any] | None = None, + checkpoint_decisions: Sequence[Mapping[str, Any]] = (), + skip: bool = False, + skip_reason: str | None = None, + metadata: Mapping[str, Any] | None = None, + ) -> Any: + """Run one declared substep and record structured status.""" + + runner = CommandBackedSubstepRunner( + substep_id=substep_id, + title=title or stage_1_substep_title(substep_id), + action=action, + ) + if skip: + result = self._skipped_result( + runner=runner, + command_names=command_names, + reuse_decision=reuse_decision, + checkpoint_decisions=checkpoint_decisions, + skip_reason=skip_reason, + metadata=metadata, + ) + self._record(result) + return None + + started_dt = datetime.now(timezone.utc) + started_at = utc_timestamp(started_dt) + self._record_event( + Stage1StatusEvent( + substep_id=substep_id, + status="started", + created_at=started_at, + message=f"Started {runner.title}", + metadata=dict(metadata or {}), + ) + ) + try: + value = runner.run() + except Exception as exc: + error = Stage1ErrorRecord.from_exception( + exc, + substep_id=substep_id, + command_name=command_names[0] if command_names else None, + metadata=dict(metadata or {}), + ) + result = self._result( + runner=runner, + status="failed", + started_dt=started_dt, + command_names=command_names, + artifact_paths=artifact_paths, + reuse_decision=reuse_decision, + checkpoint_decisions=checkpoint_decisions, + error=error, + metadata=metadata, + ) + self._record(result) + raise + + result = self._result( + runner=runner, + status="completed", + started_dt=started_dt, + command_names=command_names, + artifact_paths=artifact_paths, + reuse_decision=reuse_decision, + checkpoint_decisions=checkpoint_decisions, + metadata=metadata, + ) + result = self._validated_result(result) + self._record(result) + return value + + def _skipped_result( + self, + *, + runner: CommandBackedSubstepRunner, + command_names: Sequence[str], + reuse_decision: Mapping[str, Any] | None, + checkpoint_decisions: Sequence[Mapping[str, Any]], + skip_reason: str | None, + metadata: Mapping[str, Any] | None, + ) -> DatasetSubstepResult: + completed_at = utc_timestamp() + return DatasetSubstepResult( + substep_id=runner.substep_id, + title=runner.title, + status="skipped", + started_at=None, + completed_at=completed_at, + duration_s=None, + command_names=tuple(command_names), + reuse_decision=reuse_decision, + checkpoint_decisions=tuple(checkpoint_decisions), + metadata={**dict(metadata or {}), "skip_reason": skip_reason}, + ) + + def _result( + self, + *, + runner: CommandBackedSubstepRunner, + status: str, + started_dt: datetime, + command_names: Sequence[str], + artifact_paths: Sequence[str | Path], + reuse_decision: Mapping[str, Any] | None, + checkpoint_decisions: Sequence[Mapping[str, Any]], + error: Stage1ErrorRecord | None = None, + metadata: Mapping[str, Any] | None = None, + ) -> DatasetSubstepResult: + completed_dt = datetime.now(timezone.utc) + return DatasetSubstepResult( + substep_id=runner.substep_id, + title=runner.title, + status=status, + started_at=utc_timestamp(started_dt), + completed_at=utc_timestamp(completed_dt), + duration_s=(completed_dt - started_dt).total_seconds(), + command_names=tuple(command_names), + artifact_paths=_artifact_paths(artifact_paths), + reuse_decision=reuse_decision, + checkpoint_decisions=tuple(checkpoint_decisions), + error=error, + metadata=dict(metadata or {}), + ) + + def _validated_result(self, result: DatasetSubstepResult) -> DatasetSubstepResult: + if self.validation_runner is None: + return result + + report = self.validation_runner.run_for_substep_result(result) + result = replace(result, validation_report=report.to_dict()) + if self.validation_runner.should_stop(report): + exc = RuntimeError(f"Stage 1 validation failed for {result.substep_id}") + error = Stage1ErrorRecord.from_exception( + exc, + substep_id=result.substep_id, + command_name=result.command_names[0] if result.command_names else None, + metadata={"validation_report": report.to_dict()}, + ) + return replace(result, status="failed", error=error) + return result + + def _record(self, result: DatasetSubstepResult) -> None: + with self._lock: + self.results.append(result) + self.status_events.append( + Stage1StatusEvent( + substep_id=result.substep_id, + status=result.status, + created_at=result.completed_at, + message=f"{result.title}: {result.status}", + metadata={ + **dict(result.metadata), + "reuse_decision": result.reuse_decision, + "checkpoint_decisions": [ + dict(decision) for decision in result.checkpoint_decisions + ], + "validation_report": ( + dict(result.validation_report) + if result.validation_report is not None + else None + ), + }, + ) + ) + if result.error is not None: + self.error_records.append(result.error) + if result.validation_report is not None: + raise RuntimeError( + f"Stage 1 validation failed for {result.substep_id}" + ) + + def _record_event(self, event: Stage1StatusEvent) -> None: + with self._lock: + self.status_events.append(event) + + +def stage_1_substep_id_for_script(script_path: str) -> str: + """Return the Stage 1 substep id associated with a script path.""" + + for spec in stage_1_artifact_specs(): + if spec.script_path == script_path: + return spec.substage_id + return "1g_stage_base_datasets" + + +def stage_1_substep_title(substep_id: str) -> str: + """Return the configured Stage 1 title for a substep id.""" + + for spec in STAGE_1_BUILD_STEP_SPECS: + if spec.id == substep_id: + return spec.title + return substep_id + + +def _artifact_paths(paths: Sequence[str | Path]) -> tuple[str, ...]: + return tuple(str(Path(path)) for path in paths) + + +__all__ = [ + "CommandBackedSubstepRunner", + "Stage1Coordinator", + "Stage1SubstepRunner", + "Stage1ValidationAdapter", + "stage_1_substep_id_for_script", + "stage_1_substep_title", +] diff --git a/policyengine_us_data/build_datasets/diagnostics.py b/policyengine_us_data/build_datasets/diagnostics.py new file mode 100644 index 000000000..bc836a79c --- /dev/null +++ b/policyengine_us_data/build_datasets/diagnostics.py @@ -0,0 +1,393 @@ +"""Diagnostic artifact writers for Stage 1 dataset-build outputs.""" + +from __future__ import annotations + +import json +import sqlite3 +from collections.abc import Mapping +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from .artifacts import ( + DatasetArtifactSpec, + stage_1_diagnostic_artifact_specs, + stage_1_pipeline_artifact_specs, +) +from .context import DatasetBuildContext +from policyengine_us_data.utils.step_manifest import sha256_file + + +ARTIFACT_SCHEMA_VERSION = "1" + + +def _json_default(value: Any) -> Any: + if isinstance(value, Path): + return str(value) + raise TypeError(f"Object is not JSON serializable: {type(value).__name__}") + + +def _write_json(path: Path, payload: Mapping[str, Any]) -> None: + path.write_text( + json.dumps( + payload, + default=_json_default, + indent=2, + sort_keys=True, + ) + + "\n" + ) + + +def _media_type_for_path(path: Path) -> str: + suffix = path.suffix.lower() + if suffix == ".h5": + return "application/x-hdf5" + if suffix == ".db": + return "application/vnd.sqlite3" + if suffix == ".json": + return "application/json" + if suffix == ".npy": + return "application/x-numpy-array" + if suffix == ".txt": + return "text/plain" + return "application/octet-stream" + + +def _artifact_ref_for_path( + *, + logical_name: str, + path: Path, + metadata: Mapping[str, Any], +): + from policyengine_us_data.stage_contracts import ArtifactRef + + return ArtifactRef( + logical_name=logical_name, + uri=path.resolve().as_uri(), + sha256=f"sha256:{sha256_file(path)}", + size_bytes=path.stat().st_size, + media_type=_media_type_for_path(path), + metadata=metadata, + ) + + +def _diagnostic_ref_for_path( + *, + spec: DatasetArtifactSpec, + path: Path, + summary: Mapping[str, Any], +): + from policyengine_us_data.stage_contracts import DiagnosticRef + + return DiagnosticRef( + name=spec.logical_name, + kind=spec.diagnostic_kind or spec.artifact_family, + artifact=_artifact_ref_for_path( + logical_name=spec.logical_name, + path=path, + metadata={ + "artifact_family": spec.artifact_family, + "artifact_schema_version": ARTIFACT_SCHEMA_VERSION, + "substage_id": spec.substage_id, + }, + ), + summary=summary, + severity="info", + ) + + +def _diagnostic_spec(logical_name: str) -> DatasetArtifactSpec: + for spec in stage_1_diagnostic_artifact_specs(): + if spec.logical_name == logical_name: + return spec + raise KeyError(f"Unknown Stage 1 diagnostic spec: {logical_name}") + + +def _cheap_h5_summary(path: Path) -> dict[str, Any]: + import h5py + + datasets: list[dict[str, Any]] = [] + entities: dict[str, dict[str, Any]] = {} + + with h5py.File(path, "r") as h5_file: + + def visit(name: str, obj: Any) -> None: + if not isinstance(obj, h5py.Dataset): + return + parts = name.split("/") + entity = parts[0] if parts else "" + variable = parts[-2] if len(parts) > 1 else parts[-1] + period = parts[-1] if parts[-1].isdigit() else None + row_count = int(obj.shape[0]) if obj.shape else None + datasets.append( + { + "path": name, + "entity": entity, + "variable": variable, + "period": period, + "dtype": str(obj.dtype), + "shape": list(obj.shape), + "row_count": row_count, + } + ) + entity_summary = entities.setdefault( + entity, + { + "dataset_count": 0, + "variables": set(), + "periods": set(), + "row_counts": {}, + }, + ) + entity_summary["dataset_count"] += 1 + entity_summary["variables"].add(variable) + if period is not None: + entity_summary["periods"].add(period) + if row_count is not None: + entity_summary["row_counts"][name] = row_count + + h5_file.visititems(visit) + + return { + "datasets": datasets, + "entities": { + entity: { + "dataset_count": summary["dataset_count"], + "variables": sorted(summary["variables"]), + "periods": sorted(summary["periods"]), + "row_counts": summary["row_counts"], + } + for entity, summary in sorted(entities.items()) + }, + } + + +def _sqlite_summary(path: Path) -> dict[str, Any]: + tables = [] + with sqlite3.connect(f"file:{path}?mode=ro", uri=True) as conn: + conn.row_factory = sqlite3.Row + table_names = [ + row["name"] + for row in conn.execute( + """ + SELECT name + FROM sqlite_master + WHERE type = 'table' AND name NOT LIKE 'sqlite_%' + ORDER BY name + """ + ) + ] + checksum_material = [] + for table_name in table_names: + quoted_table_name = _quote_sql_identifier(table_name) + columns = [ + { + "name": row["name"], + "type": row["type"], + "notnull": int(row["notnull"]), + "pk": int(row["pk"]), + } + for row in conn.execute(f"PRAGMA table_info({quoted_table_name})") + ] + row_count = conn.execute( + f"SELECT COUNT(*) AS row_count FROM {quoted_table_name}" + ).fetchone()["row_count"] + table_summary = { + "name": table_name, + "columns": columns, + "row_count": int(row_count), + } + tables.append(table_summary) + checksum_material.append(table_summary) + + digest_payload = json.dumps( + checksum_material, + sort_keys=True, + separators=(",", ":"), + ).encode() + import hashlib + + return { + "tables": tables, + "known_target_tables": [ + table["name"] + for table in tables + if table["name"] in {"targets", "strata", "stratum_constraints"} + ], + "schema_checksum": hashlib.sha256(digest_payload).hexdigest(), + } + + +def _quote_sql_identifier(identifier: str) -> str: + return '"' + identifier.replace('"', '""') + '"' + + +@dataclass(frozen=True, kw_only=True) +class DatasetInventoryWriter: + """Write a compact inventory of Stage 1 artifacts staged for a run.""" + + context: DatasetBuildContext + + def write( + self, + *, + skip_enhanced_cps: bool = False, + skip_stage_5: bool = False, + ): + spec = _diagnostic_spec("dataset_inventory") + artifacts = [] + seen_logical_names: set[str] = set() + for artifact_spec in stage_1_pipeline_artifact_specs(): + if artifact_spec.diagnostic_output: + continue + if skip_enhanced_cps and artifact_spec.skip_when_enhanced_cps_skipped: + continue + if skip_stage_5 and artifact_spec.skip_when_stage_5_skipped: + continue + path = self.context.artifact_path(artifact_spec.filename) + if not path.exists(): + if artifact_spec.required: + raise FileNotFoundError(f"Missing staged artifact: {path}") + continue + if artifact_spec.logical_name in seen_logical_names: + raise ValueError( + f"Duplicate Stage 1 artifact: {artifact_spec.logical_name}" + ) + seen_logical_names.add(artifact_spec.logical_name) + artifacts.append(_inventory_entry(artifact_spec, path)) + + payload = { + "artifact_schema_version": ARTIFACT_SCHEMA_VERSION, + **self.context.identity(), + "artifacts": artifacts, + } + path = self.context.artifact_path(spec.filename) + _write_json(path, payload) + return _diagnostic_ref_for_path( + spec=spec, + path=path, + summary={"artifact_count": len(artifacts)}, + ) + + +@dataclass(frozen=True, kw_only=True) +class SourceDatasetSchemaSummaryWriter: + """Write a metadata-only schema summary for the source-imputed H5 handoff.""" + + context: DatasetBuildContext + + def write(self): + spec = _diagnostic_spec("source_dataset_schema_summary") + source_path = self.context.artifact_path( + "source_imputed_stratified_extended_cps.h5" + ) + if not source_path.exists(): + raise FileNotFoundError(f"Missing source dataset artifact: {source_path}") + h5_summary = _cheap_h5_summary(source_path) + payload = { + "artifact_schema_version": ARTIFACT_SCHEMA_VERSION, + **self.context.identity(), + "logical_name": "source_imputed_stratified_extended_cps", + "path": source_path.name, + **h5_summary, + } + path = self.context.artifact_path(spec.filename) + _write_json(path, payload) + return _diagnostic_ref_for_path( + spec=spec, + path=path, + summary={ + "entity_count": len(h5_summary["entities"]), + "dataset_count": len(h5_summary["datasets"]), + }, + ) + + +@dataclass(frozen=True, kw_only=True) +class TargetDatabaseSchemaSummaryWriter: + """Write a schema and row-count summary for the Stage 1 target database.""" + + context: DatasetBuildContext + + def write(self): + spec = _diagnostic_spec("target_database_schema_summary") + db_path = self.context.artifact_path("policy_data.db") + if not db_path.exists(): + raise FileNotFoundError(f"Missing target database artifact: {db_path}") + db_summary = _sqlite_summary(db_path) + payload = { + "artifact_schema_version": ARTIFACT_SCHEMA_VERSION, + **self.context.identity(), + "logical_name": "policy_data_db", + "path": db_path.name, + **db_summary, + } + path = self.context.artifact_path(spec.filename) + _write_json(path, payload) + return _diagnostic_ref_for_path( + spec=spec, + path=path, + summary={ + "table_count": len(db_summary["tables"]), + "known_target_tables": db_summary["known_target_tables"], + "schema_checksum": db_summary["schema_checksum"], + }, + ) + + +def _inventory_entry(spec: DatasetArtifactSpec, path: Path) -> dict[str, Any]: + entry: dict[str, Any] = { + "artifact_schema_version": ARTIFACT_SCHEMA_VERSION, + "logical_name": spec.logical_name, + "artifact_family": spec.artifact_family, + "substage_id": spec.substage_id, + "path": path.name, + "sha256": f"sha256:{sha256_file(path)}", + "size_bytes": path.stat().st_size, + "media_type": _media_type_for_path(path), + } + if spec.period is not None: + entry["period"] = spec.period + if path.suffix == ".h5": + entry["row_counts"] = { + dataset["path"]: dataset["row_count"] + for dataset in _cheap_h5_summary(path)["datasets"] + if dataset["row_count"] is not None + } + elif path.suffix == ".db": + db_summary = _sqlite_summary(path) + entry["row_counts"] = { + table["name"]: table["row_count"] for table in db_summary["tables"] + } + entry["schema_checksum"] = db_summary["schema_checksum"] + return entry + + +def write_stage_1_diagnostics( + *, + context: DatasetBuildContext, + skip_enhanced_cps: bool = False, + skip_stage_5: bool = False, +) -> tuple[Any, ...]: + """Write Stage 1 diagnostic artifacts and return their contract refs.""" + + refs = [ + DatasetInventoryWriter(context=context).write( + skip_enhanced_cps=skip_enhanced_cps, + skip_stage_5=skip_stage_5, + ), + TargetDatabaseSchemaSummaryWriter(context=context).write(), + ] + if not skip_stage_5: + refs.insert(1, SourceDatasetSchemaSummaryWriter(context=context).write()) + return tuple(refs) + + +__all__ = [ + "ARTIFACT_SCHEMA_VERSION", + "DatasetInventoryWriter", + "SourceDatasetSchemaSummaryWriter", + "TargetDatabaseSchemaSummaryWriter", + "write_stage_1_diagnostics", +] diff --git a/policyengine_us_data/build_datasets/rerun.py b/policyengine_us_data/build_datasets/rerun.py new file mode 100644 index 000000000..8f4aa3792 --- /dev/null +++ b/policyengine_us_data/build_datasets/rerun.py @@ -0,0 +1,128 @@ +"""Rerun and semantic reuse planning for Stage 1 dataset builds.""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from typing import Any, Literal + +from policyengine_us_data.stage_contracts import fingerprint_material + + +Stage1ReuseAction = Literal["reuse", "recompute", "blocked"] + + +@dataclass(frozen=True, kw_only=True) +class Stage1IdentityMaterial: + """Semantic identity material for a Stage 1 substep attempt.""" + + substep_id: str + code_sha: str + schema_version: str + inputs: Mapping[str, Any] = field(default_factory=dict) + parameters: Mapping[str, Any] = field(default_factory=dict) + artifact_specs: Sequence[Mapping[str, Any]] = () + upstream_contract_fingerprints: Sequence[str] = () + randomness: Mapping[str, Any] = field(default_factory=dict) + blocked_reason: str | None = None + + def fingerprint(self) -> str: + """Return the deterministic semantic identity fingerprint.""" + + return fingerprint_material( + { + "substep_id": self.substep_id, + "inputs": dict(self.inputs), + "parameters": dict(self.parameters), + "artifact_specs": list(self.artifact_specs), + "code_sha": self.code_sha, + "schema_version": self.schema_version, + "upstream_contract_fingerprints": list( + self.upstream_contract_fingerprints + ), + "randomness": dict(self.randomness), + } + ).value + + +@dataclass(frozen=True, kw_only=True) +class Stage1ReuseDecision: + """Semantic rerun/reuse decision for a Stage 1 substep.""" + + run_id: str + rerun_id: str | None + artifact_namespace: str + substep_id: str + action: Stage1ReuseAction + reason: str + identity_fingerprint: str + + def to_dict(self) -> dict[str, str | None]: + """Return a JSON-compatible reuse decision.""" + + return { + "run_id": self.run_id, + "rerun_id": self.rerun_id, + "artifact_namespace": self.artifact_namespace, + "substep_id": self.substep_id, + "action": self.action, + "reason": self.reason, + "identity_fingerprint": self.identity_fingerprint, + } + + +@dataclass(frozen=True, kw_only=True) +class Stage1RerunPlanner: + """Decide whether Stage 1 substeps may reuse semantic work.""" + + previous_identities: Mapping[str, str] = field(default_factory=dict) + + def decide( + self, + material: Stage1IdentityMaterial, + *, + run_id: str, + rerun_id: str | None = None, + ) -> Stage1ReuseDecision: + """Return a semantic reuse, recompute, or blocked decision.""" + + fingerprint = material.fingerprint() + if material.blocked_reason: + return Stage1ReuseDecision( + run_id=run_id, + rerun_id=rerun_id, + artifact_namespace=run_id, + substep_id=material.substep_id, + action="blocked", + reason=material.blocked_reason, + identity_fingerprint=fingerprint, + ) + + previous = self.previous_identities.get(material.substep_id) + if previous == fingerprint: + action = "reuse" + reason = "identity_match" + elif previous is None: + action = "recompute" + reason = "no_previous_identity" + else: + action = "recompute" + reason = "identity_mismatch" + + return Stage1ReuseDecision( + run_id=run_id, + rerun_id=rerun_id, + artifact_namespace=run_id, + substep_id=material.substep_id, + action=action, + reason=reason, + identity_fingerprint=fingerprint, + ) + + +__all__ = [ + "Stage1IdentityMaterial", + "Stage1RerunPlanner", + "Stage1ReuseAction", + "Stage1ReuseDecision", +] diff --git a/policyengine_us_data/build_datasets/results.py b/policyengine_us_data/build_datasets/results.py new file mode 100644 index 000000000..820bef32a --- /dev/null +++ b/policyengine_us_data/build_datasets/results.py @@ -0,0 +1,97 @@ +"""Structured execution results for Stage 1 dataset-build commands.""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Any, Literal + +from .status import Stage1ErrorRecord, Stage1SubstepStatus + + +CommandExecutionStatus = Literal["completed", "failed"] + + +@dataclass(frozen=True, kw_only=True) +class DatasetCommandResult: + """Result of running one Stage 1 command.""" + + command_name: str + argv: tuple[str, ...] + status: CommandExecutionStatus + returncode: int | None + started_at: str + completed_at: str + duration_s: float + combined_output_tail: tuple[str, ...] = () + error: Stage1ErrorRecord | None = None + metadata: Mapping[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-compatible command result payload.""" + + return { + "command_name": self.command_name, + "argv": list(self.argv), + "status": self.status, + "returncode": self.returncode, + "started_at": self.started_at, + "completed_at": self.completed_at, + "duration_s": self.duration_s, + "combined_output_tail": list(self.combined_output_tail), + "error": self.error.to_dict() if self.error else None, + "metadata": dict(self.metadata), + } + + +@dataclass(frozen=True, kw_only=True) +class DatasetSubstepResult: + """Result of running or skipping a Stage 1 substep.""" + + substep_id: str + title: str + status: Stage1SubstepStatus + started_at: str | None + completed_at: str + duration_s: float | None + command_names: tuple[str, ...] = () + artifact_paths: tuple[str, ...] = () + reuse_decision: Mapping[str, Any] | None = None + checkpoint_decisions: tuple[Mapping[str, Any], ...] = () + validation_report: Mapping[str, Any] | None = None + error: Stage1ErrorRecord | None = None + metadata: Mapping[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-compatible substep result payload.""" + + return { + "substep_id": self.substep_id, + "title": self.title, + "status": self.status, + "started_at": self.started_at, + "completed_at": self.completed_at, + "duration_s": self.duration_s, + "command_names": list(self.command_names), + "artifact_paths": list(self.artifact_paths), + "reuse_decision": ( + dict(self.reuse_decision) if self.reuse_decision is not None else None + ), + "checkpoint_decisions": [ + dict(decision) for decision in self.checkpoint_decisions + ], + "validation_report": ( + dict(self.validation_report) + if self.validation_report is not None + else None + ), + "error": self.error.to_dict() if self.error else None, + "metadata": dict(self.metadata), + } + + +__all__ = [ + "CommandExecutionStatus", + "DatasetCommandResult", + "DatasetSubstepResult", +] diff --git a/policyengine_us_data/build_datasets/specs.py b/policyengine_us_data/build_datasets/specs.py index 37a4eb36c..e4744eafe 100644 --- a/policyengine_us_data/build_datasets/specs.py +++ b/policyengine_us_data/build_datasets/specs.py @@ -22,6 +22,7 @@ class DatasetBuildStepSpec: reuse_mode: str = "checkpointable" skip_when_enhanced_cps_skipped: bool = False skip_when_stage_5_skipped: bool = False + validation_ids: tuple[str, ...] = () STAGE_1_BUILD_STEP_SPECS: tuple[DatasetBuildStepSpec, ...] = ( @@ -30,28 +31,33 @@ class DatasetBuildStepSpec: title="Raw data download", legacy_stage_id="0", reuse_mode="observed_only", + validation_ids=("stage_1.1a_raw_data_download.artifact_contract",), ), DatasetBuildStepSpec( id="1b_base_dataset_construction", title="Base dataset construction", legacy_stage_id="1", + validation_ids=("stage_1.1b_base_dataset_construction.artifact_contract",), ), DatasetBuildStepSpec( id="1c_extended_cps_puf_clone", title="Extended CPS PUF clone", legacy_stage_id="2", + validation_ids=("stage_1.1c_extended_cps_puf_clone.artifact_contract",), ), DatasetBuildStepSpec( id="1d_enhanced_cps_reweighting", title="Enhanced CPS reweighting", legacy_stage_id="3a", skip_when_enhanced_cps_skipped=True, + validation_ids=("stage_1.1d_enhanced_cps_reweighting.artifact_contract",), ), DatasetBuildStepSpec( id="1e_stratified_cps", title="Stratified CPS", legacy_stage_id="3b", reuse_mode="handoff", + validation_ids=("stage_1.1e_stratified_cps.artifact_contract",), ), DatasetBuildStepSpec( id="1f_source_imputation", @@ -59,6 +65,7 @@ class DatasetBuildStepSpec: legacy_stage_id="4", reuse_mode="handoff", skip_when_stage_5_skipped=True, + validation_ids=("stage_1.1f_source_imputation.artifact_contract",), ), DatasetBuildStepSpec( id="1g_stage_base_datasets", @@ -66,6 +73,7 @@ class DatasetBuildStepSpec: legacy_stage_id="7", manifest_step_ids=("04_stage_base_datasets",), reuse_mode="handoff", + validation_ids=("stage_1.1g_stage_base_datasets.artifact_contract",), ), ) diff --git a/policyengine_us_data/build_datasets/staging.py b/policyengine_us_data/build_datasets/staging.py new file mode 100644 index 000000000..aaa5a8421 --- /dev/null +++ b/policyengine_us_data/build_datasets/staging.py @@ -0,0 +1,85 @@ +"""Artifact staging helpers for Stage 1 dataset-build outputs.""" + +from __future__ import annotations + +import json +import shutil +from collections.abc import Mapping +from dataclasses import dataclass +from pathlib import Path + +from .artifacts import ( + DatasetArtifactSpec, + stage_1_pipeline_artifact_specs, +) +from .context import DatasetBuildContext + + +@dataclass(frozen=True, kw_only=True) +class PipelineArtifactStager: + """Stage declared Stage 1 artifacts into a run-scoped pipeline directory.""" + + context: DatasetBuildContext + + def stage_declared_artifacts( + self, + *, + skip_enhanced_cps: bool = False, + skip_stage_5: bool = False, + ) -> tuple[Path, ...]: + self.context.artifacts_dir.mkdir(parents=True, exist_ok=True) + staged: list[Path] = [] + missing_required: list[str] = [] + + for spec in stage_1_pipeline_artifact_specs(): + if spec.diagnostic_output: + continue + if skip_enhanced_cps and spec.skip_when_enhanced_cps_skipped: + continue + if skip_stage_5 and spec.skip_when_stage_5_skipped: + continue + if spec.yearless_alias: + alias = self._stage_yearless_alias(spec) + if alias is not None: + staged.append(alias) + continue + if spec.storage_path is None: + continue + + source = self.context.source_path(spec.storage_path) + destination = self.context.artifact_path(spec.filename) + if not source.exists(): + if spec.required: + missing_required.append(spec.filename) + continue + shutil.copy2(source, destination) + staged.append(destination) + + if missing_required: + raise FileNotFoundError( + "Missing Stage 1 pipeline artifact(s): " + + ", ".join(sorted(missing_required)) + ) + return tuple(staged) + + def write_checkpoint_stats(self, checkpoint_stats: Mapping[str, int]) -> Path: + """Write checkpoint reuse metadata as an explicit Stage 1 artifact.""" + + path = self.context.artifact_path("data_build_checkpoint_stats.json") + path.write_text( + json.dumps(dict(checkpoint_stats), indent=2, sort_keys=True) + "\n" + ) + return path + + def _stage_yearless_alias(self, spec: DatasetArtifactSpec) -> Path | None: + source = self.context.artifact_path( + "source_imputed_stratified_extended_cps_2024.h5" + ) + if not source.exists(): + return None + destination = self.context.artifact_path(spec.filename) + shutil.copy2(source, destination) + return destination + + +__all__ = ["PipelineArtifactStager"] diff --git a/policyengine_us_data/build_datasets/status.py b/policyengine_us_data/build_datasets/status.py new file mode 100644 index 000000000..a69c9bf25 --- /dev/null +++ b/policyengine_us_data/build_datasets/status.py @@ -0,0 +1,102 @@ +"""Structured status records for Stage 1 dataset-build execution.""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Literal + + +Stage1SubstepStatus = Literal["started", "completed", "skipped", "failed"] + + +def utc_timestamp(value: datetime | None = None) -> str: + """Render a UTC timestamp for Stage 1 execution status records.""" + + value = value or datetime.now(timezone.utc) + return ( + value.astimezone(timezone.utc) + .replace(microsecond=0) + .isoformat() + .replace("+00:00", "Z") + ) + + +@dataclass(frozen=True, kw_only=True) +class Stage1StatusEvent: + """A timestamped status transition for a Stage 1 substep or command.""" + + substep_id: str + status: Stage1SubstepStatus + created_at: str + message: str | None = None + command_name: str | None = None + metadata: Mapping[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-compatible status event payload.""" + + return { + "substep_id": self.substep_id, + "status": self.status, + "created_at": self.created_at, + "message": self.message, + "command_name": self.command_name, + "metadata": dict(self.metadata), + } + + +@dataclass(frozen=True, kw_only=True) +class Stage1ErrorRecord: + """Structured command or substep failure details.""" + + substep_id: str | None + command_name: str | None + error_type: str + message: str + returncode: int | None = None + created_at: str = field(default_factory=utc_timestamp) + metadata: Mapping[str, Any] = field(default_factory=dict) + + @classmethod + def from_exception( + cls, + exc: BaseException, + *, + substep_id: str | None = None, + command_name: str | None = None, + returncode: int | None = None, + metadata: Mapping[str, Any] | None = None, + ) -> "Stage1ErrorRecord": + """Build an error record from an exception without parsing logs.""" + + return cls( + substep_id=substep_id, + command_name=command_name, + error_type=type(exc).__name__, + message=str(exc), + returncode=returncode, + metadata=dict(metadata or {}), + ) + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-compatible error payload.""" + + return { + "substep_id": self.substep_id, + "command_name": self.command_name, + "error_type": self.error_type, + "message": self.message, + "returncode": self.returncode, + "created_at": self.created_at, + "metadata": dict(self.metadata), + } + + +__all__ = [ + "Stage1ErrorRecord", + "Stage1StatusEvent", + "Stage1SubstepStatus", + "utc_timestamp", +] diff --git a/policyengine_us_data/build_datasets/validation.py b/policyengine_us_data/build_datasets/validation.py new file mode 100644 index 000000000..4befb359d --- /dev/null +++ b/policyengine_us_data/build_datasets/validation.py @@ -0,0 +1,406 @@ +"""Stage 1 validation adapters over the shared validation core.""" + +from __future__ import annotations + +from collections.abc import Callable, Mapping +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, TypeAlias +from urllib.parse import unquote, urlparse + +from policyengine_us_data.pipeline_metadata import pipeline_node +from policyengine_us_data.stage_contracts import ( + ArtifactRef, + StageContract, + ValidationFinding, + ValidationReport, +) +from policyengine_us_data.utils.step_manifest import sha256_file +from policyengine_us_data.validation_core import ( + ValidationArtifactResolver, + ValidationCheck, + ValidationContext, + ValidationRunner, + ValidationSuite, +) + +from .artifacts import DatasetArtifactSpec, stage_1_artifact_specs +from .results import DatasetSubstepResult +from .specs import STAGE_1_BUILD_DATASETS, stage_1_step_specs +from .validation_targets import ValidationTargetCatalog + + +Stage1Validator: TypeAlias = Callable[[ValidationContext], ValidationFinding | None] + + +class Stage1ValidationError(RuntimeError): + """Raised when an error-level Stage 1 validation report fails.""" + + +@dataclass(frozen=True, kw_only=True) +class Stage1ValidationContext: + """Stage 1-specific context adapted into validation_core objects.""" + + run_id: str + substage_id: str + artifact_refs: Mapping[str, ArtifactRef] = field(default_factory=dict) + metadata: Mapping[str, Any] = field(default_factory=dict) + output_contract: StageContract | None = None + + @classmethod + def from_contract( + cls, + *, + contract: StageContract, + substage_id: str, + metadata: Mapping[str, Any] | None = None, + ) -> "Stage1ValidationContext": + """Build a context from a Stage 1 output contract.""" + + artifacts = { + artifact.logical_name: artifact + for artifact in contract.outputs + if artifact.metadata.get("substage_id") == substage_id + } + return cls( + run_id=contract.run_id or "unknown", + substage_id=substage_id, + artifact_refs=artifacts, + metadata=dict(metadata or {}), + output_contract=contract, + ) + + @classmethod + def from_substep_result( + cls, + *, + run_id: str, + result: DatasetSubstepResult, + metadata: Mapping[str, Any] | None = None, + ) -> "Stage1ValidationContext": + """Build a context from one completed Stage 1 substep result.""" + + artifacts = { + artifact.logical_name: artifact + for artifact in ( + _artifact_ref_for_path(Path(path), result.substep_id) + for path in result.artifact_paths + ) + if artifact is not None + } + context_metadata = { + "substep_status": result.status, + "command_names": list(result.command_names), + "artifact_paths": list(result.artifact_paths), + } + context_metadata.update(dict(metadata or {})) + return cls( + run_id=run_id, + substage_id=result.substep_id, + artifact_refs=artifacts, + metadata=context_metadata, + ) + + def to_core_context(self) -> ValidationContext: + """Return the shared validation_core context.""" + + return ValidationContext( + run_id=self.run_id, + stage_id=STAGE_1_BUILD_DATASETS, + substage_id=self.substage_id, + resolver=ValidationArtifactResolver(artifacts=self.artifact_refs), + metadata=self.metadata, + ) + + +@dataclass(frozen=True, kw_only=True) +class Stage1ValidatorSpec: + """Stage 1 wrapper around a shared ValidationCheck.""" + + validator_id: str + substage_id: str + description: str + run: Stage1Validator + severity: str = "error" + + def to_check(self, *, required_artifacts: tuple[str, ...]) -> ValidationCheck: + """Return the validation_core check represented by this spec.""" + + return ValidationCheck( + check_id=self.validator_id, + stage_id=STAGE_1_BUILD_DATASETS, + substage_id=self.substage_id, + description=self.description, + severity="warning" if self.severity == "warning" else "error", + required_artifacts=required_artifacts, + run=self.run, + ) + + +@pipeline_node( + id="stage_1_validation_runner", + label="Stage 1 Validation Runner", + node_type="library", + description="Stage 1 adapter that runs ordered validators through validation_core.", + source_file="policyengine_us_data/build_datasets/validation.py", + status="current", + stability="stable", + pathways=["data_build", "stage_contracts", "cross_stage_validation"], + artifacts_in=["Stage 1 artifacts", "dataset_build_output.json"], + artifacts_out=["ValidationReport"], + validation_commands=["uv run pytest tests/unit/test_build_dataset_validation.py"], +) +@dataclass(frozen=True, kw_only=True) +class Stage1ValidationRunner: + """Run Stage 1 validators and aggregate canonical validation reports.""" + + run_id: str + catalog: ValidationTargetCatalog = field( + default_factory=ValidationTargetCatalog.from_stage_1_specs + ) + runner: ValidationRunner = field(default_factory=ValidationRunner) + metadata: Mapping[str, Any] = field(default_factory=dict) + + def run_for_context( + self, + context: Stage1ValidationContext, + *, + required_artifacts: tuple[str, ...] | None = None, + ) -> ValidationReport: + """Run validators for a prepared Stage 1 context.""" + + specs = validators_for_substage(context.substage_id) + if not specs: + return ValidationReport( + status="not_run", + metadata=self._report_metadata(context, skip_reason="no_validators"), + ) + required = ( + tuple(required_artifacts) + if required_artifacts is not None + else tuple(context.artifact_refs) + ) + if not required: + return ValidationReport( + status="not_run", + metadata=self._report_metadata( + context, + skip_reason="no_artifacts_for_substep", + ), + ) + suite = ValidationSuite( + suite_id=f"stage_1.{context.substage_id}", + stage_id=STAGE_1_BUILD_DATASETS, + substage_id=context.substage_id, + checks=tuple(spec.to_check(required_artifacts=required) for spec in specs), + ) + return self.runner.run(suite, context.to_core_context()) + + def run_for_substep_result( + self, + result: DatasetSubstepResult, + ) -> ValidationReport: + """Run validators for one completed coordinator substep result.""" + + context = Stage1ValidationContext.from_substep_result( + run_id=self.run_id, + result=result, + metadata=self.metadata, + ) + return self.run_for_context(context) + + def run_for_contract( + self, + contract: StageContract, + substage_id: str, + ) -> ValidationReport: + """Run validators for one substage using a Stage 1 output contract.""" + + context = Stage1ValidationContext.from_contract( + contract=contract, + substage_id=substage_id, + metadata=self.metadata, + ) + required = self.catalog.required_logical_names(substage_id) + return self.run_for_context(context, required_artifacts=required) + + def should_stop(self, report: ValidationReport) -> bool: + """Return whether a report should stop downstream Stage 1 execution.""" + + return report.status == "fail" + + def _report_metadata( + self, + context: Stage1ValidationContext, + **extra: Any, + ) -> dict[str, Any]: + metadata = { + "stage_id": STAGE_1_BUILD_DATASETS, + "substage_id": context.substage_id, + "run_id": context.run_id, + "context_metadata": dict(context.metadata), + } + metadata.update(extra) + return metadata + + +def iter_stage_1_validators() -> tuple[Stage1ValidatorSpec, ...]: + """Return all registered Stage 1 validator specs.""" + + return _STAGE_1_VALIDATORS + + +def validators_for_substage(substage_id: str) -> tuple[Stage1ValidatorSpec, ...]: + """Return validator specs wired to one Stage 1 substage.""" + + validation_ids = () + for spec in stage_1_step_specs(): + if spec.id == substage_id: + validation_ids = spec.validation_ids + break + validators_by_id = {spec.validator_id: spec for spec in _STAGE_1_VALIDATORS} + return tuple(validators_by_id[validator_id] for validator_id in validation_ids) + + +def run_stage_1_validators( + context: Stage1ValidationContext, + *, + runner: Stage1ValidationRunner | None = None, +) -> ValidationReport: + """Run registered validators for a Stage 1 context.""" + + active_runner = runner or Stage1ValidationRunner(run_id=context.run_id) + return active_runner.run_for_context(context) + + +def _validate_artifact_refs( + context: ValidationContext, + *, + check_id: str, +) -> ValidationFinding | None: + for logical_name, artifact in sorted(context.resolver.artifacts.items()): + path = _file_uri_to_path(artifact.uri) + if path is None: + continue + if not path.exists(): + return ValidationFinding( + check_id=check_id, + status="fail", + message=f"Validation artifact does not exist: {logical_name}", + metric="artifact_exists", + value=str(path), + metadata={ + "logical_name": logical_name, + "uri": artifact.uri, + "substage_id": context.substage_id, + }, + ) + if path.is_file() and path.stat().st_size == 0: + return ValidationFinding( + check_id=check_id, + status="fail", + message=f"Validation artifact is empty: {logical_name}", + metric="artifact_size_bytes", + value=0, + metadata={ + "logical_name": logical_name, + "uri": artifact.uri, + "substage_id": context.substage_id, + }, + ) + return None + + +def _artifact_ref_for_path(path: Path, substage_id: str) -> ArtifactRef | None: + spec = _spec_for_path(path, substage_id) + if spec is None and not path.exists(): + return None + logical_name = spec.logical_name if spec else path.stem + metadata: dict[str, Any] = {"substage_id": substage_id} + if spec is not None: + metadata.update( + { + "artifact_family": spec.artifact_family, + "filename": spec.filename, + "period": spec.period, + } + ) + return ArtifactRef( + logical_name=logical_name, + uri=path.resolve().as_uri(), + sha256=f"sha256:{sha256_file(path)}" + if path.exists() and path.is_file() + else None, + size_bytes=path.stat().st_size if path.exists() and path.is_file() else None, + media_type=_media_type_for_path(path), + metadata=metadata, + ) + + +def _spec_for_path(path: Path, substage_id: str) -> DatasetArtifactSpec | None: + path_name = path.name + path_text = str(path) + fallback: DatasetArtifactSpec | None = None + for spec in stage_1_artifact_specs(): + candidates = {spec.filename} + if spec.storage_path is not None: + candidates.add(spec.storage_path) + candidates.add(Path(spec.storage_path).name) + if path_text in candidates or path_name in candidates: + if spec.substage_id == substage_id: + return spec + fallback = fallback or spec + return fallback + + +def _file_uri_to_path(uri: str) -> Path | None: + parsed = urlparse(uri) + if parsed.scheme != "file": + return None + return Path(unquote(parsed.path)) + + +def _media_type_for_path(path: Path) -> str: + suffix = path.suffix.lower() + if suffix == ".h5": + return "application/x-hdf5" + if suffix in {".db", ".sqlite", ".sqlite3"}: + return "application/vnd.sqlite3" + if suffix == ".json": + return "application/json" + if suffix == ".txt": + return "text/plain" + return "application/octet-stream" + + +def _artifact_contract_validator(substage_id: str) -> Stage1ValidatorSpec: + validator_id = f"stage_1.{substage_id}.artifact_contract" + + def run(context: ValidationContext) -> ValidationFinding | None: + return _validate_artifact_refs(context, check_id=validator_id) + + return Stage1ValidatorSpec( + validator_id=validator_id, + substage_id=substage_id, + description="Validate Stage 1 artifact references for this substage.", + run=run, + ) + + +_STAGE_1_VALIDATORS: tuple[Stage1ValidatorSpec, ...] = tuple( + _artifact_contract_validator(spec.id) + for spec in stage_1_step_specs() + if spec.validation_ids +) + + +__all__ = [ + "Stage1ValidationContext", + "Stage1ValidationError", + "Stage1ValidationRunner", + "Stage1Validator", + "Stage1ValidatorSpec", + "iter_stage_1_validators", + "run_stage_1_validators", + "validators_for_substage", +] diff --git a/policyengine_us_data/build_datasets/validation_results.py b/policyengine_us_data/build_datasets/validation_results.py new file mode 100644 index 000000000..570f9eb9c --- /dev/null +++ b/policyengine_us_data/build_datasets/validation_results.py @@ -0,0 +1,336 @@ +"""Durable Stage 1 validation result writing.""" + +from __future__ import annotations + +import json +import sqlite3 +from collections import defaultdict +from collections.abc import Iterable, Mapping +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from policyengine_us_data.pipeline_metadata import pipeline_node +from policyengine_us_data.stage_contracts import ( + ArtifactRef, + DiagnosticRef, + ValidationFinding, + ValidationReport, +) +from policyengine_us_data.utils.step_manifest import sha256_file +from policyengine_us_data.validation_core import ValidationReportWriter + + +@dataclass(frozen=True, kw_only=True) +class Stage1ValidationSummary: + """Result of writing Stage 1 validation outputs.""" + + report: ValidationReport + substage_reports: Mapping[str, ValidationReport] + diagnostics: tuple[DiagnosticRef, ...] + paths: Mapping[str, Path] + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-compatible summary payload.""" + + return { + "report": self.report.to_dict(), + "substage_reports": { + substage_id: report.to_dict() + for substage_id, report in self.substage_reports.items() + }, + "diagnostics": [diagnostic.to_dict() for diagnostic in self.diagnostics], + "paths": {key: str(path) for key, path in self.paths.items()}, + } + + +@pipeline_node( + id="stage_1_validation_result_writer", + label="Stage 1 Validation Result Writer", + node_type="library", + description="Write Stage 1 validation reports, findings, metrics, and SQLite rows.", + source_file="policyengine_us_data/build_datasets/validation_results.py", + status="current", + stability="stable", + pathways=["data_build", "stage_contracts", "cross_stage_validation"], + artifacts_in=["ValidationReport"], + artifacts_out=[ + "validation/summary.json", + "validation/findings.jsonl", + "validation/metrics.jsonl", + "validation/validation_results.sqlite", + ], + validation_commands=[ + "uv run pytest tests/unit/test_build_dataset_validation_results.py" + ], +) +@dataclass(frozen=True, kw_only=True) +class Stage1ValidationResultWriter: + """Write durable Stage 1 validation artifacts from canonical reports.""" + + output_dir: Path + + def write( + self, + reports: Iterable[ValidationReport], + ) -> Stage1ValidationSummary: + """Write aggregate and per-substage validation outputs.""" + + output_dir = Path(self.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + substage_reports = _merge_reports_by_substage(tuple(reports)) + aggregate = _aggregate_report(substage_reports.values()) + + paths: dict[str, Path] = {} + for substage_id, report in substage_reports.items(): + writer = ValidationReportWriter( + output_dir=output_dir, + strategies=("report",), + report_filename=f"{substage_id}.json", + ) + paths[f"{substage_id}.report"] = writer.write(report)["report"] + + paths["summary"] = output_dir / "summary.json" + paths["summary"].write_text( + json.dumps(_summary_payload(aggregate, substage_reports), indent=2) + "\n", + encoding="utf-8", + ) + paths["findings"] = output_dir / "findings.jsonl" + paths["findings"].write_text( + "".join( + json.dumps(finding.to_dict(), sort_keys=True) + "\n" + for finding in aggregate.findings + ), + encoding="utf-8", + ) + paths["metrics"] = output_dir / "metrics.jsonl" + paths["metrics"].write_text( + "".join( + json.dumps(_metrics_payload(substage_id, report), sort_keys=True) + "\n" + for substage_id, report in substage_reports.items() + ), + encoding="utf-8", + ) + paths["sqlite"] = output_dir / "validation_results.sqlite" + _write_sqlite(paths["sqlite"], substage_reports) + + diagnostics = tuple( + _diagnostic_ref(name=key, kind=_diagnostic_kind(path), path=path) + for key, path in paths.items() + ) + aggregate = ValidationReport( + status=aggregate.status, + findings=aggregate.findings, + diagnostics=diagnostics, + metadata=aggregate.metadata, + ) + return Stage1ValidationSummary( + report=aggregate, + substage_reports=substage_reports, + diagnostics=diagnostics, + paths=paths, + ) + + +def _merge_reports_by_substage( + reports: tuple[ValidationReport, ...], +) -> dict[str, ValidationReport]: + grouped: dict[str, list[ValidationReport]] = defaultdict(list) + for report in reports: + substage_id = report.metadata.get("substage_id") + if isinstance(substage_id, str) and report.status != "not_run": + grouped[substage_id].append(report) + + merged: dict[str, ValidationReport] = {} + for substage_id in sorted(grouped): + findings: list[ValidationFinding] = [] + diagnostics: list[DiagnosticRef] = [] + check_ids: list[str] = [] + for report in grouped[substage_id]: + findings.extend(report.findings) + diagnostics.extend(report.diagnostics) + check_ids.extend(report.metadata.get("check_ids", ())) + merged[substage_id] = ValidationReport( + status=_status_from_findings(findings), + findings=tuple(findings), + diagnostics=tuple(diagnostics), + metadata={ + "stage_id": "1_build_datasets", + "substage_id": substage_id, + "check_ids": sorted(set(check_ids)), + "report_count": len(grouped[substage_id]), + }, + ) + return merged + + +def _aggregate_report(reports: Iterable[ValidationReport]) -> ValidationReport: + reports = tuple(reports) + findings: list[ValidationFinding] = [] + diagnostics: list[DiagnosticRef] = [] + for report in reports: + findings.extend(report.findings) + diagnostics.extend(report.diagnostics) + if not reports: + return ValidationReport( + status="not_run", + metadata={"stage_id": "1_build_datasets", "report_count": 0}, + ) + return ValidationReport( + status=_status_from_findings(findings), + findings=tuple(findings), + diagnostics=tuple(diagnostics), + metadata={ + "stage_id": "1_build_datasets", + "report_count": len(reports), + "substage_ids": [ + report.metadata.get("substage_id") + for report in reports + if report.metadata.get("substage_id") is not None + ], + }, + ) + + +def _status_from_findings(findings: Iterable[ValidationFinding]) -> str: + statuses = tuple(finding.status for finding in findings) + if any(status == "fail" for status in statuses): + return "fail" + if any(status == "warn" for status in statuses): + return "warn" + return "pass" + + +def _summary_payload( + aggregate: ValidationReport, + substage_reports: Mapping[str, ValidationReport], +) -> dict[str, Any]: + return { + "status": aggregate.status, + "finding_count": len(aggregate.findings), + "fail_count": sum( + 1 for finding in aggregate.findings if finding.status == "fail" + ), + "warn_count": sum( + 1 for finding in aggregate.findings if finding.status == "warn" + ), + "pass_count": sum( + 1 for finding in aggregate.findings if finding.status == "pass" + ), + "substages": { + substage_id: _metrics_payload(substage_id, report) + for substage_id, report in substage_reports.items() + }, + } + + +def _metrics_payload(substage_id: str, report: ValidationReport) -> dict[str, Any]: + return { + "substage_id": substage_id, + "status": report.status, + "finding_count": len(report.findings), + "fail_count": sum(1 for finding in report.findings if finding.status == "fail"), + "warn_count": sum(1 for finding in report.findings if finding.status == "warn"), + "pass_count": sum(1 for finding in report.findings if finding.status == "pass"), + } + + +def _write_sqlite( + path: Path, + substage_reports: Mapping[str, ValidationReport], +) -> None: + with sqlite3.connect(path) as connection: + connection.execute( + """ + CREATE TABLE IF NOT EXISTS reports ( + substage_id TEXT PRIMARY KEY, + status TEXT NOT NULL, + finding_count INTEGER NOT NULL, + report_json TEXT NOT NULL + ) + """ + ) + connection.execute( + """ + CREATE TABLE IF NOT EXISTS findings ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + substage_id TEXT NOT NULL, + check_id TEXT NOT NULL, + status TEXT NOT NULL, + message TEXT NOT NULL, + finding_json TEXT NOT NULL + ) + """ + ) + connection.execute("DELETE FROM findings") + connection.execute("DELETE FROM reports") + for substage_id, report in substage_reports.items(): + connection.execute( + """ + INSERT INTO reports + (substage_id, status, finding_count, report_json) + VALUES (?, ?, ?, ?) + """, + ( + substage_id, + report.status, + len(report.findings), + json.dumps(report.to_dict(), sort_keys=True), + ), + ) + for finding in report.findings: + connection.execute( + """ + INSERT INTO findings + (substage_id, check_id, status, message, finding_json) + VALUES (?, ?, ?, ?, ?) + """, + ( + substage_id, + finding.check_id, + finding.status, + finding.message, + json.dumps(finding.to_dict(), sort_keys=True), + ), + ) + + +def _diagnostic_ref(*, name: str, kind: str, path: Path) -> DiagnosticRef: + artifact = ArtifactRef( + logical_name=f"stage_1_validation_{name}", + uri=path.resolve().as_uri(), + sha256=f"sha256:{sha256_file(path)}", + size_bytes=path.stat().st_size, + media_type=_media_type_for_path(path), + metadata={"stage_id": "1_build_datasets", "artifact_family": "validation"}, + ) + return DiagnosticRef( + name=f"stage_1_validation_{name}", + kind=kind, + artifact=artifact, + severity="info", + ) + + +def _diagnostic_kind(path: Path) -> str: + if path.suffix == ".jsonl": + return "jsonl" + if path.suffix in {".sqlite", ".db"}: + return "sqlite" + return "json" + + +def _media_type_for_path(path: Path) -> str: + if path.suffix == ".json": + return "application/json" + if path.suffix == ".jsonl": + return "application/x-ndjson" + if path.suffix in {".sqlite", ".db"}: + return "application/vnd.sqlite3" + return "application/octet-stream" + + +__all__ = [ + "Stage1ValidationResultWriter", + "Stage1ValidationSummary", +] diff --git a/policyengine_us_data/build_datasets/validation_targets.py b/policyengine_us_data/build_datasets/validation_targets.py new file mode 100644 index 000000000..def039ecf --- /dev/null +++ b/policyengine_us_data/build_datasets/validation_targets.py @@ -0,0 +1,193 @@ +"""Validation target catalog for Stage 1 dataset-build artifacts.""" + +from __future__ import annotations + +import csv +import json +import sqlite3 +from collections.abc import Iterable, Mapping +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from policyengine_us_data.pipeline_metadata import pipeline_node + +from .artifacts import stage_1_contract_artifact_specs + + +@dataclass(frozen=True, kw_only=True) +class ValidationTarget: + """One logical artifact expectation for Stage 1 validation.""" + + target_id: str + substage_id: str + logical_name: str + required: bool = True + warning_only: bool = False + metadata: Mapping[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + for value, name in ( + (self.target_id, "target_id"), + (self.substage_id, "substage_id"), + (self.logical_name, "logical_name"), + ): + if not isinstance(value, str) or not value.strip(): + raise ValueError(f"{name} must be a non-empty string") + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-compatible target payload.""" + + return { + "target_id": self.target_id, + "substage_id": self.substage_id, + "logical_name": self.logical_name, + "required": self.required, + "warning_only": self.warning_only, + "metadata": dict(self.metadata), + } + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "ValidationTarget": + """Build a validation target from a mapping.""" + + return cls( + target_id=str(data["target_id"]), + substage_id=str(data["substage_id"]), + logical_name=str(data["logical_name"]), + required=bool(data.get("required", True)), + warning_only=bool(data.get("warning_only", False)), + metadata=dict(data.get("metadata", {})), + ) + + +@pipeline_node( + id="stage_1_validation_target_catalog", + label="Stage 1 Validation Target Catalog", + node_type="library", + description="Deterministic catalog of active Stage 1 validation artifact targets.", + source_file="policyengine_us_data/build_datasets/validation_targets.py", + status="current", + stability="stable", + pathways=["data_build", "stage_contracts", "cross_stage_validation"], + validation_commands=["uv run pytest tests/unit/test_build_dataset_validation.py"], +) +@dataclass(frozen=True, kw_only=True) +class ValidationTargetCatalog: + """Deterministic lookup for active Stage 1 validation targets.""" + + targets: tuple[ValidationTarget, ...] + + def __post_init__(self) -> None: + targets = tuple(self.targets) + seen: set[str] = set() + for target in targets: + if not isinstance(target, ValidationTarget): + raise TypeError("targets must contain ValidationTarget instances") + if target.target_id in seen: + raise ValueError(f"Duplicate validation target: {target.target_id}") + seen.add(target.target_id) + object.__setattr__( + self, + "targets", + tuple(sorted(targets, key=lambda item: item.target_id)), + ) + + @classmethod + def from_stage_1_specs( + cls, + *, + skip_enhanced_cps: bool = False, + skip_stage_5: bool = False, + ) -> "ValidationTargetCatalog": + """Build the active target catalog from Stage 1 artifact specs.""" + + targets: list[ValidationTarget] = [] + for spec in stage_1_contract_artifact_specs(): + if skip_enhanced_cps and spec.skip_when_enhanced_cps_skipped: + continue + if skip_stage_5 and spec.skip_when_stage_5_skipped: + continue + targets.append( + ValidationTarget( + target_id=f"{spec.substage_id}.{spec.logical_name}", + substage_id=spec.substage_id, + logical_name=spec.logical_name, + required=spec.required, + metadata={ + "artifact_family": spec.artifact_family, + "filename": spec.filename, + "period": spec.period, + }, + ) + ) + return cls(targets=tuple(targets)) + + @classmethod + def load(cls, path: str | Path) -> "ValidationTargetCatalog": + """Load a target catalog from JSON, CSV, or SQLite.""" + + path = Path(path) + suffix = path.suffix.lower() + if suffix == ".json": + rows = json.loads(path.read_text(encoding="utf-8")) + elif suffix == ".csv": + with path.open(newline="", encoding="utf-8") as file: + rows = list(csv.DictReader(file)) + elif suffix in {".db", ".sqlite", ".sqlite3"}: + rows = _load_sqlite_targets(path) + else: + raise ValueError(f"Unsupported validation target catalog: {path}") + return cls.from_rows(rows) + + @classmethod + def from_rows( + cls, + rows: Iterable[Mapping[str, Any]], + ) -> "ValidationTargetCatalog": + """Build a catalog from row dictionaries.""" + + return cls(targets=tuple(ValidationTarget.from_dict(row) for row in rows)) + + def active_for_substage(self, substage_id: str) -> tuple[ValidationTarget, ...]: + """Return active targets for one Stage 1 substage.""" + + return tuple( + target for target in self.targets if target.substage_id == substage_id + ) + + def required_logical_names(self, substage_id: str) -> tuple[str, ...]: + """Return required logical artifacts for one substage.""" + + return tuple( + target.logical_name + for target in self.active_for_substage(substage_id) + if target.required + ) + + +def _load_sqlite_targets(path: Path) -> list[dict[str, Any]]: + with sqlite3.connect(path) as connection: + rows = connection.execute( + """ + SELECT target_id, substage_id, logical_name, required, warning_only + FROM validation_targets + ORDER BY target_id + """ + ).fetchall() + return [ + { + "target_id": target_id, + "substage_id": substage_id, + "logical_name": logical_name, + "required": bool(required), + "warning_only": bool(warning_only), + } + for target_id, substage_id, logical_name, required, warning_only in rows + ] + + +__all__ = [ + "ValidationTarget", + "ValidationTargetCatalog", +] diff --git a/policyengine_us_data/stage_contracts/dataset_build.py b/policyengine_us_data/stage_contracts/dataset_build.py index 7e76a4baa..71ec82972 100644 --- a/policyengine_us_data/stage_contracts/dataset_build.py +++ b/policyengine_us_data/stage_contracts/dataset_build.py @@ -6,18 +6,20 @@ from pathlib import Path from typing import Any -from policyengine_us_data.build_datasets import ( - STAGE_1_BUILD_STEP_SPECS, +from policyengine_us_data.build_datasets.artifacts import ( stage_1_contract_artifact_specs, ) +from policyengine_us_data.build_datasets.specs import STAGE_1_BUILD_STEP_SPECS from policyengine_us_data.utils.step_manifest import sha256_file from .artifacts import ArtifactRef from .contracts import StageContract +from .diagnostics import DiagnosticRef from .execution import ExecutionRecord, ReuseSummary from .fingerprints import fingerprint_material from .stages import STAGE_1_BUILD_DATASETS, contract_type_for_stage from .substages import SubstageRecord +from .validation import ValidationReport DATASET_BUILD_OUTPUT_CONTRACT_FILENAME = "dataset_build_output.json" DATASET_BUILD_OUTPUT_CONTRACT_TYPE = contract_type_for_stage(STAGE_1_BUILD_DATASETS) @@ -37,6 +39,10 @@ def build_dataset_build_output_contract( stage_only: bool = False, skip_enhanced_cps: bool = False, skip_stage_5: bool = False, + diagnostics: tuple[DiagnosticRef, ...] = (), + validation: ValidationReport | None = None, + substage_validation: Mapping[str, ValidationReport] | None = None, + stage_1_status_metadata: Mapping[str, Any] | None = None, ) -> StageContract: """Build the Stage 1 handoff contract from copied pipeline artifacts.""" @@ -79,12 +85,17 @@ def build_dataset_build_output_contract( outputs=outputs, skip_enhanced_cps=skip_enhanced_cps, skip_stage_5=skip_stage_5, + substage_validation=substage_validation or {}, ), + validation=validation, + diagnostics=diagnostics, execution=execution, metadata={ "artifact_count": len(outputs), "artifact_directory": str(artifacts_dir), "contract_file": DATASET_BUILD_OUTPUT_CONTRACT_FILENAME, + "diagnostic_count": len(diagnostics), + "stage_1_status": stage_1_status_metadata or {}, }, ) @@ -166,6 +177,7 @@ def _stage_1_substages( outputs: tuple[ArtifactRef, ...], skip_enhanced_cps: bool, skip_stage_5: bool, + substage_validation: Mapping[str, ValidationReport], ) -> tuple[SubstageRecord, ...]: output_by_substage: dict[str, list[ArtifactRef]] = { spec.id: [] for spec in STAGE_1_BUILD_STEP_SPECS @@ -189,6 +201,7 @@ def _stage_1_substages( status=status, outputs=tuple(output_by_substage[substage_id]), reuse_mode=spec.reuse_mode, + validation=substage_validation.get(substage_id), ) ) return tuple(records) diff --git a/tests/unit/test_build_dataset_checkpoints.py b/tests/unit/test_build_dataset_checkpoints.py new file mode 100644 index 000000000..0700f6bb1 --- /dev/null +++ b/tests/unit/test_build_dataset_checkpoints.py @@ -0,0 +1,105 @@ +from pathlib import Path + +from policyengine_us_data.build_datasets import ( + CheckpointReuseSummary, + CheckpointStore, +) + + +def _store(tmp_path: Path) -> CheckpointStore: + return CheckpointStore(root=tmp_path, branch="stage-1", commit_sha="abc123") + + +def test_checkpoint_store_scopes_paths_by_branch_and_commit(tmp_path): + store = _store(tmp_path) + + assert store.checkpoint_path("policyengine_us_data/storage/cps_2024.h5") == ( + tmp_path / "stage-1" / "abc123" / "cps_2024.h5" + ) + + +def test_checkpoint_store_requires_all_outputs_for_restore(tmp_path, monkeypatch): + store = _store(tmp_path) + checkpoint = store.checkpoint_path("a.txt") + checkpoint.parent.mkdir(parents=True) + checkpoint.write_text("cached") + monkeypatch.chdir(tmp_path) + + assert store.restore_all_outputs(("a.txt", "b.txt")) is False + assert not (tmp_path / "a.txt").exists() + + store.checkpoint_path("b.txt").write_text("cached-b") + assert store.restore_all_outputs(("a.txt", "b.txt")) is True + assert (tmp_path / "a.txt").read_text() == "cached" + assert (tmp_path / "b.txt").read_text() == "cached-b" + + +def test_missing_or_empty_checkpoint_invalidates_reuse(tmp_path): + store = _store(tmp_path) + empty = store.checkpoint_path("empty.txt") + empty.parent.mkdir(parents=True) + empty.write_text("") + + missing_decision = store.decision_for("missing.txt") + empty_decision = store.decision_for("empty.txt") + + assert missing_decision.action == "recompute" + assert missing_decision.reason == "missing" + assert empty_decision.action == "recompute" + assert empty_decision.reason == "empty" + assert store.all_outputs_reusable(("missing.txt", "empty.txt")) is False + + +def test_checkpoint_summary_matches_prior_counters(tmp_path): + store = _store(tmp_path) + checkpoint = store.checkpoint_path("valid.txt") + checkpoint.parent.mkdir(parents=True) + checkpoint.write_text("cached") + decisions = store.decisions_for(("valid.txt", "missing.txt")) + + reused_summary = CheckpointReuseSummary.from_decisions( + decisions, + recomputed=False, + ) + recomputed_summary = CheckpointReuseSummary.from_decisions( + decisions, + recomputed=True, + ) + + assert reused_summary.to_dict() == { + "expected_outputs": 2, + "valid_reused_outputs": 1, + "recomputed_outputs": 0, + "invalid_outputs": 1, + } + assert recomputed_summary.to_dict() == { + "expected_outputs": 2, + "valid_reused_outputs": 0, + "recomputed_outputs": 2, + "invalid_outputs": 1, + } + + +def test_checkpoint_cleanup_removes_only_branch_scope(tmp_path): + store = _store(tmp_path) + (tmp_path / "stage-1" / "abc123").mkdir(parents=True) + (tmp_path / "other-branch" / "abc123").mkdir(parents=True) + + assert store.cleanup_branch() is True + + assert not (tmp_path / "stage-1").exists() + assert (tmp_path / "other-branch" / "abc123").exists() + + +def test_checkpoint_cleanup_removes_only_other_commits(tmp_path): + store = _store(tmp_path) + current = tmp_path / "stage-1" / "abc123" + stale = tmp_path / "stage-1" / "old" + current.mkdir(parents=True) + stale.mkdir(parents=True) + + removed = store.cleanup_other_commits() + + assert removed == (stale,) + assert current.exists() + assert not stale.exists() diff --git a/tests/unit/test_build_dataset_commands.py b/tests/unit/test_build_dataset_commands.py new file mode 100644 index 000000000..3ea879acf --- /dev/null +++ b/tests/unit/test_build_dataset_commands.py @@ -0,0 +1,82 @@ +import sys + +import pytest + +from policyengine_us_data.build_datasets import ( + CommandRunner, + DatasetCommand, + DatasetCommandError, +) + + +def test_dataset_command_builds_python_module_command(): + command = DatasetCommand.from_script( + "policyengine_us_data/datasets/cps/cps.py", + python_executable="/python", + ) + + assert command.argv == ( + "/python", + "-u", + "-m", + "policyengine_us_data.datasets.cps.cps", + ) + assert command.name == "policyengine_us_data/datasets/cps/cps.py" + assert command.metadata["script_path"] == command.name + + +def test_dataset_command_keeps_external_python_script_path(): + command = DatasetCommand.from_script( + "scripts/example.py", + args=("--flag",), + python_executable="/python", + ) + + assert command.argv == ("/python", "-u", "scripts/example.py", "--flag") + assert command.kind == "python_script" + + +def test_dataset_command_represents_side_effecting_make_command(): + command = DatasetCommand( + name="make database", + argv=("make", "database"), + kind="side_effect", + ) + + assert command.side_effecting is True + assert command.argv == ("make", "database") + + +def test_command_runner_raises_structured_failure(): + command = DatasetCommand( + name="failing command", + argv=( + sys.executable, + "-c", + "import sys; print('structured failure'); sys.exit(7)", + ), + ) + + with pytest.raises(DatasetCommandError) as exc_info: + CommandRunner(output_tail_lines=5).run(command) + + result = exc_info.value.result + assert result.status == "failed" + assert result.returncode == 7 + assert result.error is not None + assert result.error.command_name == "failing command" + assert result.combined_output_tail == ("structured failure\n",) + + +def test_command_runner_can_return_structured_failure_without_raising(): + command = DatasetCommand( + name="nonraising command", + argv=(sys.executable, "-c", "import sys; sys.exit(3)"), + ) + + result = CommandRunner().run(command, check=False) + + assert result.status == "failed" + assert result.returncode == 3 + assert result.error is not None + assert result.error.returncode == 3 diff --git a/tests/unit/test_build_dataset_coordinator.py b/tests/unit/test_build_dataset_coordinator.py new file mode 100644 index 000000000..0e3d80796 --- /dev/null +++ b/tests/unit/test_build_dataset_coordinator.py @@ -0,0 +1,183 @@ +from pathlib import Path + +import pytest + +from policyengine_us_data.build_datasets import ( + Stage1Coordinator, + stage_1_substep_id_for_script, + stage_1_substep_title, +) +from policyengine_us_data.stage_contracts import ValidationFinding, ValidationReport + + +def test_stage_1_substep_mapping_uses_artifact_specs(): + assert ( + stage_1_substep_id_for_script("policyengine_us_data/datasets/cps/cps.py") + == "1b_base_dataset_construction" + ) + assert stage_1_substep_title("1b_base_dataset_construction") == ( + "Base dataset construction" + ) + + +def test_coordinator_records_completed_substep_and_artifacts(tmp_path): + coordinator = Stage1Coordinator() + artifact = tmp_path / "artifact.h5" + + def action(): + artifact.write_text("ok") + return "done" + + result = coordinator.run_substep( + "1b_base_dataset_construction", + "Base dataset construction", + action, + command_names=("build-cps",), + artifact_paths=(artifact,), + ) + + assert result == "done" + [substep_result] = coordinator.results + assert substep_result.status == "completed" + assert substep_result.started_at is not None + assert substep_result.completed_at is not None + assert substep_result.duration_s is not None + assert substep_result.command_names == ("build-cps",) + assert substep_result.artifact_paths == (str(artifact),) + assert [event.status for event in coordinator.status_events] == [ + "started", + "completed", + ] + + +def test_coordinator_records_skipped_substep_not_completed(): + coordinator = Stage1Coordinator() + + coordinator.run_substep( + "1f_source_imputation", + "Source imputation", + lambda: None, + command_names=("source-impute",), + skip=True, + skip_reason="--skip-stage-5", + ) + + [result] = coordinator.results + assert result.status == "skipped" + assert result.started_at is None + assert result.metadata["skip_reason"] == "--skip-stage-5" + assert coordinator.status_events[-1].status == "skipped" + + +def test_coordinator_records_failure_without_parsing_terminal_text(): + coordinator = Stage1Coordinator() + + def action(): + raise RuntimeError("structured failure") + + with pytest.raises(RuntimeError, match="structured failure"): + coordinator.run_substep( + "1c_extended_cps_puf_clone", + "Extended CPS PUF clone", + action, + command_names=("extended-cps",), + ) + + [result] = coordinator.results + assert result.status == "failed" + assert result.error is not None + assert result.error.error_type == "RuntimeError" + assert result.error.command_name == "extended-cps" + assert coordinator.error_records == [result.error] + + +def test_fake_substep_runner_collects_tiny_artifacts(tmp_path: Path): + coordinator = Stage1Coordinator() + outputs = [tmp_path / "one.txt", tmp_path / "two.txt"] + + def action(): + for path in outputs: + path.write_text(path.stem) + + coordinator.run_substep( + "1g_stage_base_datasets", + "Stage base datasets", + action, + command_names=("fake-stager",), + artifact_paths=outputs, + ) + + [result] = coordinator.results + assert result.status == "completed" + assert result.artifact_paths == tuple(str(path) for path in outputs) + + +class _PassValidationRunner: + def run_for_substep_result(self, result): + return ValidationReport( + status="pass", + metadata={"substage_id": result.substep_id, "check_ids": ["check.pass"]}, + ) + + def should_stop(self, report): + return report.status == "fail" + + +class _FailValidationRunner: + def run_for_substep_result(self, result): + return ValidationReport( + status="fail", + findings=( + ValidationFinding( + check_id="check.fail", + status="fail", + message="validator failed", + ), + ), + metadata={"substage_id": result.substep_id, "check_ids": ["check.fail"]}, + ) + + def should_stop(self, report): + return report.status == "fail" + + +def test_coordinator_attaches_validation_report(tmp_path: Path): + coordinator = Stage1Coordinator(validation_runner=_PassValidationRunner()) + artifact = tmp_path / "artifact.h5" + + def action(): + artifact.write_text("ok") + + coordinator.run_substep( + "1b_base_dataset_construction", + "Base dataset construction", + action, + artifact_paths=(artifact,), + ) + + [result] = coordinator.results + assert result.validation_report["status"] == "pass" + assert coordinator.status_events[-1].metadata["validation_report"]["status"] == ( + "pass" + ) + + +def test_coordinator_stops_after_error_level_validation_failure(tmp_path: Path): + coordinator = Stage1Coordinator(validation_runner=_FailValidationRunner()) + artifact = tmp_path / "artifact.h5" + + def action(): + artifact.write_text("ok") + + with pytest.raises(RuntimeError, match="Stage 1 validation failed"): + coordinator.run_substep( + "1b_base_dataset_construction", + "Base dataset construction", + action, + artifact_paths=(artifact,), + ) + + [result] = coordinator.results + assert result.status == "failed" + assert result.validation_report["status"] == "fail" + assert coordinator.error_records == [result.error] diff --git a/tests/unit/test_build_dataset_rerun.py b/tests/unit/test_build_dataset_rerun.py new file mode 100644 index 000000000..1e4dfcfcc --- /dev/null +++ b/tests/unit/test_build_dataset_rerun.py @@ -0,0 +1,101 @@ +from policyengine_us_data.build_datasets import ( + Stage1Coordinator, + Stage1IdentityMaterial, + Stage1RerunPlanner, +) + + +def _material(**overrides) -> Stage1IdentityMaterial: + values = { + "substep_id": "1b_base_dataset_construction", + "code_sha": "abc123", + "schema_version": "stage-1-rerun-v1", + "inputs": {"dataset": "cps"}, + "parameters": {"period": 2024}, + "artifact_specs": ({"filename": "cps_2024.h5"},), + "upstream_contract_fingerprints": ("sha256:upstream",), + "randomness": {"seed": 1}, + } + values.update(overrides) + return Stage1IdentityMaterial(**values) + + +def test_rerun_planner_reuses_matching_identity(): + material = _material() + planner = Stage1RerunPlanner( + previous_identities={material.substep_id: material.fingerprint()} + ) + + decision = planner.decide(material, run_id="run-a", rerun_id="attempt-2") + + assert decision.action == "reuse" + assert decision.reason == "identity_match" + assert decision.rerun_id == "attempt-2" + + +def test_rerun_planner_recomputes_mismatched_parameters(): + previous = _material() + current = _material(parameters={"period": 2025}) + planner = Stage1RerunPlanner( + previous_identities={previous.substep_id: previous.fingerprint()} + ) + + decision = planner.decide(current, run_id="run-a") + + assert decision.action == "recompute" + assert decision.reason == "identity_mismatch" + + +def test_rerun_planner_recomputes_mismatched_schema_or_upstream_fingerprint(): + previous = _material() + planner = Stage1RerunPlanner( + previous_identities={previous.substep_id: previous.fingerprint()} + ) + + schema_decision = planner.decide( + _material(schema_version="stage-1-rerun-v2"), + run_id="run-a", + ) + upstream_decision = planner.decide( + _material(upstream_contract_fingerprints=("sha256:changed",)), + run_id="run-a", + ) + + assert schema_decision.action == "recompute" + assert upstream_decision.action == "recompute" + + +def test_rerun_id_does_not_change_artifact_namespace(): + material = _material() + + decision = Stage1RerunPlanner().decide( + material, + run_id="canonical-run", + rerun_id="attempt-2", + ) + + assert decision.artifact_namespace == "canonical-run" + assert decision.rerun_id == "attempt-2" + + +def test_blocked_decision_serializes_into_substep_status(): + material = _material(blocked_reason="missing upstream contract") + decision = Stage1RerunPlanner().decide(material, run_id="run-a") + coordinator = Stage1Coordinator() + + coordinator.run_substep( + material.substep_id, + "Base dataset construction", + lambda: None, + reuse_decision=decision.to_dict(), + checkpoint_decisions=(), + skip=True, + skip_reason="blocked", + ) + + [result] = coordinator.results + assert decision.action == "blocked" + assert result.reuse_decision == decision.to_dict() + assert coordinator.status_events[-1].metadata["reuse_decision"] == ( + decision.to_dict() + ) diff --git a/tests/unit/test_build_dataset_specs.py b/tests/unit/test_build_dataset_specs.py index 5f13764fb..d616367af 100644 --- a/tests/unit/test_build_dataset_specs.py +++ b/tests/unit/test_build_dataset_specs.py @@ -128,6 +128,7 @@ def test_stage_1_skip_flags_identify_expected_artifacts(): } <= enhanced_cps_skipped assert { "small_enhanced_cps_2024.h5", + "source_dataset_schema_summary.json", "source_imputed_stratified_extended_cps_2024.h5", "source_imputed_stratified_extended_cps.h5", } == stage_5_skipped diff --git a/tests/unit/test_build_dataset_staging.py b/tests/unit/test_build_dataset_staging.py new file mode 100644 index 000000000..b69759573 --- /dev/null +++ b/tests/unit/test_build_dataset_staging.py @@ -0,0 +1,216 @@ +import json +import sqlite3 +from pathlib import Path + +import h5py +import pytest + +from policyengine_us_data.build_datasets import ( + DatasetBuildContext, + DatasetBuildOutputContractBuilder, + DatasetInventoryWriter, + PipelineArtifactStager, + SourceDatasetSchemaSummaryWriter, + TargetDatabaseSchemaSummaryWriter, + stage_1_pipeline_artifact_specs, + write_stage_1_diagnostics, +) + + +def _context(tmp_path: Path) -> DatasetBuildContext: + return DatasetBuildContext( + run_id="run-123", + branch="main", + code_sha="abc123", + package_version="1.98.2", + artifacts_dir=tmp_path / "artifacts", + storage_dir=tmp_path / "storage", + work_dir=tmp_path / "work", + ) + + +def _write_h5(path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with h5py.File(path, "w") as h5_file: + person = h5_file.create_group("person") + person.create_dataset("age/2024", data=[1, 2, 3]) + household = h5_file.create_group("household") + household.create_dataset("weight/2024", data=[10.0, 20.0]) + + +def _write_sqlite(path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with sqlite3.connect(path) as conn: + conn.execute("CREATE TABLE targets (id INTEGER PRIMARY KEY, value REAL)") + conn.execute("CREATE TABLE notes (id INTEGER PRIMARY KEY, label TEXT)") + conn.execute("INSERT INTO targets (value) VALUES (1.5), (2.5)") + conn.execute("INSERT INTO notes (label) VALUES ('a')") + + +def _write_text(path: Path, payload: str = "x") -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(payload) + + +def _write_required_storage_artifacts( + context: DatasetBuildContext, + *, + include_enhanced_cps: bool = True, + include_stage_5: bool = True, + include_optional_weights: bool = False, +) -> None: + for spec in stage_1_pipeline_artifact_specs(): + if spec.diagnostic_output or spec.yearless_alias or spec.storage_path is None: + continue + if not include_enhanced_cps and spec.skip_when_enhanced_cps_skipped: + continue + if not include_stage_5 and spec.skip_when_stage_5_skipped: + continue + if not spec.required and not include_optional_weights: + continue + path = context.source_path(spec.storage_path) + if path.suffix == ".h5": + _write_h5(path) + elif path.suffix == ".db": + _write_sqlite(path) + else: + _write_text(path, spec.logical_name) + + +def test_stager_copies_only_declared_artifacts(tmp_path): + context = _context(tmp_path) + _write_required_storage_artifacts(context, include_optional_weights=True) + extra = context.storage_dir / "untracked.h5" + _write_h5(extra) + + staged = PipelineArtifactStager(context=context).stage_declared_artifacts() + + staged_names = {path.name for path in staged} + assert "untracked.h5" not in staged_names + assert "acs_2022.h5" in staged_names + assert "policy_data.db" in staged_names + assert "calibration_weights.npy" in staged_names + + +def test_stager_creates_yearless_source_imputed_alias(tmp_path): + context = _context(tmp_path) + _write_required_storage_artifacts(context) + + PipelineArtifactStager(context=context).stage_declared_artifacts() + + assert ( + context.artifacts_dir / "source_imputed_stratified_extended_cps_2024.h5" + ).exists() + alias = context.artifacts_dir / "source_imputed_stratified_extended_cps.h5" + assert alias.exists() + with h5py.File(alias) as h5_file: + assert list(h5_file["person"]["age"]["2024"]) == [1, 2, 3] + + +def test_stager_fails_on_missing_required_artifact(tmp_path): + context = _context(tmp_path) + + with pytest.raises(FileNotFoundError, match="acs_2022.h5"): + PipelineArtifactStager(context=context).stage_declared_artifacts() + + +def test_stager_respects_skip_flags_for_optional_ecps_paths(tmp_path): + context = _context(tmp_path) + _write_required_storage_artifacts( + context, + include_enhanced_cps=False, + include_stage_5=False, + ) + + staged = PipelineArtifactStager(context=context).stage_declared_artifacts( + skip_enhanced_cps=True, + skip_stage_5=True, + ) + + staged_names = {path.name for path in staged} + assert "enhanced_cps_2024.h5" not in staged_names + assert "small_enhanced_cps_2024.h5" not in staged_names + assert "source_imputed_stratified_extended_cps_2024.h5" not in staged_names + assert "source_imputed_stratified_extended_cps.h5" not in staged_names + + +def test_dataset_inventory_contains_each_staged_artifact_once(tmp_path): + context = _context(tmp_path) + _write_required_storage_artifacts(context, include_optional_weights=True) + stager = PipelineArtifactStager(context=context) + stager.stage_declared_artifacts() + stager.write_checkpoint_stats({"expected_outputs": 3}) + + diagnostic = DatasetInventoryWriter(context=context).write() + + inventory_path = context.artifacts_dir / "dataset_inventory.json" + payload = json.loads(inventory_path.read_text()) + logical_names = [artifact["logical_name"] for artifact in payload["artifacts"]] + assert len(logical_names) == len(set(logical_names)) + assert "policy_data_db" in logical_names + assert "data_build_checkpoint_stats" in logical_names + assert diagnostic.artifact.logical_name == "dataset_inventory" + assert diagnostic.summary["artifact_count"] == len(logical_names) + + +def test_source_dataset_schema_summary_is_metadata_only(tmp_path): + context = _context(tmp_path) + context.artifacts_dir.mkdir(parents=True) + _write_h5(context.artifacts_dir / "source_imputed_stratified_extended_cps.h5") + + diagnostic = SourceDatasetSchemaSummaryWriter(context=context).write() + + payload = json.loads( + (context.artifacts_dir / "source_dataset_schema_summary.json").read_text() + ) + assert payload["logical_name"] == "source_imputed_stratified_extended_cps" + assert payload["entities"]["person"]["variables"] == ["age"] + assert payload["entities"]["household"]["row_counts"] == { + "household/weight/2024": 2 + } + assert diagnostic.summary == {"dataset_count": 2, "entity_count": 2} + + +def test_target_database_summary_reports_tables_and_row_counts(tmp_path): + context = _context(tmp_path) + context.artifacts_dir.mkdir(parents=True) + _write_sqlite(context.artifacts_dir / "policy_data.db") + + diagnostic = TargetDatabaseSchemaSummaryWriter(context=context).write() + + payload = json.loads( + (context.artifacts_dir / "target_database_schema_summary.json").read_text() + ) + assert [table["name"] for table in payload["tables"]] == ["notes", "targets"] + row_counts = {table["name"]: table["row_count"] for table in payload["tables"]} + assert row_counts == {"notes": 1, "targets": 2} + assert payload["known_target_tables"] == ["targets"] + assert diagnostic.summary["table_count"] == 2 + assert diagnostic.summary["known_target_tables"] == ("targets",) + + +def test_contract_builder_records_stage_1_diagnostics(tmp_path): + context = _context(tmp_path) + _write_required_storage_artifacts(context) + stager = PipelineArtifactStager(context=context) + stager.stage_declared_artifacts() + stager.write_checkpoint_stats({"expected_outputs": 3}) + diagnostics = write_stage_1_diagnostics(context=context) + + contract = DatasetBuildOutputContractBuilder(context=context).build( + checkpoint_stats={"expected_outputs": 3}, + started_at="2026-05-08T12:00:00Z", + completed_at="2026-05-08T12:01:00Z", + duration_s=60.0, + upload_requested=True, + stage_only=True, + skip_enhanced_cps=False, + diagnostics=diagnostics, + ) + + assert {diagnostic.name for diagnostic in contract.diagnostics} == { + "dataset_inventory", + "source_dataset_schema_summary", + "target_database_schema_summary", + } + assert contract.metadata["diagnostic_count"] == 3 diff --git a/tests/unit/test_build_dataset_validation.py b/tests/unit/test_build_dataset_validation.py new file mode 100644 index 000000000..0dec56183 --- /dev/null +++ b/tests/unit/test_build_dataset_validation.py @@ -0,0 +1,113 @@ +from pathlib import Path + +from policyengine_us_data.build_datasets import ( + DatasetSubstepResult, + Stage1ValidationContext, + Stage1ValidationRunner, + ValidationTargetCatalog, + iter_stage_1_validators, + stage_1_step_specs, + validators_for_substage, +) + + +def _substep_result(path: Path) -> DatasetSubstepResult: + return DatasetSubstepResult( + substep_id="1b_base_dataset_construction", + title="Base dataset construction", + status="completed", + started_at="2026-05-20T12:00:00Z", + completed_at="2026-05-20T12:00:01Z", + duration_s=1.0, + command_names=("cps.py",), + artifact_paths=(str(path),), + ) + + +def test_stage_1_validator_registry_is_wired_from_step_specs(): + validator_ids = {validator.validator_id for validator in iter_stage_1_validators()} + + for spec in stage_1_step_specs(): + assert set(spec.validation_ids) <= validator_ids + + [validator] = validators_for_substage("1c_extended_cps_puf_clone") + assert validator.validator_id == ( + "stage_1.1c_extended_cps_puf_clone.artifact_contract" + ) + + +def test_stage_1_validation_runner_validates_substep_artifact(tmp_path): + artifact = tmp_path / "cps_2024.h5" + artifact.write_bytes(b"tiny") + + report = Stage1ValidationRunner(run_id="run-a").run_for_substep_result( + _substep_result(artifact) + ) + + assert report.status == "pass" + assert report.metadata["substage_id"] == "1b_base_dataset_construction" + assert report.findings == () + + +def test_stage_1_validation_runner_reports_missing_required_logical_artifact(): + context = Stage1ValidationContext( + run_id="run-a", + substage_id="1b_base_dataset_construction", + artifact_refs={}, + ) + + report = Stage1ValidationRunner(run_id="run-a").run_for_context( + context, + required_artifacts=("cps_2024",), + ) + + assert report.status == "fail" + [finding] = report.findings + assert finding.check_id == "stage_1.1b_base_dataset_construction.artifact_contract" + assert finding.metric == "required_artifact" + assert finding.value == "cps_2024" + + +def test_validation_target_catalog_loads_active_targets_deterministically(): + catalog = ValidationTargetCatalog.from_stage_1_specs( + skip_enhanced_cps=True, + skip_stage_5=True, + ) + + ids = [target.target_id for target in catalog.targets] + assert ids == sorted(ids) + assert "small_enhanced_cps_2024" not in catalog.required_logical_names( + "1d_enhanced_cps_reweighting" + ) + assert catalog.required_logical_names("1g_stage_base_datasets") == ( + "build_log", + "data_build_checkpoint_stats", + "policy_data_db", + ) + + +def test_stage_1_validation_runner_rejects_empty_artifacts(tmp_path): + artifact = tmp_path / "cps_2024.h5" + artifact.touch() + + report = Stage1ValidationRunner(run_id="run-a").run_for_substep_result( + _substep_result(artifact) + ) + + assert report.status == "fail" + assert report.findings[0].metric == "artifact_size_bytes" + + +def test_stage_1_validation_runner_rejects_missing_declared_artifacts(tmp_path): + artifact = tmp_path / "cps_2024.h5" + + report = Stage1ValidationRunner(run_id="run-a").run_for_substep_result( + _substep_result(artifact) + ) + + assert report.status == "fail" + assert report.findings[0].metric == "artifact_exists" + + +def test_validators_for_unknown_substage_returns_empty_tuple(): + assert validators_for_substage("unknown") == () diff --git a/tests/unit/test_build_dataset_validation_results.py b/tests/unit/test_build_dataset_validation_results.py new file mode 100644 index 000000000..60ce248d1 --- /dev/null +++ b/tests/unit/test_build_dataset_validation_results.py @@ -0,0 +1,75 @@ +import json +import sqlite3 + +from policyengine_us_data.build_datasets import Stage1ValidationResultWriter +from policyengine_us_data.stage_contracts import ValidationFinding, ValidationReport + + +def _report(*, substage_id: str, status: str = "fail") -> ValidationReport: + findings = () + if status == "fail": + findings = ( + ValidationFinding( + check_id=f"stage_1.{substage_id}.artifact_contract", + status="fail", + message="missing artifact", + metric="artifact_exists", + value="missing", + ), + ) + return ValidationReport( + status=status, + findings=findings, + metadata={ + "stage_id": "1_build_datasets", + "substage_id": substage_id, + "check_ids": [f"stage_1.{substage_id}.artifact_contract"], + }, + ) + + +def test_stage_1_validation_result_writer_writes_queryable_outputs(tmp_path): + summary = Stage1ValidationResultWriter(output_dir=tmp_path).write( + [ + _report(substage_id="1b_base_dataset_construction"), + _report(substage_id="1c_extended_cps_puf_clone", status="pass"), + ] + ) + + assert summary.report.status == "fail" + assert (tmp_path / "1b_base_dataset_construction.json").exists() + assert (tmp_path / "1c_extended_cps_puf_clone.json").exists() + assert json.loads((tmp_path / "summary.json").read_text())["status"] == "fail" + assert json.loads((tmp_path / "findings.jsonl").read_text())["metric"] == ( + "artifact_exists" + ) + metrics = [ + json.loads(line) + for line in (tmp_path / "metrics.jsonl").read_text().splitlines() + ] + assert {row["substage_id"] for row in metrics} == { + "1b_base_dataset_construction", + "1c_extended_cps_puf_clone", + } + + with sqlite3.connect(tmp_path / "validation_results.sqlite") as connection: + rows = connection.execute( + "SELECT substage_id, status FROM reports ORDER BY substage_id" + ).fetchall() + assert rows == [ + ("1b_base_dataset_construction", "fail"), + ("1c_extended_cps_puf_clone", "pass"), + ] + assert {diagnostic.kind for diagnostic in summary.diagnostics} >= { + "json", + "jsonl", + "sqlite", + } + + +def test_stage_1_validation_result_writer_handles_no_reports(tmp_path): + summary = Stage1ValidationResultWriter(output_dir=tmp_path).write([]) + + assert summary.report.status == "not_run" + assert json.loads((tmp_path / "summary.json").read_text())["status"] == "not_run" + assert (tmp_path / "validation_results.sqlite").exists() diff --git a/tests/unit/test_dataset_build_stage_contract.py b/tests/unit/test_dataset_build_stage_contract.py index ae36424c3..31de4b102 100644 --- a/tests/unit/test_dataset_build_stage_contract.py +++ b/tests/unit/test_dataset_build_stage_contract.py @@ -1,7 +1,10 @@ from pathlib import Path from policyengine_us_data.stage_contracts import ( + DiagnosticRef, StageContract, + ValidationFinding, + ValidationReport, contract_from_json, contract_to_json, ) @@ -205,3 +208,90 @@ def test_dataset_build_contract_fingerprint_excludes_run_id(tmp_path): def test_dataset_build_contract_filename_is_stable(): assert DATASET_BUILD_OUTPUT_CONTRACT_FILENAME == "dataset_build_output.json" + + +def test_dataset_build_contract_records_diagnostic_refs(tmp_path): + _write_artifacts(tmp_path) + diagnostic = DiagnosticRef( + name="dataset_inventory", + kind="dataset_inventory", + summary={"artifact_count": 13}, + ) + + contract = build_dataset_build_output_contract( + artifacts_dir=tmp_path, + run_id="run-a", + code_sha="abc123", + package_version="1.98.2", + checkpoint_stats={"expected_outputs": 4}, + started_at="2026-05-08T12:00:00Z", + completed_at="2026-05-08T12:01:00Z", + diagnostics=(diagnostic,), + ) + + assert contract.diagnostics == (diagnostic,) + assert contract.metadata["diagnostic_count"] == 1 + + +def test_dataset_build_contract_records_stage_1_status_metadata(tmp_path): + _write_artifacts(tmp_path) + + contract = build_dataset_build_output_contract( + artifacts_dir=tmp_path, + run_id="run-a", + code_sha="abc123", + package_version="1.98.2", + checkpoint_stats={"expected_outputs": 4}, + started_at="2026-05-08T12:00:00Z", + completed_at="2026-05-08T12:01:00Z", + stage_1_status_metadata={ + "substep_results": [ + { + "substep_id": "1b_base_dataset_construction", + "reuse_decision": {"action": "reuse"}, + } + ] + }, + ) + + assert contract.metadata["stage_1_status"]["substep_results"][0][ + "reuse_decision" + ] == {"action": "reuse"} + + +def test_dataset_build_contract_records_validation_reports(tmp_path): + _write_artifacts(tmp_path) + substage_report = ValidationReport( + status="fail", + findings=( + ValidationFinding( + check_id="stage_1.1b_base_dataset_construction.artifact_contract", + status="fail", + message="missing CPS", + ), + ), + metadata={"substage_id": "1b_base_dataset_construction"}, + ) + aggregate_report = ValidationReport( + status="fail", + findings=substage_report.findings, + metadata={"stage_id": "1_build_datasets"}, + ) + + contract = build_dataset_build_output_contract( + artifacts_dir=tmp_path, + run_id="run-a", + code_sha="abc123", + package_version="1.98.2", + checkpoint_stats={"expected_outputs": 4}, + started_at="2026-05-08T12:00:00Z", + completed_at="2026-05-08T12:01:00Z", + validation=aggregate_report, + substage_validation={ + "1b_base_dataset_construction": substage_report, + }, + ) + + assert contract.validation == aggregate_report + records = {record.substage_id: record for record in contract.substages} + assert records["1b_base_dataset_construction"].validation == substage_report diff --git a/tests/unit/test_modal_data_build.py b/tests/unit/test_modal_data_build.py index eed66aec7..19d809031 100644 --- a/tests/unit/test_modal_data_build.py +++ b/tests/unit/test_modal_data_build.py @@ -3,7 +3,11 @@ from datetime import datetime, timedelta, timezone from types import ModuleType, SimpleNamespace -from policyengine_us_data.build_datasets import stage_1_script_outputs +from policyengine_us_data.build_datasets import ( + CheckpointStore, + Stage1ReuseDecision, + stage_1_script_outputs, +) from policyengine_us_data.stage_contracts import read_contract @@ -309,6 +313,7 @@ def test_write_dataset_build_contract_writes_stage_1_handoff(tmp_path): upload_requested=False, stage_only=True, skip_enhanced_cps=True, + branch="stage-1", ) contract_path = tmp_path / "dataset_build_output.json" @@ -318,6 +323,45 @@ def test_write_dataset_build_contract_writes_stage_1_handoff(tmp_path): assert contract.parameters["stage_only"] is True +def test_run_script_with_checkpoint_rejects_blocked_reuse_decision( + tmp_path, + monkeypatch, +): + data_build = _load_data_build_module() + + def fail_run_script(*args, **kwargs): + raise AssertionError("blocked decisions must not run commands") + + monkeypatch.setattr(data_build, "run_script", fail_run_script) + decision = Stage1ReuseDecision( + run_id="run-a", + rerun_id="attempt-2", + artifact_namespace="run-a", + substep_id="1b_base_dataset_construction", + action="blocked", + reason="identity_mismatch", + identity_fingerprint="sha256:abc", + ) + + try: + data_build.run_script_with_checkpoint( + "policyengine_us_data/datasets/cps/cps.py", + "cps_2024.h5", + "branch", + object(), + checkpoint_store=CheckpointStore( + root=tmp_path, + branch="branch", + commit_sha="abc123", + ), + reuse_decision=decision, + ) + except RuntimeError as exc: + assert "identity_mismatch" in str(exc) + else: + raise AssertionError("blocked reuse decision should fail") + + def test_utc_timestamp_renders_zulu_time_for_build_log(): data_build = _load_data_build_module() budapest_summer = timezone(timedelta(hours=2))