From 2b97bdb1a594bc82d28074d9eb0e7faa8be13ee1 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 20 May 2026 16:45:48 +0200 Subject: [PATCH] Add Stage 2 package payload reader --- changelog.d/1073.changed | 1 + docs/engineering/pipeline-map.md | 16 +- docs/generated/pipeline_api.json | 106 +++- docs/generated/pipeline_map.json | 110 +++- docs/pipeline_map.yaml | 33 +- modal_app/remote_calibration_runner.py | 27 +- .../calibration/unified_calibration.py | 60 +-- .../calibration_package/__init__.py | 16 + .../calibration_package/payload.py | 480 ++++++++++++++++++ .../stage_contracts/calibration_package.py | 210 ++------ .../unit/calibration_package/test_payload.py | 124 +++++ ...test_calibration_package_stage_contract.py | 6 +- tests/unit/test_pipeline_docs_extractor.py | 3 + tests/unit/test_remote_calibration_runner.py | 30 ++ 14 files changed, 983 insertions(+), 239 deletions(-) create mode 100644 changelog.d/1073.changed create mode 100644 policyengine_us_data/calibration_package/payload.py create mode 100644 tests/unit/calibration_package/test_payload.py diff --git a/changelog.d/1073.changed b/changelog.d/1073.changed new file mode 100644 index 000000000..2e3a100ab --- /dev/null +++ b/changelog.d/1073.changed @@ -0,0 +1 @@ +Add typed Stage 2 calibration package payload reader and writer helpers. diff --git a/docs/engineering/pipeline-map.md b/docs/engineering/pipeline-map.md index b1f630718..5332d7481 100644 --- a/docs/engineering/pipeline-map.md +++ b/docs/engineering/pipeline-map.md @@ -379,6 +379,7 @@ Build sparse calibration matrix (targets x households x clones) | `takeup_rerand` Block-Level Takeup Re-randomization | `process` | `unknown` | `unknown` | | | `sparse_build` Sparse Matrix Construction | `process` | `unknown` | `unknown` | | | `out_pkg` calibration_package.pkl | `artifact` | `unknown` | `unknown` | | +| `out_metadata` calibration_package_meta.json | `artifact` | `unknown` | `unknown` | | | `out_contract` calibration_package_contract.json | `artifact` | `unknown` | `unknown` | | | `util_sql` sqlalchemy | `utility` | `unknown` | `unknown` | | | `util_pool` ProcessPoolExecutor | `utility` | `unknown` | `unknown` | | @@ -395,6 +396,9 @@ Build sparse calibration matrix (targets x households x clones) | `clone_assembly` Clone Value Assembly | `library` | `current` | `moving` | `policyengine_us_data.calibration.unified_matrix_builder._assemble_clone_values_standalone` | | `build_matrix` Build Calibration Matrix | `library` | `current` | `moving` | `policyengine_us_data.calibration.unified_matrix_builder.UnifiedMatrixBuilder.build_matrix` | | `build_matrix_chunked` Build Calibration Matrix In Chunks | `library` | `current` | `experimental` | `policyengine_us_data.calibration.unified_matrix_builder.UnifiedMatrixBuilder.build_matrix_chunked` | +| `stage2_payload_boundary` Stage 2 Package Payload | `library` | `current` | `moving` | `policyengine_us_data.calibration_package.payload.CalibrationPackagePayload` | +| `stage2_payload_writer` Stage 2 Payload Writer | `library` | `current` | `moving` | `policyengine_us_data.calibration_package.payload.CalibrationPackageWriter` | +| `stage2_payload_reader` Stage 2 Payload Reader | `library` | `current` | `moving` | `policyengine_us_data.calibration_package.payload.CalibrationPackageReader` | | `stage2_calibration_package_contract_writer` Stage 2 Contract Writer | `library` | `current` | `moving` | `policyengine_us_data.stage_contracts.calibration_package.write_calibration_package_contract` | | `stage2_calibration_package_contract_validator` Stage 2 Contract Validator | `validation` | `current` | `moving` | `policyengine_us_data.stage_contracts.calibration_package.validate_calibration_package_contract` | @@ -423,13 +427,19 @@ Build sparse calibration matrix (targets x households x clones) - `takeup_rerand` -> `sparse_build` `data_flow` - `sparse_build` -> `build_matrix` `uses_library` (non-chunked path) - `sparse_build` -> `build_matrix_chunked` `uses_library` (chunked path) -- `build_matrix` -> `stage2_calibration_package_writer` `data_flow` -- `build_matrix_chunked` -> `stage2_calibration_package_writer` `data_flow` +- `build_matrix` -> `stage2_payload_boundary` `data_flow` +- `build_matrix_chunked` -> `stage2_payload_boundary` `data_flow` +- `stage2_payload_boundary` -> `stage2_calibration_package_writer` `data_flow` (typed package payload) - `stage2_artifact_specs` -> `stage2_calibration_package_writer` `uses_utility` (package path) -- `stage2_calibration_package_writer` -> `out_pkg` `produces_artifact` +- `stage2_calibration_package_writer` -> `stage2_payload_writer` `uses_library` (pickle write) +- `stage2_payload_writer` -> `out_pkg` `produces_artifact` +- `out_pkg` -> `stage2_payload_reader` `data_flow` - `out_pkg` -> `stage2_calibration_package_contract_writer` `data_flow` +- `stage2_payload_reader` -> `stage2_calibration_package_contract_writer` `uses_library` (summary and checksum) - `stage2_artifact_specs` -> `stage2_calibration_package_contract_writer` `uses_utility` (contract path) - `stage2_calibration_package_contract_writer` -> `out_contract` `produces_artifact` +- `out_contract` -> `stage2_payload_writer` `data_flow` (sidecar contract material) +- `stage2_payload_writer` -> `out_metadata` `produces_artifact` (sidecar metadata) - `out_pkg` -> `stage2_calibration_package_contract_validator` `validates` - `out_contract` -> `stage2_calibration_package_contract_validator` `validates` - `in_cps_s5` -> `stage2_calibration_package_contract_validator` `validates` diff --git a/docs/generated/pipeline_api.json b/docs/generated/pipeline_api.json index 832ffa407..edbc8ba57 100644 --- a/docs/generated/pipeline_api.json +++ b/docs/generated/pipeline_api.json @@ -727,7 +727,7 @@ "docstring": "", "id": "calibration_diagnostics", "kind": "function", - "line": 1249, + "line": 1245, "metadata": { "api_refs": [ "policyengine_us_data.calibration.unified_calibration.compute_diagnostics" @@ -1091,7 +1091,7 @@ "docstring": "Fit L0-regularized calibration weights.\n\nArgs:\n X_sparse: Sparse matrix (targets x records).\n targets: Target values array.\n lambda_l0: L0 regularization strength.\n epochs: Training epochs.\n device: Torch device.\n verbose_freq: Print frequency. Defaults to 10%.\n beta: L0 gate temperature.\n lambda_l2: L2 regularization strength.\n learning_rate: Optimizer learning rate.\n log_freq: Epochs between per-target CSV logs.\n None disables logging.\n log_path: Path for the per-target calibration log CSV.\n target_names: Human-readable target names for the log.\n initial_weights: Pre-computed initial weights. If None,\n computed from targets_df age targets.\n targets_df: Targets DataFrame, used to compute\n initial_weights when not provided.\n target_groups: Optional group ID per target row for balanced loss.\n resume_from: Path to a `.checkpoint.pt` file or `.npy`\n weights file to continue fitting from.\n checkpoint_path: Where to save resumable fit checkpoints.\n\nReturns:\n Weight array of shape (n_records,).", "id": "fit_model", "kind": "function", - "line": 893, + "line": 889, "metadata": { "api_refs": [ "policyengine_us_data.calibration.unified_calibration.fit_l0_weights" @@ -1410,7 +1410,7 @@ "docstring": "Compute population-based initial weights from age targets.\n\nFor each congressional district, sums person_count targets where\ndomain_variable == \"age\" to get district population, then divides\nby the number of columns (households) active in that district.\n\nArgs:\n X_sparse: Sparse matrix (targets x records).\n targets_df: Targets DataFrame with columns: variable,\n domain_variable, geo_level, geographic_id, value.\n\nReturns:\n Weight array of shape (n_records,).", "id": "init_weights", "kind": "function", - "line": 814, + "line": 810, "metadata": { "api_refs": [ "policyengine_us_data.calibration.unified_calibration.compute_initial_weights" @@ -3472,7 +3472,7 @@ "docstring": "Run unified calibration pipeline.\n\nArgs:\n dataset_path: Path to CPS h5 file.\n db_path: Path to policy_data.db.\n n_clones: Number of dataset clones.\n lambda_l0: L0 regularization strength.\n epochs: Training epochs.\n device: Torch device.\n seed: Random seed.\n domain_variables: Filter targets by domain variable.\n hierarchical_domains: Domains for hierarchical\n uprating + CD reconciliation.\n skip_takeup_rerandomize: Skip takeup step.\n skip_source_impute: Skip ACS/SIPP/SCF imputations.\n target_config: Parsed target config dict.\n target_config_path: Path to target config, for provenance.\n target_config_identity: Resolved target config path/checksum identity.\n build_only: If True, save package and skip fitting.\n package_path: Load pre-built package (skip build).\n package_output_path: Where to save calibration package.\n beta: L0 gate temperature.\n lambda_l2: L2 regularization strength.\n learning_rate: Optimizer learning rate.\n log_freq: Epochs between per-target CSV logs.\n log_path: Path for per-target calibration log CSV.\n resume_from: Path to a checkpoint or weights file to\n continue fitting from.\n checkpoint_path: Where to save resumable fit checkpoints.\n chunked_matrix: Build matrix in clone-household chunks.\n chunk_size: Clone-household columns per chunk.\n chunk_dir: Directory for chunked COO/H5 artifacts.\n keep_chunks: Keep temporary chunk H5 files.\n resume_chunks: Reuse existing chunk COO files.\n\nReturns:\n (weights, targets_df, X_sparse, target_names, geography_info)\n weights is None when build_only=True.\n geography_info is a dict with cd_geoid and base_n_records.", "id": "run_calibration", "kind": "function", - "line": 1375, + "line": 1371, "metadata": { "api_refs": [ "policyengine_us_data.calibration.unified_calibration.run_calibration" @@ -3801,7 +3801,7 @@ "docstring": "Validate that a Stage 2 sidecar describes the calibration package.", "id": "stage2_calibration_package_contract_validator", "kind": "function", - "line": 379, + "line": 252, "metadata": { "api_refs": [ "policyengine_us_data.stage_contracts.calibration_package.validate_calibration_package_contract" @@ -3822,14 +3822,14 @@ ] }, "object_path": "policyengine_us_data.stage_contracts.calibration_package.validate_calibration_package_contract", - "signature": "def validate_calibration_package_contract(*, package_path: Path, contract_path: Path | None = None, package: Mapping[str, Any] | None = None, dataset_path: Path | None = None, db_path: Path | None = None) -> StageContract", + "signature": "def validate_calibration_package_contract(*, package_path: Path, contract_path: Path | None = None, package: CalibrationPackagePayload | Mapping[str, Any] | None = None, dataset_path: Path | None = None, db_path: Path | None = None) -> StageContract", "source_file": "policyengine_us_data/stage_contracts/calibration_package.py" }, "stage2_calibration_package_contract_writer": { "docstring": "Write and return the Stage 2 calibration-package contract.", "id": "stage2_calibration_package_contract_writer", "kind": "function", - "line": 322, + "line": 195, "metadata": { "api_refs": [ "policyengine_us_data.stage_contracts.calibration_package.write_calibration_package_contract" @@ -3853,14 +3853,14 @@ ] }, "object_path": "policyengine_us_data.stage_contracts.calibration_package.write_calibration_package_contract", - "signature": "def write_calibration_package_contract(*, package_path: Path, dataset_path: Path, db_path: Path, package: Mapping[str, Any], parameters: CalibrationPackageParameters | Mapping[str, Any], run_id: str | None, completed_at: str, started_at: str | None = None, duration_s: float | None = None, code_sha: str | None = None, package_version: str | None = None, contract_path: Path | None = None) -> StageContract", + "signature": "def write_calibration_package_contract(*, package_path: Path, dataset_path: Path, db_path: Path, package: CalibrationPackagePayload | Mapping[str, Any], parameters: CalibrationPackageParameters | Mapping[str, Any], run_id: str | None, completed_at: str, started_at: str | None = None, duration_s: float | None = None, code_sha: str | None = None, package_version: str | None = None, contract_path: Path | None = None) -> StageContract", "source_file": "policyengine_us_data/stage_contracts/calibration_package.py" }, "stage2_calibration_package_writer": { "docstring": "Save calibration package to pickle.\n\nArgs:\n path: Output file path.\n X_sparse: Sparse matrix.\n targets_df: Targets DataFrame.\n target_names: Target name list.\n metadata: Run metadata dict.\n initial_weights: Pre-computed initial weight array.\n cd_geoid: CD GEOID array from geography assignment.\n block_geoid: Block GEOID array from geography assignment.", "id": "stage2_calibration_package_writer", "kind": "function", - "line": 661, + "line": 663, "metadata": { "api_refs": [ "policyengine_us_data.calibration.unified_calibration.save_calibration_package" @@ -3914,11 +3914,95 @@ "signature": "def stage2_input_bundle_from_artifacts_dir(artifacts_dir: str | Path) -> Stage2InputBundle", "source_file": "policyengine_us_data/calibration_package/specs.py" }, + "stage2_payload_boundary": { + "docstring": "Typed access to the dictionary persisted in `calibration_package.pkl`.", + "id": "stage2_payload_boundary", + "kind": "class", + "line": 114, + "metadata": { + "api_refs": [ + "policyengine_us_data.calibration_package.payload.CalibrationPackagePayload" + ], + "artifacts_in": "[CALIBRATION_PACKAGE_FILENAME]", + "description": "Typed access to the calibration_package.pkl matrix, targets, metadata, geography arrays, and compatibility warnings.", + "id": "stage2_payload_boundary", + "label": "Stage 2 Package Payload", + "node_type": "library", + "pathways": [ + "calibration_package" + ], + "source_file": "policyengine_us_data/calibration_package/payload.py", + "stability": "moving", + "status": "current", + "validation_commands": [ + "uv run pytest tests/unit/calibration_package/test_payload.py" + ] + }, + "object_path": "policyengine_us_data.calibration_package.payload.CalibrationPackagePayload", + "signature": "class CalibrationPackagePayload", + "source_file": "policyengine_us_data/calibration_package/payload.py" + }, + "stage2_payload_reader": { + "docstring": "Read typed Stage 2 package payloads from disk.", + "id": "stage2_payload_reader", + "kind": "class", + "line": 328, + "metadata": { + "api_refs": [ + "policyengine_us_data.calibration_package.payload.CalibrationPackageReader" + ], + "artifacts_in": "[CALIBRATION_PACKAGE_FILENAME]", + "description": "Load calibration_package.pkl through the typed Stage 2 payload boundary and expose checksum/summary material.", + "id": "stage2_payload_reader", + "label": "Stage 2 Payload Reader", + "node_type": "library", + "pathways": [ + "calibration_package" + ], + "source_file": "policyengine_us_data/calibration_package/payload.py", + "stability": "moving", + "status": "current", + "validation_commands": [ + "uv run pytest tests/unit/calibration_package/test_payload.py" + ] + }, + "object_path": "policyengine_us_data.calibration_package.payload.CalibrationPackageReader", + "signature": "class CalibrationPackageReader", + "source_file": "policyengine_us_data/calibration_package/payload.py" + }, + "stage2_payload_writer": { + "docstring": "Write typed Stage 2 package payloads and metadata sidecars.", + "id": "stage2_payload_writer", + "kind": "class", + "line": 385, + "metadata": { + "api_refs": [ + "policyengine_us_data.calibration_package.payload.CalibrationPackageWriter" + ], + "artifacts_out": "[CALIBRATION_PACKAGE_FILENAME, CALIBRATION_PACKAGE_METADATA_FILENAME]", + "description": "Persist calibration_package.pkl and derive calibration_package_meta.json from typed payload and contract material.", + "id": "stage2_payload_writer", + "label": "Stage 2 Payload Writer", + "node_type": "library", + "pathways": [ + "calibration_package" + ], + "source_file": "policyengine_us_data/calibration_package/payload.py", + "stability": "moving", + "status": "current", + "validation_commands": [ + "uv run pytest tests/unit/calibration_package/test_payload.py" + ] + }, + "object_path": "policyengine_us_data.calibration_package.payload.CalibrationPackageWriter", + "signature": "class CalibrationPackageWriter", + "source_file": "policyengine_us_data/calibration_package/payload.py" + }, "stage2_target_config_apply": { "docstring": "Filter target rows before matrix construction.", "id": "stage2_target_config_apply", "kind": "function", - "line": 631, + "line": 633, "metadata": { "api_refs": [ "policyengine_us_data.calibration.unified_calibration.apply_target_config_to_targets" @@ -3973,7 +4057,7 @@ "docstring": "Load target include/exclude config from YAML.\n\nArgs:\n path: Path to YAML config file.\n\nReturns:\n Parsed config dict with include and exclude lists.", "id": "stage2_target_config_load", "kind": "function", - "line": 525, + "line": 527, "metadata": { "api_refs": [ "policyengine_us_data.calibration.unified_calibration.load_target_config" diff --git a/docs/generated/pipeline_map.json b/docs/generated/pipeline_map.json index ea6f4fb0e..b29f78deb 100644 --- a/docs/generated/pipeline_map.json +++ b/docs/generated/pipeline_map.json @@ -1998,8 +1998,8 @@ "metadata": { "api_node_count": 96, "canonical_stage_count": 5, - "decorated_object_count": 156, - "mapped_decorated_node_count": 60, + "decorated_object_count": 159, + "mapped_decorated_node_count": 63, "stage_count": 17, "substage_count": 17 }, @@ -3970,11 +3970,17 @@ { "edge_type": "data_flow", "source": "build_matrix", - "target": "stage2_calibration_package_writer" + "target": "stage2_payload_boundary" }, { "edge_type": "data_flow", "source": "build_matrix_chunked", + "target": "stage2_payload_boundary" + }, + { + "edge_type": "data_flow", + "label": "typed package payload", + "source": "stage2_payload_boundary", "target": "stage2_calibration_package_writer" }, { @@ -3984,13 +3990,30 @@ "target": "stage2_calibration_package_writer" }, { - "edge_type": "produces_artifact", + "edge_type": "uses_library", + "label": "pickle write", "source": "stage2_calibration_package_writer", + "target": "stage2_payload_writer" + }, + { + "edge_type": "produces_artifact", + "source": "stage2_payload_writer", "target": "out_pkg" }, { "edge_type": "data_flow", "source": "out_pkg", + "target": "stage2_payload_reader" + }, + { + "edge_type": "data_flow", + "source": "out_pkg", + "target": "stage2_calibration_package_contract_writer" + }, + { + "edge_type": "uses_library", + "label": "summary and checksum", + "source": "stage2_payload_reader", "target": "stage2_calibration_package_contract_writer" }, { @@ -4004,6 +4027,18 @@ "source": "stage2_calibration_package_contract_writer", "target": "out_contract" }, + { + "edge_type": "data_flow", + "label": "sidecar contract material", + "source": "out_contract", + "target": "stage2_payload_writer" + }, + { + "edge_type": "produces_artifact", + "label": "sidecar metadata", + "source": "stage2_payload_writer", + "target": "out_metadata" + }, { "edge_type": "validates", "source": "out_pkg", @@ -4067,8 +4102,12 @@ "build_matrix", "build_matrix_chunked", "stage2_artifact_specs", + "stage2_payload_boundary", + "stage2_payload_writer", + "stage2_payload_reader", "stage2_calibration_package_writer", "out_pkg", + "out_metadata", "stage2_calibration_package_contract_writer", "out_contract", "stage2_calibration_package_contract_validator" @@ -4154,6 +4193,12 @@ "label": "calibration_package.pkl", "node_type": "artifact" }, + { + "description": "Metadata sidecar generated from the typed package payload and Stage 2 contract", + "id": "out_metadata", + "label": "calibration_package_meta.json", + "node_type": "artifact" + }, { "description": "Stage 2 package handoff contract written next to calibration_package.pkl", "id": "out_contract", @@ -4405,6 +4450,63 @@ "uv run pytest tests/integration/test_chunked_matrix_builder.py" ] }, + { + "api_refs": [ + "policyengine_us_data.calibration_package.payload.CalibrationPackagePayload" + ], + "artifacts_in": "[CALIBRATION_PACKAGE_FILENAME]", + "description": "Typed access to the calibration_package.pkl matrix, targets, metadata, geography arrays, and compatibility warnings.", + "id": "stage2_payload_boundary", + "label": "Stage 2 Package Payload", + "node_type": "library", + "pathways": [ + "calibration_package" + ], + "source_file": "policyengine_us_data/calibration_package/payload.py", + "stability": "moving", + "status": "current", + "validation_commands": [ + "uv run pytest tests/unit/calibration_package/test_payload.py" + ] + }, + { + "api_refs": [ + "policyengine_us_data.calibration_package.payload.CalibrationPackageWriter" + ], + "artifacts_out": "[CALIBRATION_PACKAGE_FILENAME, CALIBRATION_PACKAGE_METADATA_FILENAME]", + "description": "Persist calibration_package.pkl and derive calibration_package_meta.json from typed payload and contract material.", + "id": "stage2_payload_writer", + "label": "Stage 2 Payload Writer", + "node_type": "library", + "pathways": [ + "calibration_package" + ], + "source_file": "policyengine_us_data/calibration_package/payload.py", + "stability": "moving", + "status": "current", + "validation_commands": [ + "uv run pytest tests/unit/calibration_package/test_payload.py" + ] + }, + { + "api_refs": [ + "policyengine_us_data.calibration_package.payload.CalibrationPackageReader" + ], + "artifacts_in": "[CALIBRATION_PACKAGE_FILENAME]", + "description": "Load calibration_package.pkl through the typed Stage 2 payload boundary and expose checksum/summary material.", + "id": "stage2_payload_reader", + "label": "Stage 2 Payload Reader", + "node_type": "library", + "pathways": [ + "calibration_package" + ], + "source_file": "policyengine_us_data/calibration_package/payload.py", + "stability": "moving", + "status": "current", + "validation_commands": [ + "uv run pytest tests/unit/calibration_package/test_payload.py" + ] + }, { "api_refs": [ "policyengine_us_data.stage_contracts.calibration_package.write_calibration_package_contract" diff --git a/docs/pipeline_map.yaml b/docs/pipeline_map.yaml index 02d5581c5..c4e35fc87 100644 --- a/docs/pipeline_map.yaml +++ b/docs/pipeline_map.yaml @@ -821,8 +821,12 @@ stages: - build_matrix - build_matrix_chunked - stage2_artifact_specs + - stage2_payload_boundary + - stage2_payload_writer + - stage2_payload_reader - stage2_calibration_package_writer - out_pkg + - out_metadata - stage2_calibration_package_contract_writer - out_contract - stage2_calibration_package_contract_validator @@ -875,6 +879,10 @@ stages: label: calibration_package.pkl node_type: artifact description: X_sparse CSR matrix, targets_df, initial_weights, metadata + - id: out_metadata + label: calibration_package_meta.json + node_type: artifact + description: Metadata sidecar generated from the typed package payload and Stage 2 contract - id: out_contract label: calibration_package_contract.json node_type: artifact @@ -983,21 +991,36 @@ stages: edge_type: uses_library label: chunked path - source: build_matrix - target: stage2_calibration_package_writer + target: stage2_payload_boundary edge_type: data_flow - source: build_matrix_chunked + target: stage2_payload_boundary + edge_type: data_flow + - source: stage2_payload_boundary target: stage2_calibration_package_writer edge_type: data_flow + label: typed package payload - source: stage2_artifact_specs target: stage2_calibration_package_writer edge_type: uses_utility label: package path - source: stage2_calibration_package_writer + target: stage2_payload_writer + edge_type: uses_library + label: pickle write + - source: stage2_payload_writer target: out_pkg edge_type: produces_artifact + - source: out_pkg + target: stage2_payload_reader + edge_type: data_flow - source: out_pkg target: stage2_calibration_package_contract_writer edge_type: data_flow + - source: stage2_payload_reader + target: stage2_calibration_package_contract_writer + edge_type: uses_library + label: summary and checksum - source: stage2_artifact_specs target: stage2_calibration_package_contract_writer edge_type: uses_utility @@ -1005,6 +1028,14 @@ stages: - source: stage2_calibration_package_contract_writer target: out_contract edge_type: produces_artifact + - source: out_contract + target: stage2_payload_writer + edge_type: data_flow + label: sidecar contract material + - source: stage2_payload_writer + target: out_metadata + edge_type: produces_artifact + label: sidecar metadata - source: out_pkg target: stage2_calibration_package_contract_validator edge_type: validates diff --git a/modal_app/remote_calibration_runner.py b/modal_app/remote_calibration_runner.py index 969cf405d..cbf750ce5 100644 --- a/modal_app/remote_calibration_runner.py +++ b/modal_app/remote_calibration_runner.py @@ -363,23 +363,30 @@ def _print_provenance_from_meta(meta: dict, current_branch: str = None) -> None: def _write_package_sidecar(pkg_path: str) -> bool: - """Extract metadata from a pickle package and write a JSON sidecar. + """Write package metadata from the typed payload and contract sidecar. Returns: True if sidecar was written successfully, False otherwise. """ - import json import logging - import pickle - sidecar_path = pkg_path.replace(".pkl", "_meta.json") try: - with open(pkg_path, "rb") as f: - package = pickle.load(f) - meta = package.get("metadata", {}) - del package - with open(sidecar_path, "w") as f: - json.dump(meta, f, indent=2) + from policyengine_us_data.calibration_package.payload import ( + CalibrationPackageReader, + CalibrationPackageWriter, + ) + from policyengine_us_data.calibration_package.specs import ( + CALIBRATION_PACKAGE_CONTRACT_FILENAME, + ) + from policyengine_us_data.stage_contracts.io import read_contract + + package_path = Path(pkg_path) + payload = CalibrationPackageReader(package_path=package_path).read() + contract_path = package_path.with_name(CALIBRATION_PACKAGE_CONTRACT_FILENAME) + contract = read_contract(contract_path) if contract_path.exists() else None + sidecar_path = CalibrationPackageWriter( + package_path=package_path, + ).write_metadata_sidecar(payload, contract=contract) print( f"Sidecar metadata written to {sidecar_path}", flush=True, diff --git a/policyengine_us_data/calibration/unified_calibration.py b/policyengine_us_data/calibration/unified_calibration.py index f3943573b..bdc411f77 100644 --- a/policyengine_us_data/calibration/unified_calibration.py +++ b/policyengine_us_data/calibration/unified_calibration.py @@ -41,8 +41,10 @@ build_checkpoint_signature, checkpoint_signature_mismatches, ) -from policyengine_us_data.calibration.calibration_utils import ( - create_target_groups, +from policyengine_us_data.calibration_package.payload import ( + CalibrationPackagePayload, + CalibrationPackageReader, + CalibrationPackageWriter, ) from policyengine_us_data.calibration_package.specs import ( DEFAULT_TARGET_CONFIG_PATH as DEFAULT_TARGET_CONFIG_RELATIVE_PATH, @@ -680,20 +682,16 @@ def save_calibration_package( cd_geoid: CD GEOID array from geography assignment. block_geoid: Block GEOID array from geography assignment. """ - import pickle - - package = { - "X_sparse": X_sparse, - "targets_df": targets_df, - "target_names": target_names, - "metadata": metadata, - "initial_weights": initial_weights, - "cd_geoid": cd_geoid, - "block_geoid": block_geoid, - } - Path(path).parent.mkdir(parents=True, exist_ok=True) - with open(path, "wb") as f: - pickle.dump(package, f, protocol=pickle.HIGHEST_PROTOCOL) + payload = CalibrationPackagePayload( + X_sparse=X_sparse, + targets_df=targets_df, + target_names=target_names, + metadata=metadata, + initial_weights=initial_weights, + cd_geoid=cd_geoid, + block_geoid=block_geoid, + ) + CalibrationPackageWriter(package_path=Path(path)).write(payload) logger.info("Calibration package saved to %s", path) @@ -706,16 +704,14 @@ def load_calibration_package(path: str) -> dict: Returns: Dict with X_sparse, targets_df, target_names, metadata. """ - import pickle - - with open(path, "rb") as f: - package = pickle.load(f) + payload = CalibrationPackageReader(package_path=Path(path)).read() + package = payload.to_mapping() logger.info( "Loaded package: %d targets, %d records", - package["X_sparse"].shape[0], - package["X_sparse"].shape[1], + payload.X_sparse.shape[0], + payload.X_sparse.shape[1], ) - meta = package.get("metadata", {}) + meta = payload.metadata print_package_provenance(meta) check_package_staleness(meta) return package @@ -1732,15 +1728,15 @@ def run_calibration( initial_weights = compute_initial_weights(X_sparse, targets_df) if package_output_path: - package_payload = { - "X_sparse": X_sparse, - "targets_df": targets_df, - "target_names": target_names, - "metadata": metadata, - "initial_weights": initial_weights, - "cd_geoid": geography.cd_geoid, - "block_geoid": geography.block_geoid, - } + package_payload = CalibrationPackagePayload( + X_sparse=X_sparse, + targets_df=targets_df, + target_names=target_names, + metadata=metadata, + initial_weights=initial_weights, + cd_geoid=geography.cd_geoid, + block_geoid=geography.block_geoid, + ) save_calibration_package( package_output_path, X_sparse, diff --git a/policyengine_us_data/calibration_package/__init__.py b/policyengine_us_data/calibration_package/__init__.py index 26655041a..dd2c02b74 100644 --- a/policyengine_us_data/calibration_package/__init__.py +++ b/policyengine_us_data/calibration_package/__init__.py @@ -26,6 +26,15 @@ stage2_input_bundle_from_stage1_contract, stage2_input_bundle_from_stage1_contract_path, ) +from .payload import ( + LEGACY_MISSING_GEOGRAPHY_WARNING, + REQUIRED_PACKAGE_KEYS, + CalibrationPackagePayloadError, + CalibrationPackagePayload, + CalibrationPackageReader, + CalibrationPackageWriter, + calibration_package_payload_failure_report, +) __all__ = [ "CALIBRATION_PACKAGE_CONTRACT_FILENAME", @@ -41,10 +50,16 @@ "TARGET_DATABASE_FILENAME", "CalibrationPackageArtifactPaths", "CalibrationPackageOutputBundle", + "CalibrationPackagePayload", + "CalibrationPackagePayloadError", + "CalibrationPackageReader", + "CalibrationPackageWriter", + "LEGACY_MISSING_GEOGRAPHY_WARNING", "Stage2BuildContext", "Stage2InputBundle", "Stage2InputBundleError", "Stage2InputSource", + "REQUIRED_PACKAGE_KEYS", "TargetConfigIdentity", "calibration_package_artifact_paths", "resolve_target_config_identity", @@ -52,4 +67,5 @@ "stage2_input_bundle_from_artifacts_dir", "stage2_input_bundle_from_stage1_contract", "stage2_input_bundle_from_stage1_contract_path", + "calibration_package_payload_failure_report", ] diff --git a/policyengine_us_data/calibration_package/payload.py b/policyengine_us_data/calibration_package/payload.py new file mode 100644 index 000000000..7c06ce35f --- /dev/null +++ b/policyengine_us_data/calibration_package/payload.py @@ -0,0 +1,480 @@ +"""Typed reader and writer boundary for Stage 2 package payloads.""" + +from __future__ import annotations + +import json +import pickle +from collections.abc import Mapping +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from policyengine_us_data.pipeline_metadata import pipeline_node +from policyengine_us_data.pipeline_schema import PipelineNode +from policyengine_us_data.utils.geography_checksum import ( + canonical_geography_checksum, + hash_string_array, +) +from policyengine_us_data.utils.step_manifest import sha256_file + +from .specs import CALIBRATION_PACKAGE_FILENAME, CALIBRATION_PACKAGE_METADATA_FILENAME + +if TYPE_CHECKING: + from policyengine_us_data.stage_contracts.calibration_package_schema import ( + CalibrationPackageSummary, + GeographyAssignmentSummary, + ) + from policyengine_us_data.stage_contracts.validation import ValidationReport + +REQUIRED_PACKAGE_KEYS: frozenset[str] = frozenset( + {"X_sparse", "targets_df", "target_names", "metadata"} +) +LEGACY_MISSING_GEOGRAPHY_WARNING = ( + "legacy packages without block_geoid/cd_geoid cannot prove geography assignment" +) + + +class CalibrationPackagePayloadError(RuntimeError): + """Read/write failure with a canonical Stage 2 validation report.""" + + def __init__( + self, + *, + operation: str, + package_path: str | Path, + error: Exception, + ) -> None: + report = calibration_package_payload_failure_report( + operation=operation, + package_path=package_path, + error=error, + ) + details = "; ".join(finding.message for finding in report.findings) + super().__init__(details or "Stage 2 package payload operation failed") + self.operation = operation + self.package_path = Path(package_path) + self.validation_report = report + + +def calibration_package_payload_failure_report( + *, + operation: str, + package_path: str | Path, + error: Exception, +) -> "ValidationReport": + """Return a canonical validation report for package payload failures.""" + + from policyengine_us_data.stage_contracts.validation import ( + ValidationFinding, + ValidationReport, + ) + + path = Path(package_path) + return ValidationReport( + status="fail", + findings=( + ValidationFinding( + check_id=f"stage2_payload_{operation}", + status="fail", + message=( + f"Stage 2 calibration package payload {operation} failed " + f"for {path}: {error}" + ), + metadata={ + "operation": operation, + "package_path": str(path), + "error_type": error.__class__.__name__, + }, + ), + ), + metadata={ + "artifact": CALIBRATION_PACKAGE_FILENAME, + "package_path": str(path), + }, + ) + + +@pipeline_node( + PipelineNode( + id="stage2_payload_boundary", + label="Stage 2 Package Payload", + node_type="library", + description="Typed access to the calibration_package.pkl matrix, targets, metadata, geography arrays, and compatibility warnings.", + source_file="policyengine_us_data/calibration_package/payload.py", + status="current", + stability="moving", + pathways=["calibration_package"], + artifacts_in=[CALIBRATION_PACKAGE_FILENAME], + validation_commands=[ + "uv run pytest tests/unit/calibration_package/test_payload.py" + ], + ) +) +@dataclass(frozen=True, kw_only=True) +class CalibrationPackagePayload: + """Typed access to the dictionary persisted in `calibration_package.pkl`.""" + + X_sparse: Any + targets_df: Any + target_names: Any + metadata: Mapping[str, Any] + initial_weights: Any | None = None + cd_geoid: Any | None = None + block_geoid: Any | None = None + compatibility_warnings: tuple[str, ...] = () + + @classmethod + def from_mapping( + cls, + package: Mapping[str, Any], + *, + require_required_keys: bool = True, + ) -> "CalibrationPackagePayload": + """Validate and wrap a legacy package mapping.""" + + if not isinstance(package, Mapping): + raise ValueError("Calibration package pickle must contain a mapping") + missing = sorted(REQUIRED_PACKAGE_KEYS - set(package)) + if missing and require_required_keys: + raise ValueError(f"Calibration package missing required key: {missing[0]}") + metadata = package.get("metadata", {}) + if metadata is None: + metadata = {} + if not isinstance(metadata, Mapping): + raise ValueError("Calibration package metadata must be a mapping") + cd_geoid = package.get("cd_geoid") + block_geoid = package.get("block_geoid") + warnings: list[str] = [] + if cd_geoid is None and block_geoid is None: + warnings.append(LEGACY_MISSING_GEOGRAPHY_WARNING) + return cls( + X_sparse=package.get("X_sparse"), + targets_df=package.get("targets_df"), + target_names=package.get("target_names"), + metadata=dict(metadata), + initial_weights=package.get("initial_weights"), + cd_geoid=cd_geoid, + block_geoid=block_geoid, + compatibility_warnings=tuple(warnings), + ) + + def to_mapping(self) -> dict[str, Any]: + """Return the pickle-compatible package dictionary.""" + + return { + "X_sparse": self.X_sparse, + "targets_df": self.targets_df, + "target_names": self.target_names, + "metadata": dict(self.metadata), + "initial_weights": self.initial_weights, + "cd_geoid": self.cd_geoid, + "block_geoid": self.block_geoid, + } + + def summary(self) -> CalibrationPackageSummary: + """Return the contract-safe package summary.""" + + from policyengine_us_data.stage_contracts.calibration_package_schema import ( + CalibrationPackageSummary, + ) + + try: + n_targets, n_columns = self.X_sparse.shape + except (AttributeError, ValueError) as exc: + raise ValueError("X_sparse must expose a two-dimensional shape") from exc + if not hasattr(self.X_sparse, "nnz"): + raise ValueError("X_sparse must expose nnz") + + n_targets = int(n_targets) + n_columns = int(n_columns) + nnz = int(self.X_sparse.nnz) + density = nnz / (n_targets * n_columns) if n_targets * n_columns else 0.0 + + return CalibrationPackageSummary( + matrix_shape=(n_targets, n_columns), + matrix_nnz=nnz, + matrix_density=float(density), + n_targets=int(len(self.targets_df)), + n_columns=n_columns, + target_name_count=int(len(self.target_names)), + dataset_sha256=self.metadata_string("dataset_sha256"), + db_sha256=self.metadata_string("db_sha256"), + target_config_path=self.metadata_string("target_config_path"), + target_config_sha256=self.metadata_string("target_config_sha256"), + n_clones=self.metadata_int("n_clones"), + seed=self.metadata_int("seed"), + base_n_records=self.metadata_int("base_n_records"), + package_scope=self.metadata_string("package_scope"), + matrix_builder=self.metadata_string("matrix_builder"), + chunk_size=self.metadata_int("chunk_size"), + chunk_dir=self.metadata_string("chunk_dir"), + has_initial_weights=self.initial_weights is not None, + has_cd_geoid=self.cd_geoid is not None, + has_block_geoid=self.block_geoid is not None, + cd_geoid_length=_optional_len(self.cd_geoid), + block_geoid_length=_optional_len(self.block_geoid), + ) + + def geography_summary(self) -> GeographyAssignmentSummary: + """Return the contract-safe geography assignment summary.""" + + from policyengine_us_data.stage_contracts.calibration_package_schema import ( + GeographyAssignmentSummary, + ) + + n_records = self.metadata_int("base_n_records") + n_clones = self.metadata_int("n_clones") + has_blocks = self.block_geoid is not None + has_cds = self.cd_geoid is not None + + if not has_blocks and not has_cds: + return GeographyAssignmentSummary( + source_kind="unavailable", + n_records=n_records, + n_clones=n_clones, + n_rows=None, + has_block_geoid=False, + has_cd_geoid=False, + block_geoid_length=None, + cd_geoid_length=None, + block_geoid_sha256=None, + cd_geoid_sha256=None, + canonical_geography_sha256=None, + ) + if not has_blocks or not has_cds: + raise ValueError( + "Calibration package geography requires both block_geoid and cd_geoid" + ) + if n_records is None or n_clones is None: + raise ValueError( + "Calibration package geography requires metadata base_n_records and n_clones" + ) + if n_records <= 0 or n_clones <= 0: + raise ValueError( + "Calibration package geography requires positive base_n_records and n_clones" + ) + + block_geoids = _one_dimensional_string_array(self.block_geoid, "block_geoid") + cd_geoids = _one_dimensional_string_array(self.cd_geoid, "cd_geoid") + n_rows = int(len(block_geoids)) + if n_rows == 0: + raise ValueError("Calibration package geography arrays must be non-empty") + if len(cd_geoids) != n_rows: + raise ValueError( + "Calibration package geography has mismatched block_geoid and cd_geoid " + f"lengths: {n_rows} != {len(cd_geoids)}" + ) + if n_records * n_clones != n_rows: + raise ValueError( + "Calibration package geography length does not match metadata: " + f"{n_rows} rows for {n_records} records x {n_clones} clones" + ) + + return GeographyAssignmentSummary( + source_kind="calibration_package", + n_records=n_records, + n_clones=n_clones, + n_rows=n_rows, + has_block_geoid=True, + has_cd_geoid=True, + block_geoid_length=n_rows, + cd_geoid_length=int(len(cd_geoids)), + block_geoid_sha256=hash_string_array(block_geoids), + cd_geoid_sha256=hash_string_array(cd_geoids), + canonical_geography_sha256=canonical_geography_checksum( + block_geoid=block_geoids, + cd_geoid=cd_geoids, + n_records=n_records, + n_clones=n_clones, + ), + ) + + def metadata_string(self, key: str) -> str | None: + """Return a metadata value coerced to a string, preserving nulls.""" + + value = self.metadata.get(key) + if value is None: + return None + return str(value) + + def metadata_int(self, key: str) -> int | None: + """Return a metadata value coerced to an integer, preserving nulls.""" + + value = self.metadata.get(key) + if value is None: + return None + if isinstance(value, bool): + raise ValueError(f"Calibration package metadata {key!r} must be an integer") + return int(value) + + +@pipeline_node( + PipelineNode( + id="stage2_payload_reader", + label="Stage 2 Payload Reader", + node_type="library", + description="Load calibration_package.pkl through the typed Stage 2 payload boundary and expose checksum/summary material.", + source_file="policyengine_us_data/calibration_package/payload.py", + status="current", + stability="moving", + pathways=["calibration_package"], + artifacts_in=[CALIBRATION_PACKAGE_FILENAME], + validation_commands=[ + "uv run pytest tests/unit/calibration_package/test_payload.py" + ], + ) +) +@dataclass(frozen=True, kw_only=True) +class CalibrationPackageReader: + """Read typed Stage 2 package payloads from disk.""" + + package_path: Path + + def read(self) -> CalibrationPackagePayload: + """Load and validate the persisted package payload.""" + + try: + with Path(self.package_path).open("rb") as handle: + package = pickle.load(handle) + return CalibrationPackagePayload.from_mapping(package) + except Exception as exc: + raise CalibrationPackagePayloadError( + operation="read", + package_path=self.package_path, + error=exc, + ) from exc + + def checksum(self) -> str: + """Return the package file checksum used for reuse comparisons.""" + + try: + return f"sha256:{sha256_file(Path(self.package_path))}" + except Exception as exc: + raise CalibrationPackagePayloadError( + operation="checksum", + package_path=self.package_path, + error=exc, + ) from exc + + def summary(self) -> CalibrationPackageSummary: + """Read the package and return its summary.""" + + return self.read().summary() + + +@pipeline_node( + PipelineNode( + id="stage2_payload_writer", + label="Stage 2 Payload Writer", + node_type="library", + description="Persist calibration_package.pkl and derive calibration_package_meta.json from typed payload and contract material.", + source_file="policyengine_us_data/calibration_package/payload.py", + status="current", + stability="moving", + pathways=["calibration_package"], + artifacts_out=[ + CALIBRATION_PACKAGE_FILENAME, + CALIBRATION_PACKAGE_METADATA_FILENAME, + ], + validation_commands=[ + "uv run pytest tests/unit/calibration_package/test_payload.py" + ], + ) +) +@dataclass(frozen=True, kw_only=True) +class CalibrationPackageWriter: + """Write typed Stage 2 package payloads and metadata sidecars.""" + + package_path: Path + + def write(self, payload: CalibrationPackagePayload) -> Path: + """Persist a package payload using the legacy pickle format.""" + + try: + output_path = Path(self.package_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("wb") as handle: + pickle.dump( + payload.to_mapping(), handle, protocol=pickle.HIGHEST_PROTOCOL + ) + return output_path + except Exception as exc: + raise CalibrationPackagePayloadError( + operation="write", + package_path=self.package_path, + error=exc, + ) from exc + + def write_metadata_sidecar( + self, + payload: CalibrationPackagePayload, + *, + contract: Any | None = None, + sidecar_path: str | Path | None = None, + ) -> Path: + """Write `calibration_package_meta.json` from typed payload material.""" + + try: + output_path = ( + Path(sidecar_path) + if sidecar_path is not None + else Path(self.package_path).with_name( + CALIBRATION_PACKAGE_METADATA_FILENAME + ) + ) + sidecar_payload = { + **dict(payload.metadata), + "package_sha256": f"sha256:{sha256_file(Path(self.package_path))}", + "package_summary": payload.summary().to_dict(), + "geography_assignment": payload.geography_summary().to_dict(), + "compatibility_warnings": list(payload.compatibility_warnings), + } + if contract is not None: + sidecar_payload["contract"] = { + "stage_id": getattr(contract, "stage_id", None), + "contract_type": getattr(contract, "contract_type", None), + "fingerprint": ( + contract.fingerprint.to_dict() + if getattr(contract, "fingerprint", None) is not None + else None + ), + } + output_path.write_text( + json.dumps(sidecar_payload, indent=2, sort_keys=True), + encoding="utf-8", + ) + return output_path + except Exception as exc: + raise CalibrationPackagePayloadError( + operation="write_metadata_sidecar", + package_path=self.package_path, + error=exc, + ) from exc + + +def _optional_len(value: Any) -> int | None: + if value is None: + return None + return int(len(value)) + + +def _one_dimensional_string_array(value: Any, key: str) -> Any: + import numpy as np + + array = np.asarray(value, dtype=str) + if array.ndim != 1: + raise ValueError(f"Calibration package geography {key} must be one-dimensional") + if np.any(array == ""): + raise ValueError(f"Calibration package geography {key} contains empty values") + return array + + +__all__ = [ + "LEGACY_MISSING_GEOGRAPHY_WARNING", + "REQUIRED_PACKAGE_KEYS", + "CalibrationPackagePayloadError", + "CalibrationPackagePayload", + "CalibrationPackageReader", + "CalibrationPackageWriter", + "calibration_package_payload_failure_report", +] diff --git a/policyengine_us_data/stage_contracts/calibration_package.py b/policyengine_us_data/stage_contracts/calibration_package.py index 27125c6b7..ccf97a206 100644 --- a/policyengine_us_data/stage_contracts/calibration_package.py +++ b/policyengine_us_data/stage_contracts/calibration_package.py @@ -2,11 +2,14 @@ from __future__ import annotations -import pickle from collections.abc import Mapping from pathlib import Path from typing import Any +from policyengine_us_data.calibration_package.payload import ( + CalibrationPackagePayload, + CalibrationPackageReader, +) from policyengine_us_data.calibration_package.specs import ( CALIBRATION_PACKAGE_CONTRACT_FILENAME, CALIBRATION_PACKAGE_SUBSTAGE_ID, @@ -14,10 +17,6 @@ from policyengine_us_data.pipeline_metadata import pipeline_node from policyengine_us_data.pipeline_schema import PipelineNode from policyengine_us_data.utils.step_manifest import sha256_file -from policyengine_us_data.utils.geography_checksum import ( - canonical_geography_checksum, - hash_string_array, -) from .artifacts import ArtifactRef from .calibration_package_schema import ( @@ -38,146 +37,19 @@ def summarize_geography_assignment( - package: Mapping[str, Any], + package: CalibrationPackagePayload | Mapping[str, Any], ) -> GeographyAssignmentSummary: """Return a contract-safe summary of package-backed geography assignment.""" - metadata = _package_metadata(package) - n_records = _optional_metadata_int(metadata, "base_n_records") - n_clones = _optional_metadata_int(metadata, "n_clones") - raw_blocks = package.get("block_geoid") - raw_cds = package.get("cd_geoid") - has_blocks = raw_blocks is not None - has_cds = raw_cds is not None - - if not has_blocks and not has_cds: - return _unavailable_geography_assignment_summary( - n_records=n_records, - n_clones=n_clones, - ) - if not has_blocks or not has_cds: - raise ValueError( - "Calibration package geography requires both block_geoid and cd_geoid" - ) - if n_records is None or n_clones is None: - raise ValueError( - "Calibration package geography requires metadata base_n_records and n_clones" - ) - if n_records <= 0 or n_clones <= 0: - raise ValueError( - "Calibration package geography requires positive base_n_records and n_clones" - ) - - block_geoids = _one_dimensional_string_array(raw_blocks, "block_geoid") - cd_geoids = _one_dimensional_string_array(raw_cds, "cd_geoid") - n_rows = int(len(block_geoids)) - if n_rows == 0: - raise ValueError("Calibration package geography arrays must be non-empty") - if len(cd_geoids) != n_rows: - raise ValueError( - "Calibration package geography has mismatched block_geoid and cd_geoid " - f"lengths: {n_rows} != {len(cd_geoids)}" - ) - if n_records * n_clones != n_rows: - raise ValueError( - "Calibration package geography length does not match metadata: " - f"{n_rows} rows for {n_records} records x {n_clones} clones" - ) - - return GeographyAssignmentSummary( - source_kind="calibration_package", - n_records=n_records, - n_clones=n_clones, - n_rows=n_rows, - has_block_geoid=True, - has_cd_geoid=True, - block_geoid_length=n_rows, - cd_geoid_length=int(len(cd_geoids)), - block_geoid_sha256=hash_string_array(block_geoids), - cd_geoid_sha256=hash_string_array(cd_geoids), - canonical_geography_sha256=canonical_geography_checksum( - block_geoid=block_geoids, - cd_geoid=cd_geoids, - n_records=n_records, - n_clones=n_clones, - ), - ) - - -def _unavailable_geography_assignment_summary( - *, - n_records: int | None, - n_clones: int | None, -) -> GeographyAssignmentSummary: - """Create a summary for legacy packages without geography assignment arrays.""" - - return GeographyAssignmentSummary( - source_kind="unavailable", - n_records=n_records, - n_clones=n_clones, - n_rows=None, - has_block_geoid=False, - has_cd_geoid=False, - block_geoid_length=None, - cd_geoid_length=None, - block_geoid_sha256=None, - cd_geoid_sha256=None, - canonical_geography_sha256=None, - ) + return _calibration_package_payload(package, require_core=False).geography_summary() def summarize_calibration_package( - package: Mapping[str, Any], + package: CalibrationPackagePayload | Mapping[str, Any], ) -> CalibrationPackageSummary: """Return a contract-safe summary of a calibration package pickle payload.""" - matrix = _required_package_value(package, "X_sparse") - targets_df = _required_package_value(package, "targets_df") - target_names = _required_package_value(package, "target_names") - metadata = _package_metadata(package) - - try: - n_targets, n_columns = matrix.shape - except (AttributeError, ValueError) as exc: - raise ValueError("X_sparse must expose a two-dimensional shape") from exc - if not hasattr(matrix, "nnz"): - raise ValueError("X_sparse must expose nnz") - - n_targets = int(n_targets) - n_columns = int(n_columns) - nnz = int(matrix.nnz) - density = nnz / (n_targets * n_columns) if n_targets * n_columns else 0.0 - - return CalibrationPackageSummary( - matrix_shape=(n_targets, n_columns), - matrix_nnz=nnz, - matrix_density=float(density), - n_targets=int(len(targets_df)), - n_columns=n_columns, - target_name_count=int(len(target_names)), - dataset_sha256=_optional_metadata_string(metadata, "dataset_sha256"), - db_sha256=_optional_metadata_string(metadata, "db_sha256"), - target_config_path=_optional_metadata_string( - metadata, - "target_config_path", - ), - target_config_sha256=_optional_metadata_string( - metadata, - "target_config_sha256", - ), - n_clones=_optional_metadata_int(metadata, "n_clones"), - seed=_optional_metadata_int(metadata, "seed"), - base_n_records=_optional_metadata_int(metadata, "base_n_records"), - package_scope=_optional_metadata_string(metadata, "package_scope"), - matrix_builder=_optional_metadata_string(metadata, "matrix_builder"), - chunk_size=_optional_metadata_int(metadata, "chunk_size"), - chunk_dir=_optional_metadata_string(metadata, "chunk_dir"), - has_initial_weights=package.get("initial_weights") is not None, - has_cd_geoid=package.get("cd_geoid") is not None, - has_block_geoid=package.get("block_geoid") is not None, - cd_geoid_length=_optional_len(package.get("cd_geoid")), - block_geoid_length=_optional_len(package.get("block_geoid")), - ) + return _calibration_package_payload(package).summary() def build_calibration_package_contract( @@ -185,7 +57,7 @@ def build_calibration_package_contract( package_path: Path, dataset_path: Path, db_path: Path, - package: Mapping[str, Any], + package: CalibrationPackagePayload | Mapping[str, Any], parameters: CalibrationPackageParameters | Mapping[str, Any], run_id: str | None, completed_at: str, @@ -204,13 +76,14 @@ def build_calibration_package_contract( _require_existing_file(db_path, "target database") parameter_schema = _calibration_package_parameters(parameters) - metadata = _package_metadata(package) + payload = _calibration_package_payload(package) + metadata = payload.metadata parameter_payload = _parameters_with_package_identity( parameter_schema.to_dict(), metadata, ) - package_summary = summarize_calibration_package(package).to_dict() - geography_summary = summarize_geography_assignment(package).to_dict() + package_summary = payload.summary().to_dict() + geography_summary = payload.geography_summary().to_dict() inputs = ( _artifact_ref_from_path( logical_name="source_imputed_stratified_extended_cps", @@ -324,7 +197,7 @@ def write_calibration_package_contract( package_path: Path, dataset_path: Path, db_path: Path, - package: Mapping[str, Any], + package: CalibrationPackagePayload | Mapping[str, Any], parameters: CalibrationPackageParameters | Mapping[str, Any], run_id: str | None, completed_at: str, @@ -380,7 +253,7 @@ def validate_calibration_package_contract( *, package_path: Path, contract_path: Path | None = None, - package: Mapping[str, Any] | None = None, + package: CalibrationPackagePayload | Mapping[str, Any] | None = None, dataset_path: Path | None = None, db_path: Path | None = None, ) -> StageContract: @@ -462,29 +335,29 @@ def validate_persisted_calibration_package_contract( ) -def load_calibration_package_payload(package_path: Path) -> Mapping[str, Any]: - """Load a calibration package pickle for sidecar validation.""" +def load_calibration_package_payload(package_path: Path) -> CalibrationPackagePayload: + """Load a typed calibration package payload for sidecar validation.""" - with Path(package_path).open("rb") as handle: - package = pickle.load(handle) - if not isinstance(package, Mapping): - raise ValueError("Calibration package pickle must contain a mapping") - return package + return CalibrationPackageReader(package_path=Path(package_path)).read() -def _required_package_value(package: Mapping[str, Any], key: str) -> Any: - if key not in package: - raise ValueError(f"Calibration package missing required key: {key}") - return package[key] +def _calibration_package_payload( + package: CalibrationPackagePayload | Mapping[str, Any], + *, + require_core: bool = True, +) -> CalibrationPackagePayload: + if isinstance(package, CalibrationPackagePayload): + return package + return CalibrationPackagePayload.from_mapping( + package, + require_required_keys=require_core, + ) -def _package_metadata(package: Mapping[str, Any]) -> Mapping[str, Any]: - metadata = package.get("metadata", {}) - if metadata is None: - return {} - if not isinstance(metadata, Mapping): - raise ValueError("Calibration package metadata must be a mapping") - return metadata +def _package_metadata( + package: CalibrationPackagePayload | Mapping[str, Any], +) -> Mapping[str, Any]: + return _calibration_package_payload(package).metadata def _optional_metadata_string( @@ -506,23 +379,6 @@ def _optional_metadata_int(metadata: Mapping[str, Any], key: str) -> int | None: return int(value) -def _optional_len(value: Any) -> int | None: - if value is None: - return None - return int(len(value)) - - -def _one_dimensional_string_array(value: Any, key: str) -> Any: - import numpy as np - - array = np.asarray(value, dtype=str) - if array.ndim != 1: - raise ValueError(f"Calibration package geography {key} must be one-dimensional") - if np.any(array == ""): - raise ValueError(f"Calibration package geography {key} contains empty values") - return array - - def _calibration_package_parameters( parameters: CalibrationPackageParameters | Mapping[str, Any], ) -> CalibrationPackageParameters: diff --git a/tests/unit/calibration_package/test_payload.py b/tests/unit/calibration_package/test_payload.py new file mode 100644 index 000000000..b13d3bfda --- /dev/null +++ b/tests/unit/calibration_package/test_payload.py @@ -0,0 +1,124 @@ +import json +import pickle + +import pytest + +from tests.unit.fixtures.calibration_package_stage_contract import ( + calibration_package_contract, + calibration_package_payload, + calibration_package_payload_without_geography, +) + +from policyengine_us_data.calibration_package.payload import ( + LEGACY_MISSING_GEOGRAPHY_WARNING, + CalibrationPackagePayloadError, + CalibrationPackagePayload, + CalibrationPackageReader, + CalibrationPackageWriter, +) +from policyengine_us_data.stage_contracts.calibration_package import ( + summarize_calibration_package, +) +from policyengine_us_data.stage_contracts.calibration_package_schema import ( + CalibrationPackageSummary, + GeographyAssignmentSummary, +) + + +def test_calibration_package_payload_read_write_round_trip(tmp_path): + package_path = tmp_path / "calibration_package.pkl" + payload = CalibrationPackagePayload.from_mapping(calibration_package_payload()) + + written = CalibrationPackageWriter(package_path=package_path).write(payload) + loaded = CalibrationPackageReader(package_path=package_path).read() + + assert written == package_path + assert loaded.summary() == payload.summary() + assert loaded.geography_summary() == payload.geography_summary() + assert ( + CalibrationPackageReader(package_path=package_path) + .checksum() + .startswith("sha256:") + ) + + +@pytest.mark.parametrize( + "missing_key", + ["X_sparse", "targets_df", "target_names", "metadata"], +) +def test_calibration_package_payload_rejects_missing_required_keys(missing_key): + package = calibration_package_payload() + package.pop(missing_key) + + with pytest.raises(ValueError, match=missing_key): + CalibrationPackagePayload.from_mapping(package) + + +def test_calibration_package_reader_failure_exposes_validation_report(tmp_path): + package_path = tmp_path / "calibration_package.pkl" + with package_path.open("wb") as handle: + pickle.dump(["not", "a", "mapping"], handle) + + with pytest.raises(CalibrationPackagePayloadError) as exc_info: + CalibrationPackageReader(package_path=package_path).read() + + report = exc_info.value.validation_report + assert report.status == "fail" + finding = report.findings[0] + assert finding.check_id == "stage2_payload_read" + assert finding.metadata["operation"] == "read" + assert finding.metadata["package_path"] == str(package_path) + assert finding.metadata["error_type"] == "ValueError" + assert "must contain a mapping" in finding.message + + +def test_calibration_package_writer_failure_exposes_validation_report(tmp_path): + payload = CalibrationPackagePayload.from_mapping(calibration_package_payload()) + + with pytest.raises(CalibrationPackagePayloadError) as exc_info: + CalibrationPackageWriter(package_path=tmp_path).write(payload) + + report = exc_info.value.validation_report + assert report.status == "fail" + finding = report.findings[0] + assert finding.check_id == "stage2_payload_write" + assert finding.metadata["operation"] == "write" + assert finding.metadata["package_path"] == str(tmp_path) + assert finding.metadata["error_type"] == "IsADirectoryError" + + +def test_legacy_package_without_geography_records_compatibility_warning(): + payload = CalibrationPackagePayload.from_mapping( + calibration_package_payload_without_geography() + ) + + assert payload.compatibility_warnings == (LEGACY_MISSING_GEOGRAPHY_WARNING,) + geography = payload.geography_summary() + assert geography.source_kind == "unavailable" + + +def test_payload_summary_matches_existing_contract_summary(): + payload = CalibrationPackagePayload.from_mapping(calibration_package_payload()) + + assert payload.summary() == summarize_calibration_package(payload.to_mapping()) + + +def test_metadata_sidecar_uses_payload_and_contract(tmp_path): + package_path = tmp_path / "calibration_package.pkl" + payload = CalibrationPackagePayload.from_mapping(calibration_package_payload()) + contract = calibration_package_contract(tmp_path) + writer = CalibrationPackageWriter(package_path=package_path) + writer.write(payload) + + sidecar_path = writer.write_metadata_sidecar(payload, contract=contract) + sidecar = json.loads(sidecar_path.read_text(encoding="utf-8")) + + assert sidecar_path == tmp_path / "calibration_package_meta.json" + assert CalibrationPackageSummary.from_dict(sidecar["package_summary"]) == ( + payload.summary() + ) + assert GeographyAssignmentSummary.from_dict(sidecar["geography_assignment"]) == ( + payload.geography_summary() + ) + assert sidecar["contract"]["stage_id"] == "2_build_calibration_package" + assert sidecar["compatibility_warnings"] == [] diff --git a/tests/unit/test_calibration_package_stage_contract.py b/tests/unit/test_calibration_package_stage_contract.py index b99a12c78..9010f4a3a 100644 --- a/tests/unit/test_calibration_package_stage_contract.py +++ b/tests/unit/test_calibration_package_stage_contract.py @@ -31,6 +31,9 @@ validate_persisted_calibration_package_contract, write_calibration_package_contract, ) +from policyengine_us_data.calibration_package.payload import ( + CalibrationPackagePayloadError, +) from policyengine_us_data.utils.geography_checksum import ( canonical_geography_checksum, hash_string_array, @@ -656,7 +659,8 @@ def test_load_calibration_package_payload_rejects_non_mapping(tmp_path): try: load_calibration_package_payload(package_path) - except ValueError as exc: + except CalibrationPackagePayloadError as exc: assert "must contain a mapping" in str(exc) + assert exc.validation_report.status == "fail" else: raise AssertionError("Non-mapping package payload should fail") diff --git a/tests/unit/test_pipeline_docs_extractor.py b/tests/unit/test_pipeline_docs_extractor.py index 79b39fa0d..239456c93 100644 --- a/tests/unit/test_pipeline_docs_extractor.py +++ b/tests/unit/test_pipeline_docs_extractor.py @@ -136,6 +136,9 @@ def test_pipeline_map_manifest_validates(): "stage2_target_config_load", "build_matrix", "build_matrix_chunked", + "stage2_payload_boundary", + "stage2_payload_reader", + "stage2_payload_writer", "stage2_calibration_package_writer", "stage2_calibration_package_contract_writer", "stage2_calibration_package_contract_validator", diff --git a/tests/unit/test_remote_calibration_runner.py b/tests/unit/test_remote_calibration_runner.py index 9b280cad9..4b7d9da23 100644 --- a/tests/unit/test_remote_calibration_runner.py +++ b/tests/unit/test_remote_calibration_runner.py @@ -1,3 +1,4 @@ +import json import inspect import importlib import sys @@ -293,3 +294,32 @@ def fake_run_streaming(cmd, env=None, label=""): ensure_prereqs.assert_called_once() volume.reload.assert_called_once() volume.commit.assert_called_once() + + +def test_write_package_sidecar_reads_payload_and_contract(tmp_path): + remote_runner = _load_remote_calibration_runner_module() + from tests.unit.fixtures.calibration_package_stage_contract import ( + calibration_package_contract, + ) + + from policyengine_us_data.calibration_package.specs import ( + CALIBRATION_PACKAGE_CONTRACT_FILENAME, + CALIBRATION_PACKAGE_METADATA_FILENAME, + ) + from policyengine_us_data.stage_contracts.io import write_contract + + package_path = tmp_path / "calibration_package.pkl" + contract = calibration_package_contract(tmp_path) + write_contract(contract, tmp_path / CALIBRATION_PACKAGE_CONTRACT_FILENAME) + + assert remote_runner._write_package_sidecar(str(package_path)) is True + + sidecar = json.loads( + (tmp_path / CALIBRATION_PACKAGE_METADATA_FILENAME).read_text( + encoding="utf-8", + ) + ) + assert sidecar["package_sha256"].startswith("sha256:") + assert sidecar["package_summary"]["matrix_shape"] == [2, 3] + assert sidecar["geography_assignment"]["source_kind"] == "calibration_package" + assert sidecar["contract"]["stage_id"] == "2_build_calibration_package"