From e1ea12618ca261dd22446522e6eaa94eff3f3e91 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 21 May 2026 21:31:19 +0200 Subject: [PATCH] Add Stage 2 calibration package validation runner --- changelog.d/1104.added | 1 + modal_app/remote_calibration_runner.py | 32 + .../calibration/unified_calibration.py | 54 +- .../calibration_package/__init__.py | 6 + .../calibration_package/specs.py | 28 +- .../calibration_package/validation.py | 974 ++++++++++++++++++ .../test_unified_calibration_build_only.py | 50 + tests/unit/calibration_package/test_specs.py | 21 + .../calibration_package/test_validation.py | 352 +++++++ tests/unit/test_remote_calibration_runner.py | 64 ++ 10 files changed, 1568 insertions(+), 14 deletions(-) create mode 100644 changelog.d/1104.added create mode 100644 policyengine_us_data/calibration_package/validation.py create mode 100644 tests/unit/calibration/test_unified_calibration_build_only.py create mode 100644 tests/unit/calibration_package/test_validation.py diff --git a/changelog.d/1104.added b/changelog.d/1104.added new file mode 100644 index 000000000..7edad57b3 --- /dev/null +++ b/changelog.d/1104.added @@ -0,0 +1 @@ +Added canonical Stage 2 calibration-package validation artifacts. diff --git a/modal_app/remote_calibration_runner.py b/modal_app/remote_calibration_runner.py index a81764fba..5485a8237 100644 --- a/modal_app/remote_calibration_runner.py +++ b/modal_app/remote_calibration_runner.py @@ -443,6 +443,12 @@ def _build_package_impl( label="build", ) if build_rc != 0: + if ( + package_artifacts.validation_report.exists() + or package_artifacts.validation_findings.exists() + or package_artifacts.validation_summary.exists() + ): + pipeline_vol.commit() raise RuntimeError(f"Package build failed with code {build_rc}") from policyengine_us_data.stage_contracts.calibration_package import ( @@ -456,6 +462,32 @@ def _build_package_impl( db_path=Path(db_path), ) + from policyengine_us_data.calibration_package.validation import ( # noqa: E402 + CalibrationPackageValidator, + format_validation_report, + ) + + validator = CalibrationPackageValidator() + validation_report = validator.validate_and_write( + package_path=package_artifacts.package, + contract_path=package_artifacts.contract, + dataset_path=Path(dataset_path), + db_path=Path(db_path), + reports_dir=package_artifacts.reports_dir, + targets_path=package_artifacts.targets, + target_facets_path=package_artifacts.target_facets, + geography_summary_path=package_artifacts.geography_summary, + matrix_summary_path=package_artifacts.matrix_summary, + run_id=run_id or None, + ) + print( + format_validation_report(validation_report, package_path=pkg_path), + flush=True, + ) + if validation_report.status == "fail": + pipeline_vol.commit() + validator.raise_for_failure(validation_report) + sidecar_ok = _write_package_sidecar(pkg_path) if not sidecar_ok: print( diff --git a/policyengine_us_data/calibration/unified_calibration.py b/policyengine_us_data/calibration/unified_calibration.py index 87994927f..87fe82f0e 100644 --- a/policyengine_us_data/calibration/unified_calibration.py +++ b/policyengine_us_data/calibration/unified_calibration.py @@ -56,6 +56,8 @@ ) from policyengine_us_data.calibration_package.specs import ( DEFAULT_TARGET_CONFIG_PATH as DEFAULT_TARGET_CONFIG_RELATIVE_PATH, + CALIBRATION_PACKAGE_CONTRACT_FILENAME, + CALIBRATION_REPORTS_DIRNAME, CALIBRATION_TARGET_FACETS_FILENAME, CALIBRATION_TARGETS_FILENAME, GEOGRAPHY_ASSIGNMENT_SUMMARY_FILENAME, @@ -1837,20 +1839,46 @@ def run_calibration( ) if build_only: - from policyengine_us_data.calibration.validate_package import ( - validate_package, - format_report, - ) + if package_output_path: + from policyengine_us_data.calibration_package.validation import ( + CalibrationPackageValidator, + format_validation_report, + ) - package = { - "X_sparse": X_sparse, - "targets_df": targets_df, - "target_names": target_names, - "metadata": metadata, - "initial_weights": initial_weights, - } - result = validate_package(package) - print(format_report(result)) + validator = CalibrationPackageValidator() + validation_report = validator.validate_and_write( + package_path=package_path, + contract_path=package_path.with_name( + CALIBRATION_PACKAGE_CONTRACT_FILENAME + ), + dataset_path=Path(dataset_path), + db_path=Path(db_path), + reports_dir=package_path.with_name(CALIBRATION_REPORTS_DIRNAME), + targets_path=targets_path, + target_facets_path=target_facets_path, + geography_summary_path=geography_summary_path, + matrix_summary_path=matrix_summary_path, + run_id=run_id, + ) + print( + format_validation_report(validation_report, package_path=package_path) + ) + validator.raise_for_failure(validation_report) + else: + from policyengine_us_data.calibration.validate_package import ( + format_report, + validate_package, + ) + + package = { + "X_sparse": X_sparse, + "targets_df": targets_df, + "target_names": target_names, + "metadata": metadata, + "initial_weights": initial_weights, + } + result = validate_package(package) + print(format_report(result)) geography_info = { "cd_geoid": geography.cd_geoid, "block_geoid": geography.block_geoid, diff --git a/policyengine_us_data/calibration_package/__init__.py b/policyengine_us_data/calibration_package/__init__.py index aa2d2e52e..44b72bafa 100644 --- a/policyengine_us_data/calibration_package/__init__.py +++ b/policyengine_us_data/calibration_package/__init__.py @@ -14,6 +14,9 @@ MATRIX_BUILD_DIRNAME, MATRIX_SUMMARY_FILENAME, SOURCE_DATASET_FILENAME, + STAGE2_VALIDATION_FINDINGS_FILENAME, + STAGE2_VALIDATION_REPORT_FILENAME, + STAGE2_VALIDATION_SUMMARY_FILENAME, TARGET_CONFIG_IDENTITY_MODES, TARGET_DATABASE_FILENAME, CalibrationPackageArtifactPaths, @@ -90,6 +93,9 @@ "MATRIX_BUILDER_MODES", "MATRIX_SUMMARY_FILENAME", "SOURCE_DATASET_FILENAME", + "STAGE2_VALIDATION_FINDINGS_FILENAME", + "STAGE2_VALIDATION_REPORT_FILENAME", + "STAGE2_VALIDATION_SUMMARY_FILENAME", "TARGET_CONFIG_IDENTITY_MODES", "TARGET_DATABASE_FILENAME", "CalibrationPackageArtifactPaths", diff --git a/policyengine_us_data/calibration_package/specs.py b/policyengine_us_data/calibration_package/specs.py index ca1dbe509..d88d0225a 100644 --- a/policyengine_us_data/calibration_package/specs.py +++ b/policyengine_us_data/calibration_package/specs.py @@ -26,6 +26,9 @@ GEOGRAPHY_ASSIGNMENT_SUMMARY_FILENAME = "geography_assignment_summary.json" MATRIX_SUMMARY_FILENAME = "matrix_summary.json" CALIBRATION_REPORTS_DIRNAME = "calibration_reports" +STAGE2_VALIDATION_REPORT_FILENAME = "validation_report.json" +STAGE2_VALIDATION_FINDINGS_FILENAME = "validation_findings.jsonl" +STAGE2_VALIDATION_SUMMARY_FILENAME = "validation_summary.json" MATRIX_BUILD_DIRNAME = "matrix_build" CALIBRATION_PACKAGE_SUBSTAGE_ID = "2a_matrix_build_calibration_target_construction" @@ -194,10 +197,15 @@ class CalibrationPackageOutputBundle: geography_summary: Path matrix_summary: Path reports_dir: Path + validation_report: Path + validation_findings: Path + validation_summary: Path matrix_build_dir: Path @property - def manifest_outputs(self) -> tuple[Path, Path, Path, Path, Path, Path]: + def manifest_outputs( + self, + ) -> tuple[Path, Path, Path, Path, Path, Path, Path, Path, Path]: """Return the durable Stage 2 outputs recorded in step manifests.""" return ( @@ -207,6 +215,9 @@ def manifest_outputs(self) -> tuple[Path, Path, Path, Path, Path, Path]: self.target_facets, self.geography_summary, self.matrix_summary, + self.validation_report, + self.validation_findings, + self.validation_summary, ) @@ -331,6 +342,9 @@ def stage2_input_bundle_from_stage1_contract_path( CALIBRATION_PACKAGE_CONTRACT_FILENAME, GEOGRAPHY_ASSIGNMENT_SUMMARY_FILENAME, MATRIX_SUMMARY_FILENAME, + STAGE2_VALIDATION_REPORT_FILENAME, + STAGE2_VALIDATION_FINDINGS_FILENAME, + STAGE2_VALIDATION_SUMMARY_FILENAME, ], validation_commands=[ "uv run pytest tests/unit/calibration_package/test_specs.py" @@ -412,6 +426,15 @@ def calibration_package_artifact_paths( geography_summary=root / GEOGRAPHY_ASSIGNMENT_SUMMARY_FILENAME, matrix_summary=root / MATRIX_SUMMARY_FILENAME, reports_dir=root / CALIBRATION_REPORTS_DIRNAME, + validation_report=root + / CALIBRATION_REPORTS_DIRNAME + / STAGE2_VALIDATION_REPORT_FILENAME, + validation_findings=root + / CALIBRATION_REPORTS_DIRNAME + / STAGE2_VALIDATION_FINDINGS_FILENAME, + validation_summary=root + / CALIBRATION_REPORTS_DIRNAME + / STAGE2_VALIDATION_SUMMARY_FILENAME, matrix_build_dir=root / MATRIX_BUILD_DIRNAME, ) @@ -528,6 +551,9 @@ def _artifact_uri_to_path(uri: str) -> Path: "MATRIX_BUILD_DIRNAME", "MATRIX_SUMMARY_FILENAME", "SOURCE_DATASET_FILENAME", + "STAGE2_VALIDATION_FINDINGS_FILENAME", + "STAGE2_VALIDATION_REPORT_FILENAME", + "STAGE2_VALIDATION_SUMMARY_FILENAME", "TARGET_CONFIG_IDENTITY_MODES", "TARGET_DATABASE_FILENAME", "CalibrationPackageArtifactPaths", diff --git a/policyengine_us_data/calibration_package/validation.py b/policyengine_us_data/calibration_package/validation.py new file mode 100644 index 000000000..eea4549ca --- /dev/null +++ b/policyengine_us_data/calibration_package/validation.py @@ -0,0 +1,974 @@ +"""Stage 2 calibration-package validation service.""" + +from __future__ import annotations + +import json +from collections.abc import Mapping +from dataclasses import dataclass, field, replace +from pathlib import Path +from typing import Any +from urllib.parse import unquote, urlparse + +import numpy as np + +from policyengine_us_data.calibration_package.payload import ( + CalibrationPackagePayload, + CalibrationPackageReader, +) +from policyengine_us_data.calibration_package.targets import target_facets_from_rows +from policyengine_us_data.calibration_package.specs import ( + CALIBRATION_PACKAGE_SUBSTAGE_ID, + STAGE2_VALIDATION_FINDINGS_FILENAME, + STAGE2_VALIDATION_REPORT_FILENAME, + STAGE2_VALIDATION_SUMMARY_FILENAME, +) +from policyengine_us_data.calibration_package.matrix import ( + ChunkCacheManifest, +) +from policyengine_us_data.pipeline_metadata import pipeline_node +from policyengine_us_data.pipeline_schema import PipelineNode +from policyengine_us_data.stage_contracts import ( + ArtifactRef, + StageContract, + ValidationFinding, + ValidationReport, +) +from policyengine_us_data.stage_contracts.calibration_package import ( + validate_persisted_calibration_package_contract, +) +from policyengine_us_data.stage_contracts.calibration_package_schema import ( + MatrixBuildSummary, +) +from policyengine_us_data.stage_contracts.io import read_contract, write_contract +from policyengine_us_data.stage_contracts.stages import ( + STAGE_2_BUILD_CALIBRATION_PACKAGE, +) +from policyengine_us_data.utils.manifest import compute_file_checksum +from policyengine_us_data.utils.step_manifest import sha256_file +from policyengine_us_data.validation_core import ( + ValidationArtifactResolver, + ValidationCheck, + ValidationContext, + ValidationResultWriter, + ValidationRunner, + ValidationSuite, +) + +__all__ = [ + "CalibrationPackageValidationError", + "CalibrationPackageValidator", + "format_validation_report", +] + +_CHECK_PACKAGE_LOADABLE = "stage2.package.loadable" +_CHECK_CONTRACT_MATCHES = "stage2.contract.matches_package" +_CHECK_TARGET_CONFIG = "stage2.target_config.identity" +_CHECK_MATRIX = "stage2.matrix.consistency" +_CHECK_TARGET_FRAME = "stage2.target_frame.consistency" +_CHECK_TARGET_METADATA = "stage2.target_metadata.consistency" +_CHECK_GEOGRAPHY = "stage2.geography.consistency" +_CHECK_INITIAL_WEIGHTS = "stage2.initial_weights.consistency" +_CHECK_CHUNK_MANIFEST = "stage2.chunk_manifest.consistency" +_TARGET_FRAME_REQUIRED_COLUMNS = frozenset( + {"value", "domain_variable", "variable", "geo_level", "geographic_id"} +) +_TARGET_CONFIG_IDENTITY_MODES = frozenset({"default", "explicit", "all_active_targets"}) + + +class CalibrationPackageValidationError(RuntimeError): + """Raised when Stage 2 calibration-package validation fails.""" + + def __init__(self, report: ValidationReport) -> None: + failing_ids = tuple( + finding.check_id for finding in report.findings if finding.status == "fail" + ) + message = "Stage 2 calibration package validation failed" + if failing_ids: + message += ": " + ", ".join(failing_ids) + super().__init__(message) + self.report = report + self.failing_ids = failing_ids + + +@pipeline_node( + PipelineNode( + id="stage2_calibration_package_validator", + label="Stage 2 Calibration Package Validator", + node_type="validation", + description="Validate Stage 2 package, target, matrix, geography, chunk, and contract artifacts through the shared validation core.", + source_file="policyengine_us_data/calibration_package/validation.py", + status="current", + stability="moving", + pathways=["calibration_package", "cross_stage_validation"], + artifacts_in=[ + "calibration_package.pkl", + "calibration_package_contract.json", + "calibration_targets.jsonl", + "calibration_target_facets.json", + "geography_assignment_summary.json", + "matrix_summary.json", + ], + artifacts_out=[ + STAGE2_VALIDATION_REPORT_FILENAME, + STAGE2_VALIDATION_FINDINGS_FILENAME, + STAGE2_VALIDATION_SUMMARY_FILENAME, + ], + validation_commands=[ + "uv run pytest tests/unit/calibration_package/test_validation.py" + ], + ) +) +@dataclass(frozen=True, kw_only=True) +class CalibrationPackageValidator: + """Validate Stage 2 calibration package artifacts with canonical reports.""" + + runner: ValidationRunner = field(default_factory=ValidationRunner) + + def validate( + self, + *, + package_path: str | Path, + contract_path: str | Path, + dataset_path: str | Path, + db_path: str | Path, + targets_path: str | Path | None = None, + target_facets_path: str | Path | None = None, + geography_summary_path: str | Path | None = None, + matrix_summary_path: str | Path | None = None, + run_id: str | None = None, + ) -> ValidationReport: + """Return a canonical validation report for Stage 2 artifacts.""" + + paths = { + "calibration_package": Path(package_path), + "calibration_package_contract": Path(contract_path), + "source_dataset": Path(dataset_path), + "target_database": Path(db_path), + } + optional_paths = { + "calibration_targets": targets_path, + "calibration_target_facets": target_facets_path, + "geography_assignment_summary": geography_summary_path, + "matrix_summary": matrix_summary_path, + } + for logical_name, path in optional_paths.items(): + if path is not None: + paths[logical_name] = Path(path) + + artifacts = _artifact_refs(paths) + context = ValidationContext( + run_id=run_id or "stage2-calibration-package", + stage_id=STAGE_2_BUILD_CALIBRATION_PACKAGE, + substage_id=CALIBRATION_PACKAGE_SUBSTAGE_ID, + resolver=ValidationArtifactResolver(artifacts=artifacts), + metadata={ + "package_path": str(package_path), + "contract_path": str(contract_path), + }, + ) + cache: dict[str, Any] = {} + return self.runner.run(_validation_suite(cache), context) + + def validate_and_write( + self, + *, + package_path: str | Path, + contract_path: str | Path, + dataset_path: str | Path, + db_path: str | Path, + reports_dir: str | Path, + targets_path: str | Path | None = None, + target_facets_path: str | Path | None = None, + geography_summary_path: str | Path | None = None, + matrix_summary_path: str | Path | None = None, + run_id: str | None = None, + attach_to_contract: bool = True, + ) -> ValidationReport: + """Validate, write report artifacts, and optionally attach the report.""" + + report = self.validate( + package_path=package_path, + contract_path=contract_path, + dataset_path=dataset_path, + db_path=db_path, + targets_path=targets_path, + target_facets_path=target_facets_path, + geography_summary_path=geography_summary_path, + matrix_summary_path=matrix_summary_path, + run_id=run_id, + ) + paths = ValidationResultWriter( + output_dir=Path(reports_dir), + report_filename=STAGE2_VALIDATION_REPORT_FILENAME, + findings_filename=STAGE2_VALIDATION_FINDINGS_FILENAME, + summary_filename=STAGE2_VALIDATION_SUMMARY_FILENAME, + ).write(report) + if attach_to_contract: + attach_validation_report_to_contract( + contract_path=Path(contract_path), + report=report, + validation_paths=paths, + ) + return report + + def raise_for_failure(self, report: ValidationReport) -> None: + """Raise with failing check IDs when ``report`` failed.""" + + if report.status == "fail": + raise CalibrationPackageValidationError(report) + + +def attach_validation_report_to_contract( + *, + contract_path: Path, + report: ValidationReport, + validation_paths: Mapping[str, Path], +) -> StageContract: + """Attach validation output to a Stage 2 contract and rewrite it.""" + + contract = read_contract(contract_path) + validation_artifacts = { + key: str(path) for key, path in sorted(validation_paths.items()) + } + metadata = { + **dict(contract.metadata), + "validation_artifacts": validation_artifacts, + } + substages = tuple( + replace(substage, validation=report) + if substage.substage_id == CALIBRATION_PACKAGE_SUBSTAGE_ID + else substage + for substage in contract.substages + ) + updated = replace( + contract, + validation=report, + substages=substages, + metadata=metadata, + ) + write_contract(updated, contract_path) + return updated + + +def format_validation_report( + report: ValidationReport, + *, + package_path: str | Path | None = None, +) -> str: + """Return a compact human-readable validation report.""" + + lines = ["", "=== Stage 2 Calibration Package Validation ===", ""] + if package_path is not None: + lines.append(f"Package: {package_path}") + lines.append(f"Status: {report.status.upper()}") + failing = [finding for finding in report.findings if finding.status == "fail"] + warnings = [finding for finding in report.findings if finding.status == "warn"] + lines.append(f"Findings: {len(report.findings)}") + if failing: + lines.append("") + lines.append("Failures:") + for finding in failing: + lines.append(f" {finding.check_id}: {finding.message}") + if warnings: + lines.append("") + lines.append("Warnings:") + for finding in warnings: + lines.append(f" {finding.check_id}: {finding.message}") + return "\n".join(lines) + + +def _validation_suite(cache: dict[str, Any]) -> ValidationSuite: + return ValidationSuite( + suite_id="stage2_calibration_package_validation", + stage_id=STAGE_2_BUILD_CALIBRATION_PACKAGE, + substage_id=CALIBRATION_PACKAGE_SUBSTAGE_ID, + checks=( + ValidationCheck( + check_id=_CHECK_PACKAGE_LOADABLE, + stage_id=STAGE_2_BUILD_CALIBRATION_PACKAGE, + substage_id=CALIBRATION_PACKAGE_SUBSTAGE_ID, + description="Calibration package pickle loads through typed payload reader.", + required_artifacts=("calibration_package",), + run=lambda context: _check_package_loadable(context, cache), + ), + ValidationCheck( + check_id=_CHECK_CONTRACT_MATCHES, + stage_id=STAGE_2_BUILD_CALIBRATION_PACKAGE, + substage_id=CALIBRATION_PACKAGE_SUBSTAGE_ID, + description="Persisted Stage 2 contract matches package and input artifacts.", + required_artifacts=( + "calibration_package", + "calibration_package_contract", + "source_dataset", + "target_database", + ), + run=lambda context: _check_contract_matches(context, cache), + ), + ValidationCheck( + check_id=_CHECK_TARGET_CONFIG, + stage_id=STAGE_2_BUILD_CALIBRATION_PACKAGE, + substage_id=CALIBRATION_PACKAGE_SUBSTAGE_ID, + description="Target config identity is present and checksum-backed.", + required_artifacts=("calibration_package",), + run=lambda context: _check_target_config(context, cache), + ), + ValidationCheck( + check_id=_CHECK_MATRIX, + stage_id=STAGE_2_BUILD_CALIBRATION_PACKAGE, + substage_id=CALIBRATION_PACKAGE_SUBSTAGE_ID, + description="Sparse matrix dimensions match target rows, target names, and summary artifact.", + required_artifacts=("calibration_package", "matrix_summary"), + run=lambda context: _check_matrix(context, cache), + ), + ValidationCheck( + check_id=_CHECK_TARGET_FRAME, + stage_id=STAGE_2_BUILD_CALIBRATION_PACKAGE, + substage_id=CALIBRATION_PACKAGE_SUBSTAGE_ID, + description="Target frame contains required columns and row ordering.", + required_artifacts=("calibration_package",), + run=lambda context: _check_target_frame(context, cache), + ), + ValidationCheck( + check_id=_CHECK_TARGET_METADATA, + stage_id=STAGE_2_BUILD_CALIBRATION_PACKAGE, + substage_id=CALIBRATION_PACKAGE_SUBSTAGE_ID, + description="Target metadata artifacts match package target rows and facets.", + required_artifacts=( + "calibration_package", + "calibration_targets", + "calibration_target_facets", + ), + run=lambda context: _check_target_metadata(context, cache), + ), + ValidationCheck( + check_id=_CHECK_GEOGRAPHY, + stage_id=STAGE_2_BUILD_CALIBRATION_PACKAGE, + substage_id=CALIBRATION_PACKAGE_SUBSTAGE_ID, + description="Geography assignment arrays and summary artifact are consistent.", + required_artifacts=( + "calibration_package", + "geography_assignment_summary", + ), + run=lambda context: _check_geography(context, cache), + ), + ValidationCheck( + check_id=_CHECK_INITIAL_WEIGHTS, + stage_id=STAGE_2_BUILD_CALIBRATION_PACKAGE, + substage_id=CALIBRATION_PACKAGE_SUBSTAGE_ID, + description="Initial weights are present, finite, non-negative, and column-aligned.", + required_artifacts=("calibration_package",), + run=lambda context: _check_initial_weights(context, cache), + ), + ValidationCheck( + check_id=_CHECK_CHUNK_MANIFEST, + stage_id=STAGE_2_BUILD_CALIBRATION_PACKAGE, + substage_id=CALIBRATION_PACKAGE_SUBSTAGE_ID, + description="Declared chunk manifest exists, parses, and matches its checksum.", + required_artifacts=("calibration_package",), + run=lambda context: _check_chunk_manifest(context, cache), + ), + ), + ) + + +def _check_package_loadable( + context: ValidationContext, + cache: dict[str, Any], +) -> ValidationFinding: + payload = _payload(context, cache) + summary = payload.summary() + return _finding( + _CHECK_PACKAGE_LOADABLE, + status="pass", + message="Calibration package payload is loadable.", + metric="package_target_count", + value=summary.n_targets, + ) + + +def _check_contract_matches( + context: ValidationContext, + cache: dict[str, Any], +) -> ValidationFinding: + contract = validate_persisted_calibration_package_contract( + package_path=_path(context, "calibration_package"), + contract_path=_path(context, "calibration_package_contract"), + dataset_path=_path(context, "source_dataset"), + db_path=_path(context, "target_database"), + ) + cache["contract"] = contract + return _finding( + _CHECK_CONTRACT_MATCHES, + status="pass", + message="Stage 2 contract matches the package and input artifacts.", + value=contract.fingerprint.value, + ) + + +def _check_target_config( + context: ValidationContext, + cache: dict[str, Any], +) -> ValidationFinding: + payload = _payload(context, cache) + metadata = payload.metadata + mode = metadata.get("target_config_mode") + config_path = metadata.get("target_config_path") + config_sha = metadata.get("target_config_sha256") + if mode not in _TARGET_CONFIG_IDENTITY_MODES: + return _finding( + _CHECK_TARGET_CONFIG, + status="fail", + message=f"Unknown target config mode: {mode!r}", + metric="target_config_mode", + value=mode, + ) + if mode == "all_active_targets": + if config_path is not None or config_sha is not None: + return _finding( + _CHECK_TARGET_CONFIG, + status="fail", + message="all_active_targets target config must not include path or checksum.", + metric="target_config_identity", + value={"path": config_path, "sha256": config_sha}, + ) + return _finding( + _CHECK_TARGET_CONFIG, + status="pass", + message="All-active-targets package does not require target config identity.", + metric="target_config_mode", + value=mode, + ) + if not config_path or not config_sha: + return _finding( + _CHECK_TARGET_CONFIG, + status="fail", + message=f"{mode} target config requires path and checksum.", + metric="target_config_identity", + value={"path": config_path, "sha256": config_sha}, + ) + resolved_path = _resolve_existing_path(str(config_path)) + if resolved_path is None: + return _finding( + _CHECK_TARGET_CONFIG, + status="fail", + message=f"Target config path does not exist: {config_path}", + metric="target_config_path", + value=str(config_path), + ) + actual_sha = compute_file_checksum(resolved_path) + allowed = {actual_sha, f"sha256:{actual_sha}"} + if str(config_sha) not in allowed: + return _finding( + _CHECK_TARGET_CONFIG, + status="fail", + message="Target config checksum does not match package metadata.", + metric="target_config_sha256", + value=str(config_sha), + threshold=actual_sha, + metadata={"path": str(resolved_path)}, + ) + return _finding( + _CHECK_TARGET_CONFIG, + status="pass", + message="Target config identity is checksum-backed.", + metric="target_config_mode", + value=mode, + metadata={"path": str(resolved_path)}, + ) + + +def _check_matrix( + context: ValidationContext, + cache: dict[str, Any], +) -> ValidationFinding: + payload = _payload(context, cache) + summary = payload.summary() + if summary.matrix_shape[0] != summary.n_targets: + return _finding( + _CHECK_MATRIX, + status="fail", + message="Matrix row count does not match target frame length.", + metric="matrix_shape", + value=summary.matrix_shape, + threshold=summary.n_targets, + ) + if summary.target_name_count != summary.n_targets: + return _finding( + _CHECK_MATRIX, + status="fail", + message="Target name count does not match target frame length.", + metric="target_name_count", + value=summary.target_name_count, + threshold=summary.n_targets, + ) + matrix_summary = _matrix_summary(context, cache) + if matrix_summary is not None: + expected = { + "matrix_shape": tuple(summary.matrix_shape), + "matrix_nnz": summary.matrix_nnz, + "matrix_density": summary.matrix_density, + "n_targets": summary.n_targets, + "n_columns": summary.n_columns, + "target_name_count": summary.target_name_count, + "base_n_records": summary.base_n_records, + "n_clones": summary.n_clones, + "matrix_builder": summary.matrix_builder, + "chunk_size": summary.chunk_size, + "chunk_dir": summary.chunk_dir, + } + for key, expected_value in expected.items(): + actual_value = getattr(matrix_summary, key) + if actual_value != expected_value: + return _finding( + _CHECK_MATRIX, + status="fail", + message=f"Matrix summary artifact does not match package for {key}.", + metric=key, + value=actual_value, + threshold=expected_value, + ) + return _finding( + _CHECK_MATRIX, + status="pass", + message="Matrix dimensions and summary are consistent.", + metric="matrix_shape", + value=summary.matrix_shape, + ) + + +def _check_target_frame( + context: ValidationContext, + cache: dict[str, Any], +) -> ValidationFinding: + payload = _payload(context, cache) + targets_df = payload.targets_df + missing_columns = sorted(_TARGET_FRAME_REQUIRED_COLUMNS - set(targets_df.columns)) + if missing_columns: + return _finding( + _CHECK_TARGET_FRAME, + status="fail", + message="Target frame is missing required columns.", + metric="missing_columns", + value=missing_columns, + ) + if len(targets_df) != len(payload.target_names): + return _finding( + _CHECK_TARGET_FRAME, + status="fail", + message="Target frame row count does not match target_names count.", + metric="target_row_count", + value=len(targets_df), + threshold=len(payload.target_names), + ) + if bool(targets_df["value"].isna().any()): + return _finding( + _CHECK_TARGET_FRAME, + status="fail", + message="Target frame contains null target values.", + metric="null_target_values", + value=True, + ) + return _finding( + _CHECK_TARGET_FRAME, + status="pass", + message="Target frame columns and row ordering are valid.", + metric="target_row_count", + value=len(targets_df), + ) + + +def _check_target_metadata( + context: ValidationContext, + cache: dict[str, Any], +) -> ValidationFinding: + payload = _payload(context, cache) + rows = _target_metadata_rows(context) + if len(rows) != len(payload.target_names): + return _finding( + _CHECK_TARGET_METADATA, + status="fail", + message="Target metadata row count does not match package targets.", + metric="target_metadata_row_count", + value=len(rows), + threshold=len(payload.target_names), + ) + targets_df = payload.targets_df.reset_index(drop=True) + for index, row in enumerate(rows): + if row.get("target_index") != index: + return _target_metadata_mismatch( + "target_index", + row.get("target_index"), + index, + ) + expected_name = str(payload.target_names[index]) + if row.get("target_name") != expected_name: + return _target_metadata_mismatch( + "target_name", + row.get("target_name"), + expected_name, + ) + expected_value = float(targets_df.loc[index, "value"]) + if not np.isclose(float(row.get("target_value")), expected_value): + return _target_metadata_mismatch( + "target_value", + row.get("target_value"), + expected_value, + ) + comparisons = { + "variable": str(targets_df.loc[index, "variable"]), + "geography_level": _optional_string_value( + targets_df.loc[index, "geo_level"] + ), + "geography_id": _optional_string_value( + targets_df.loc[index, "geographic_id"] + ), + "domain_variable": _optional_string_value( + targets_df.loc[index, "domain_variable"] + ), + } + for key, expected in comparisons.items(): + if row.get(key) != expected: + return _target_metadata_mismatch(key, row.get(key), expected) + + facets = _json_artifact(context, "calibration_target_facets") + if not isinstance(facets, Mapping): + return _finding( + _CHECK_TARGET_METADATA, + status="fail", + message="Target facets artifact must contain a JSON object.", + metric="target_facets_type", + value=type(facets).__name__, + ) + expected_facets = target_facets_from_rows(rows) + if dict(facets) != expected_facets: + return _finding( + _CHECK_TARGET_METADATA, + status="fail", + message="Target facets artifact does not match target metadata rows.", + metric="target_facets", + value=dict(facets), + threshold=expected_facets, + ) + return _finding( + _CHECK_TARGET_METADATA, + status="pass", + message="Target metadata rows and facets match package target order.", + metric="target_metadata_row_count", + value=len(rows), + ) + + +def _check_geography( + context: ValidationContext, + cache: dict[str, Any], +) -> ValidationFinding: + payload = _payload(context, cache) + summary = payload.geography_summary() + if summary.source_kind != "calibration_package" or summary.status != "completed": + return _finding( + _CHECK_GEOGRAPHY, + status="fail", + message="Calibration package does not include completed geography assignment arrays.", + metric="geography_status", + value={ + "source_kind": summary.source_kind, + "status": summary.status, + }, + ) + persisted = _json_artifact(context, "geography_assignment_summary") + if persisted is not None and persisted != summary.to_dict(): + return _finding( + _CHECK_GEOGRAPHY, + status="fail", + message="Geography assignment summary artifact does not match package.", + metric="canonical_geography_sha256", + value=persisted.get("canonical_geography_sha256"), + threshold=summary.canonical_geography_sha256, + ) + return _finding( + _CHECK_GEOGRAPHY, + status="pass", + message="Geography assignment arrays and summary are consistent.", + metric="canonical_geography_sha256", + value=summary.canonical_geography_sha256, + ) + + +def _check_initial_weights( + context: ValidationContext, + cache: dict[str, Any], +) -> ValidationFinding: + payload = _payload(context, cache) + if payload.initial_weights is None: + return _finding( + _CHECK_INITIAL_WEIGHTS, + status="fail", + message="Calibration package is missing initial_weights.", + metric="has_initial_weights", + value=False, + ) + weights = np.asarray(payload.initial_weights) + n_columns = int(payload.X_sparse.shape[1]) + if len(weights) != n_columns: + return _finding( + _CHECK_INITIAL_WEIGHTS, + status="fail", + message="Initial weights length does not match matrix columns.", + metric="initial_weights_length", + value=len(weights), + threshold=n_columns, + ) + if not bool(np.isfinite(weights).all()): + return _finding( + _CHECK_INITIAL_WEIGHTS, + status="fail", + message="Initial weights must be finite.", + metric="initial_weights_finite", + value=False, + ) + if bool((weights < 0).any()): + return _finding( + _CHECK_INITIAL_WEIGHTS, + status="fail", + message="Initial weights must be non-negative.", + metric="initial_weights_non_negative", + value=False, + ) + return _finding( + _CHECK_INITIAL_WEIGHTS, + status="pass", + message="Initial weights are finite, non-negative, and column-aligned.", + metric="initial_weights_length", + value=len(weights), + ) + + +def _check_chunk_manifest( + context: ValidationContext, + cache: dict[str, Any], +) -> ValidationFinding: + matrix_summary = _matrix_summary(context, cache) + if matrix_summary is None: + return _finding( + _CHECK_CHUNK_MANIFEST, + status="pass", + message="No matrix summary artifact declared a chunk manifest.", + metric="chunk_manifest_declared", + value=False, + ) + if matrix_summary.matrix_builder != "chunked": + return _finding( + _CHECK_CHUNK_MANIFEST, + status="pass", + message="Non-chunked matrix build does not require a chunk manifest.", + metric="matrix_builder", + value=matrix_summary.matrix_builder, + ) + manifest_path = matrix_summary.chunk_manifest_path + manifest_sha = matrix_summary.chunk_manifest_sha256 + if manifest_path is None and manifest_sha is None: + return _finding( + _CHECK_CHUNK_MANIFEST, + status="pass", + message="Chunked matrix summary does not declare a persisted chunk manifest.", + metric="chunk_manifest_declared", + value=False, + ) + if not manifest_path or not manifest_sha: + return _finding( + _CHECK_CHUNK_MANIFEST, + status="fail", + message="Chunk manifest path and checksum must be declared together.", + metric="chunk_manifest_identity", + value={"path": manifest_path, "sha256": manifest_sha}, + ) + path = Path(manifest_path) + if not path.exists(): + return _finding( + _CHECK_CHUNK_MANIFEST, + status="fail", + message=f"Chunk manifest does not exist: {path}", + metric="chunk_manifest_path", + value=str(path), + ) + actual_sha = compute_file_checksum(path) + if manifest_sha not in {actual_sha, f"sha256:{actual_sha}"}: + return _finding( + _CHECK_CHUNK_MANIFEST, + status="fail", + message="Chunk manifest checksum does not match matrix summary.", + metric="chunk_manifest_sha256", + value=manifest_sha, + threshold=actual_sha, + ) + ChunkCacheManifest.read(path) + return _finding( + _CHECK_CHUNK_MANIFEST, + status="pass", + message="Chunk manifest exists, parses, and matches its checksum.", + metric="chunk_manifest_sha256", + value=manifest_sha, + ) + + +def _payload( + context: ValidationContext, + cache: dict[str, Any], +) -> CalibrationPackagePayload: + if "payload" not in cache: + cache["payload"] = CalibrationPackageReader( + package_path=_path(context, "calibration_package") + ).read() + return cache["payload"] + + +def _matrix_summary( + context: ValidationContext, + cache: dict[str, Any], +) -> MatrixBuildSummary | None: + if "matrix_summary" in cache: + return cache["matrix_summary"] + path = _optional_path(context, "matrix_summary") + if path is None: + cache["matrix_summary"] = None + return None + cache["matrix_summary"] = MatrixBuildSummary.from_dict( + json.loads(path.read_text(encoding="utf-8")) + ) + return cache["matrix_summary"] + + +def _json_artifact( + context: ValidationContext, + logical_name: str, +) -> Mapping[str, Any] | None: + path = _optional_path(context, logical_name) + if path is None: + return None + return json.loads(path.read_text(encoding="utf-8")) + + +def _target_metadata_rows(context: ValidationContext) -> list[dict[str, Any]]: + path = _path(context, "calibration_targets") + rows: list[dict[str, Any]] = [] + with path.open(encoding="utf-8") as handle: + for line in handle: + if line.strip(): + row = json.loads(line) + if not isinstance(row, Mapping): + raise ValueError("Target metadata JSONL rows must be objects") + rows.append(dict(row)) + return rows + + +def _target_metadata_mismatch( + key: str, + actual: Any, + expected: Any, +) -> ValidationFinding: + return _finding( + _CHECK_TARGET_METADATA, + status="fail", + message=f"Target metadata artifact does not match package for {key}.", + metric=key, + value=actual, + threshold=expected, + ) + + +def _finding( + check_id: str, + *, + status: str, + message: str, + metric: str | None = None, + value: Any | None = None, + threshold: Any | None = None, + metadata: Mapping[str, Any] | None = None, +) -> ValidationFinding: + return ValidationFinding( + check_id=check_id, + status=status, + message=message, + metric=metric, + value=value, + threshold=threshold, + metadata=dict(metadata or {}), + ) + + +def _artifact_refs(paths: Mapping[str, Path]) -> dict[str, ArtifactRef]: + refs: dict[str, ArtifactRef] = {} + for logical_name, path in paths.items(): + if path.exists() and path.is_file(): + refs[logical_name] = ArtifactRef( + logical_name=logical_name, + uri=path.resolve().as_uri(), + sha256=f"sha256:{sha256_file(path)}", + size_bytes=path.stat().st_size, + media_type=_media_type_for_path(path), + metadata={ + "stage_id": STAGE_2_BUILD_CALIBRATION_PACKAGE, + "substage_id": CALIBRATION_PACKAGE_SUBSTAGE_ID, + }, + ) + return refs + + +def _path(context: ValidationContext, logical_name: str) -> Path: + return _artifact_uri_to_path(context.resolver.require(logical_name).uri) + + +def _optional_path(context: ValidationContext, logical_name: str) -> Path | None: + artifact = context.resolver.optional(logical_name) + if artifact is None: + return None + return _artifact_uri_to_path(artifact.uri) + + +def _artifact_uri_to_path(uri: str) -> Path: + parsed = urlparse(uri) + if parsed.scheme == "file": + return Path(unquote(parsed.path)) + if not parsed.scheme: + return Path(uri) + raise ValueError(f"Unsupported artifact URI scheme: {uri}") + + +def _resolve_existing_path(path: str) -> Path | None: + candidate = Path(path) + candidates = [candidate] if candidate.is_absolute() else [Path.cwd() / candidate] + repo_candidate = Path(__file__).resolve().parents[2] / candidate + if repo_candidate not in candidates: + candidates.append(repo_candidate) + for item in candidates: + if item.exists() and item.is_file(): + return item + return None + + +def _optional_string_value(value: Any) -> str | None: + if value is None: + return None + if hasattr(np, "isnan"): + try: + if bool(np.isnan(value)): + return None + except TypeError: + pass + return str(value) + + +def _media_type_for_path(path: Path) -> str: + suffix = path.suffix.lower() + if suffix == ".h5": + return "application/x-hdf5" + if suffix == ".db": + return "application/vnd.sqlite3" + if suffix == ".json": + return "application/json" + if suffix == ".jsonl": + return "application/x-ndjson" + if suffix == ".pkl": + return "application/python-pickle" + return "application/octet-stream" diff --git a/tests/unit/calibration/test_unified_calibration_build_only.py b/tests/unit/calibration/test_unified_calibration_build_only.py new file mode 100644 index 000000000..0f919d305 --- /dev/null +++ b/tests/unit/calibration/test_unified_calibration_build_only.py @@ -0,0 +1,50 @@ +"""Focused source guards for build-only unified calibration package behavior.""" + +import ast +from pathlib import Path + + +def _call_name(node: ast.AST) -> str | None: + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Attribute): + owner = _call_name(node.value) + return f"{owner}.{node.attr}" if owner else node.attr + return None + + +def test_build_only_package_output_returns_before_weight_fitting(): + source = Path("policyengine_us_data/calibration/unified_calibration.py").read_text( + encoding="utf-8" + ) + module = ast.parse(source) + run_calibration = next( + node + for node in module.body + if isinstance(node, ast.FunctionDef) and node.name == "run_calibration" + ) + build_only_block = next( + node + for node in run_calibration.body + if isinstance(node, ast.If) + and isinstance(node.test, ast.Name) + and node.test.id == "build_only" + ) + + build_only_calls = { + _call_name(call.func) + for call in ast.walk(build_only_block) + if isinstance(call, ast.Call) + } + fit_calls_after_build_only = [ + call + for statement in run_calibration.body[ + run_calibration.body.index(build_only_block) + 1 : + ] + for call in ast.walk(statement) + if isinstance(call, ast.Call) and _call_name(call.func) == "fit_l0_weights" + ] + + assert "validator.raise_for_failure" in build_only_calls + assert any(isinstance(node, ast.Return) for node in build_only_block.body) + assert fit_calls_after_build_only, "source guard should cover the L0 fit path" diff --git a/tests/unit/calibration_package/test_specs.py b/tests/unit/calibration_package/test_specs.py index b20238a66..dcd32b807 100644 --- a/tests/unit/calibration_package/test_specs.py +++ b/tests/unit/calibration_package/test_specs.py @@ -16,6 +16,9 @@ MATRIX_BUILD_DIRNAME, MATRIX_SUMMARY_FILENAME, SOURCE_DATASET_FILENAME, + STAGE2_VALIDATION_FINDINGS_FILENAME, + STAGE2_VALIDATION_REPORT_FILENAME, + STAGE2_VALIDATION_SUMMARY_FILENAME, TARGET_DATABASE_FILENAME, TargetConfigIdentity, calibration_package_artifact_paths, @@ -115,6 +118,21 @@ def test_calibration_package_artifact_paths(): assert paths.reports_dir == Path("/pipeline/artifacts/run-a") / ( CALIBRATION_REPORTS_DIRNAME ) + assert paths.validation_report == ( + Path("/pipeline/artifacts/run-a") + / CALIBRATION_REPORTS_DIRNAME + / STAGE2_VALIDATION_REPORT_FILENAME + ) + assert paths.validation_findings == ( + Path("/pipeline/artifacts/run-a") + / CALIBRATION_REPORTS_DIRNAME + / STAGE2_VALIDATION_FINDINGS_FILENAME + ) + assert paths.validation_summary == ( + Path("/pipeline/artifacts/run-a") + / CALIBRATION_REPORTS_DIRNAME + / STAGE2_VALIDATION_SUMMARY_FILENAME + ) assert paths.matrix_build_dir == Path("/pipeline/artifacts/run-a") / ( MATRIX_BUILD_DIRNAME ) @@ -125,6 +143,9 @@ def test_calibration_package_artifact_paths(): paths.target_facets, paths.geography_summary, paths.matrix_summary, + paths.validation_report, + paths.validation_findings, + paths.validation_summary, ) diff --git a/tests/unit/calibration_package/test_validation.py b/tests/unit/calibration_package/test_validation.py new file mode 100644 index 000000000..5476c1619 --- /dev/null +++ b/tests/unit/calibration_package/test_validation.py @@ -0,0 +1,352 @@ +import json +from pathlib import Path +from types import SimpleNamespace + +import numpy as np +import pytest + +from tests.unit.fixtures.calibration_package_stage_contract import ( + CALIBRATION_COMPLETED_AT, + CALIBRATION_RUN_ID, + CALIBRATION_STARTED_AT, + calibration_package_parameters, + calibration_package_payload, + calibration_package_payload_without_geography, + contract_input_paths, + write_calibration_package_payload, + write_non_mapping_calibration_package_payload, +) + +from policyengine_us_data.calibration_package.matrix import ( + ChunkCacheManifest, + MatrixBuildResult, + MatrixBuildSpec, +) +from policyengine_us_data.calibration_package.specs import ( + CALIBRATION_REPORTS_DIRNAME, + CALIBRATION_TARGET_FACETS_FILENAME, + CALIBRATION_TARGETS_FILENAME, + GEOGRAPHY_ASSIGNMENT_SUMMARY_FILENAME, + MATRIX_SUMMARY_FILENAME, + STAGE2_VALIDATION_FINDINGS_FILENAME, + STAGE2_VALIDATION_REPORT_FILENAME, + STAGE2_VALIDATION_SUMMARY_FILENAME, +) +from policyengine_us_data.calibration_package.validation import ( + CalibrationPackageValidationError, + CalibrationPackageValidator, +) +from policyengine_us_data.calibration_package.targets import target_facets_from_rows +from policyengine_us_data.stage_contracts import ValidationReport +from policyengine_us_data.stage_contracts.calibration_package import ( + summarize_geography_assignment, + write_calibration_package_contract, +) +from policyengine_us_data.stage_contracts.calibration_package_schema import ( + MatrixBuildSummary, +) +from policyengine_us_data.stage_contracts.io import read_contract +from policyengine_us_data.utils.manifest import compute_file_checksum + + +def _matrix_build_summary_for_package(package: dict) -> MatrixBuildSummary: + metadata = package["metadata"] + return MatrixBuildResult.from_builder_output( + spec=MatrixBuildSpec( + matrix_builder=metadata["matrix_builder"], + base_n_records=metadata["base_n_records"], + n_clones=metadata["n_clones"], + chunk_size=metadata["chunk_size"], + chunk_dir=metadata["chunk_dir"], + ), + targets_df=package["targets_df"], + X_sparse=package["X_sparse"], + target_names=package["target_names"], + ).summary() + + +def _write_artifacts( + tmp_path: Path, + *, + package: dict | None = None, + matrix_summary_updates: dict | None = None, +) -> SimpleNamespace: + dataset_path, db_path, package_path = contract_input_paths(tmp_path) + package = package or calibration_package_payload() + target_config = tmp_path / "target_config.yaml" + target_config.write_text("include: []\n", encoding="utf-8") + package["metadata"]["target_config_path"] = str(target_config) + package["metadata"]["target_config_sha256"] = compute_file_checksum(target_config) + package["metadata"]["target_config_mode"] = "explicit" + write_calibration_package_payload(package_path, package) + + target_rows = _target_rows_for_package(package) + targets_path = tmp_path / CALIBRATION_TARGETS_FILENAME + targets_path.write_text( + "".join(json.dumps(row, sort_keys=True) + "\n" for row in target_rows), + encoding="utf-8", + ) + target_facets_path = tmp_path / CALIBRATION_TARGET_FACETS_FILENAME + target_facets_path.write_text( + json.dumps(target_facets_from_rows(target_rows), sort_keys=True) + "\n", + encoding="utf-8", + ) + geography_summary = summarize_geography_assignment(package) + geography_summary_path = tmp_path / GEOGRAPHY_ASSIGNMENT_SUMMARY_FILENAME + geography_summary_path.write_text( + json.dumps(geography_summary.to_dict(), sort_keys=True) + "\n", + encoding="utf-8", + ) + matrix_summary = _matrix_build_summary_for_package(package).to_dict() + if matrix_summary_updates: + matrix_summary.update(matrix_summary_updates) + matrix_summary_schema = MatrixBuildSummary.from_dict(matrix_summary) + matrix_summary_path = tmp_path / MATRIX_SUMMARY_FILENAME + matrix_summary_path.write_text( + json.dumps(matrix_summary_schema.to_dict(), sort_keys=True) + "\n", + encoding="utf-8", + ) + parameters = calibration_package_parameters() + parameters["target_config"] = str(target_config) + parameters["target_config_sha256"] = compute_file_checksum(target_config) + write_calibration_package_contract( + package_path=package_path, + dataset_path=dataset_path, + db_path=db_path, + package=package, + parameters=parameters, + run_id=CALIBRATION_RUN_ID, + started_at=CALIBRATION_STARTED_AT, + completed_at=CALIBRATION_COMPLETED_AT, + target_metadata_path=targets_path, + target_facets_path=target_facets_path, + target_selection_summary={"target_count": len(package["target_names"])}, + geography_summary_path=geography_summary_path, + geography_assignment_summary=geography_summary, + matrix_summary_path=matrix_summary_path, + matrix_build_summary=matrix_summary_schema, + ) + return SimpleNamespace( + package_path=package_path, + contract_path=tmp_path / "calibration_package_contract.json", + dataset_path=dataset_path, + db_path=db_path, + targets_path=targets_path, + target_facets_path=target_facets_path, + geography_summary_path=geography_summary_path, + matrix_summary_path=matrix_summary_path, + reports_dir=tmp_path / CALIBRATION_REPORTS_DIRNAME, + ) + + +def _target_rows_for_package(package: dict) -> list[dict]: + rows = [] + targets_df = package["targets_df"].reset_index(drop=True) + for target_index, row in targets_df.iterrows(): + rows.append( + { + "constraint_key": "none", + "domain_variable": row.get("domain_variable"), + "geography_id": row.get("geographic_id"), + "geography_level": row.get("geo_level"), + "included_in_package": True, + "period": None, + "source_table": "targets", + "target_components": [row["variable"]], + "target_config_mode": package["metadata"]["target_config_mode"], + "target_config_path": package["metadata"]["target_config_path"], + "target_config_sha256": package["metadata"]["target_config_sha256"], + "target_constraints": [], + "target_expression": None, + "target_id": target_index, + "target_index": target_index, + "target_name": str(package["target_names"][target_index]), + "target_value": float(row["value"]), + "variable": row["variable"], + } + ) + return rows + + +def _validate(paths: SimpleNamespace) -> ValidationReport: + return CalibrationPackageValidator().validate_and_write( + package_path=paths.package_path, + contract_path=paths.contract_path, + dataset_path=paths.dataset_path, + db_path=paths.db_path, + reports_dir=paths.reports_dir, + targets_path=paths.targets_path, + target_facets_path=paths.target_facets_path, + geography_summary_path=paths.geography_summary_path, + matrix_summary_path=paths.matrix_summary_path, + run_id=CALIBRATION_RUN_ID, + ) + + +def _failing_ids(report: ValidationReport) -> set[str]: + return {finding.check_id for finding in report.findings if finding.status == "fail"} + + +def test_validator_writes_canonical_report_and_attaches_contract(tmp_path): + paths = _write_artifacts(tmp_path) + + report = _validate(paths) + + assert report.status == "pass" + assert (paths.reports_dir / STAGE2_VALIDATION_REPORT_FILENAME).exists() + assert (paths.reports_dir / STAGE2_VALIDATION_FINDINGS_FILENAME).exists() + assert (paths.reports_dir / STAGE2_VALIDATION_SUMMARY_FILENAME).exists() + restored = ValidationReport.from_dict( + json.loads( + (paths.reports_dir / STAGE2_VALIDATION_REPORT_FILENAME).read_text( + encoding="utf-8" + ) + ) + ) + assert restored.status == "pass" + contract = read_contract(paths.contract_path) + assert contract.validation == report + assert contract.substages[0].validation == report + assert contract.metadata["validation_artifacts"]["report"].endswith( + STAGE2_VALIDATION_REPORT_FILENAME + ) + + +def test_validator_reports_target_config_checksum_failure(tmp_path): + paths = _write_artifacts(tmp_path) + package = calibration_package_payload() + package["metadata"]["target_config_path"] = str(tmp_path / "target_config.yaml") + package["metadata"]["target_config_sha256"] = "sha256:" + "0" * 64 + write_calibration_package_payload(paths.package_path, package) + + report = _validate(paths) + + assert report.status == "fail" + assert "stage2.target_config.identity" in _failing_ids(report) + + +def test_validator_reports_unloadable_package_payload(tmp_path): + paths = _write_artifacts(tmp_path) + write_non_mapping_calibration_package_payload(paths.package_path) + + report = _validate(paths) + + assert "stage2.package.loadable" in _failing_ids(report) + + +def test_validator_reports_matrix_summary_mismatch(tmp_path): + paths = _write_artifacts(tmp_path) + summary = json.loads(paths.matrix_summary_path.read_text(encoding="utf-8")) + summary["matrix_nnz"] = 1 + paths.matrix_summary_path.write_text( + json.dumps(summary, sort_keys=True) + "\n", + encoding="utf-8", + ) + + report = _validate(paths) + + assert "stage2.matrix.consistency" in _failing_ids(report) + + +def test_validator_reports_missing_matrix_summary_artifact(tmp_path): + paths = _write_artifacts(tmp_path) + paths.matrix_summary_path.unlink() + + report = _validate(paths) + + assert "stage2.matrix.consistency" in _failing_ids(report) + + +def test_validator_reports_target_frame_missing_required_column(tmp_path): + package = calibration_package_payload() + package["targets_df"] = package["targets_df"].drop(columns=["geographic_id"]) + paths = _write_artifacts(tmp_path, package=package) + + report = _validate(paths) + + assert "stage2.target_frame.consistency" in _failing_ids(report) + + +def test_validator_reports_missing_target_metadata_artifact(tmp_path): + paths = _write_artifacts(tmp_path) + paths.targets_path.unlink() + + report = _validate(paths) + + assert "stage2.target_metadata.consistency" in _failing_ids(report) + + +def test_validator_reports_target_facets_mismatch(tmp_path): + paths = _write_artifacts(tmp_path) + facets = json.loads(paths.target_facets_path.read_text(encoding="utf-8")) + facets["target_count"] = 999 + paths.target_facets_path.write_text( + json.dumps(facets, sort_keys=True) + "\n", + encoding="utf-8", + ) + + report = _validate(paths) + + assert "stage2.target_metadata.consistency" in _failing_ids(report) + + +def test_validator_reports_missing_geography_summary_artifact(tmp_path): + paths = _write_artifacts(tmp_path) + paths.geography_summary_path.unlink() + + report = _validate(paths) + + assert "stage2.geography.consistency" in _failing_ids(report) + + +def test_validator_reports_missing_geography_arrays(tmp_path): + paths = _write_artifacts( + tmp_path, + package=calibration_package_payload_without_geography(), + ) + + report = _validate(paths) + + assert "stage2.geography.consistency" in _failing_ids(report) + + +def test_validator_reports_initial_weights_length_mismatch(tmp_path): + package = calibration_package_payload() + package["initial_weights"] = np.array([1.0, 1.0]) + paths = _write_artifacts(tmp_path, package=package) + + report = _validate(paths) + + assert "stage2.initial_weights.consistency" in _failing_ids(report) + + +def test_validator_reports_chunk_manifest_checksum_mismatch(tmp_path): + manifest_path = tmp_path / "matrix_build" / "chunk_manifest.json" + ChunkCacheManifest.from_signature({"run_id": CALIBRATION_RUN_ID}).write( + manifest_path + ) + paths = _write_artifacts( + tmp_path, + matrix_summary_updates={ + "chunk_manifest_path": str(manifest_path), + "chunk_manifest_sha256": "sha256:" + "0" * 64, + }, + ) + + report = _validate(paths) + + assert "stage2.chunk_manifest.consistency" in _failing_ids(report) + + +def test_validation_failure_error_includes_finding_ids(tmp_path): + package = calibration_package_payload() + package["initial_weights"] = np.array([1.0, 1.0]) + paths = _write_artifacts(tmp_path, package=package) + validator = CalibrationPackageValidator() + report = _validate(paths) + + with pytest.raises( + CalibrationPackageValidationError, + match="stage2.initial_weights.consistency", + ): + validator.raise_for_failure(report) diff --git a/tests/unit/test_remote_calibration_runner.py b/tests/unit/test_remote_calibration_runner.py index db295ea6a..f0fc9b0dd 100644 --- a/tests/unit/test_remote_calibration_runner.py +++ b/tests/unit/test_remote_calibration_runner.py @@ -5,6 +5,8 @@ from types import ModuleType, SimpleNamespace from unittest.mock import Mock +import pytest + def _load_remote_calibration_runner_module(): fake_modal = ModuleType("modal") @@ -187,6 +189,22 @@ def fake_validate_persisted_contract(**kwargs): "validate_persisted_calibration_package_contract", fake_validate_persisted_contract, ) + from policyengine_us_data.calibration_package import validation + from policyengine_us_data.stage_contracts import ValidationReport + + class FakePackageValidator: + def validate_and_write(self, **kwargs): + captured["package_validation"] = kwargs + return ValidationReport(status="pass") + + def raise_for_failure(self, report): + captured["validation_status"] = report.status + + monkeypatch.setattr( + validation, + "CalibrationPackageValidator", + lambda: FakePackageValidator(), + ) def fake_run_streaming(cmd, env=None, label=""): captured["cmd"] = cmd @@ -241,11 +259,57 @@ def fake_run_streaming(cmd, env=None, label=""): assert captured["contract_validation"]["db_path"] == ( artifacts_dir / "policy_data.db" ) + assert captured["package_validation"]["package_path"] == ( + artifacts_dir / "calibration_package.pkl" + ) + assert captured["package_validation"]["reports_dir"] == ( + artifacts_dir / "calibration_reports" + ) + assert captured["validation_status"] == "pass" ensure_prereqs.assert_called_once() volume.reload.assert_called_once() volume.commit.assert_called_once() +def test_build_package_impl_commits_validation_report_before_failing( + monkeypatch, + tmp_path, +): + remote_runner = _load_remote_calibration_runner_module() + artifacts_dir = tmp_path / "artifacts" / "bench-run" + artifacts_dir.mkdir(parents=True) + (artifacts_dir / "policy_data.db").write_bytes(b"db") + (artifacts_dir / "source_imputed_stratified_extended_cps.h5").write_bytes(b"h5") + + volume = SimpleNamespace(reload=Mock(), commit=Mock()) + monkeypatch.setattr(remote_runner, "PIPELINE_MOUNT", str(tmp_path)) + monkeypatch.setattr(remote_runner, "pipeline_vol", volume) + monkeypatch.setattr(remote_runner, "_setup_repo", lambda: None) + monkeypatch.setattr(remote_runner, "_ensure_geography_prerequisites", lambda: None) + + def fake_run_streaming(cmd, env=None, label=""): + report_dir = artifacts_dir / "calibration_reports" + report_dir.mkdir() + (report_dir / "validation_report.json").write_text( + '{"status":"fail"}\n', + encoding="utf-8", + ) + return 2, [] + + monkeypatch.setattr(remote_runner, "_run_streaming", fake_run_streaming) + + with pytest.raises(RuntimeError, match="Package build failed"): + remote_runner._build_package_impl( + branch="main", + workers=1, + n_clones=10, + run_id="bench-run", + ) + + volume.reload.assert_called_once() + volume.commit.assert_called_once() + + def test_write_package_sidecar_reads_payload_and_contract(tmp_path): remote_runner = _load_remote_calibration_runner_module() from tests.unit.fixtures.calibration_package_stage_contract import (