From 7351ec1c697b673931dbdfabe9f04528a7a34408 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 22 May 2026 23:44:14 +0200 Subject: [PATCH] Add scoped fitted weights contracts --- changelog.d/1118.added | 1 + docs/engineering/pipeline-map.md | 10 + docs/engineering/stages/fit_weights.md | 14 + docs/generated/pipeline_api.json | 8 +- docs/generated/pipeline_map.json | 62 +++- docs/pipeline_map.yaml | 38 ++ modal_app/pipeline.py | 39 ++- policyengine_us_data/fit_weights/__init__.py | 10 + policyengine_us_data/fit_weights/contracts.py | 331 ++++++++++++++++++ tests/unit/fit_weights/conftest.py | 65 ++++ tests/unit/fit_weights/test_contracts.py | 164 +++++++++ tests/unit/test_pipeline_source_contracts.py | 9 +- 12 files changed, 738 insertions(+), 13 deletions(-) create mode 100644 changelog.d/1118.added create mode 100644 policyengine_us_data/fit_weights/contracts.py create mode 100644 tests/unit/fit_weights/test_contracts.py diff --git a/changelog.d/1118.added b/changelog.d/1118.added new file mode 100644 index 000000000..b2f1de2a1 --- /dev/null +++ b/changelog.d/1118.added @@ -0,0 +1 @@ +Add scoped Stage 3 fitted-weight contract artifacts for regional and national fits. diff --git a/docs/engineering/pipeline-map.md b/docs/engineering/pipeline-map.md index b1f630718..92d86c4ab 100644 --- a/docs/engineering/pipeline-map.md +++ b/docs/engineering/pipeline-map.md @@ -460,12 +460,14 @@ Fit regional log-weights using L0 HardConcrete gates on GPU | `modal_gpu` Modal GPU Container | `external` | `unknown` | `unknown` | | | `fit_spec_regional` FittedWeightsSpec regional | `library` | `unknown` | `unknown` | | | `fit_artifacts_regional` ScopedFitArtifacts regional | `library` | `unknown` | `unknown` | | +| `fit_contract_builder_regional` FittedWeightsContractBuilder regional | `library` | `unknown` | `unknown` | | | `create_model` Create SparseCalibrationWeights | `process` | `unknown` | `unknown` | | | `extract_weights` Extract Weights | `process` | `unknown` | `unknown` | | | `out_weights` calibration_weights.npy | `artifact` | `unknown` | `unknown` | | | `out_geo_s6` geography_assignment.npz | `artifact` | `unknown` | `unknown` | | | `out_diag` unified_diagnostics.csv | `artifact` | `unknown` | `unknown` | | | `out_config_s6` unified_run_config.json | `artifact` | `unknown` | `unknown` | | +| `out_fit_contract_regional` fitted_weights_regional_contract.json | `artifact` | `unknown` | `unknown` | | | `util_l0` l0-python | `utility` | `unknown` | `unknown` | | | `util_pytorch` PyTorch | `utility` | `unknown` | `unknown` | | | `init_weights` Compute Initial Weights | `library` | `current` | `moving` | `policyengine_us_data.calibration.unified_calibration.compute_initial_weights` | @@ -479,6 +481,9 @@ Fit regional log-weights using L0 HardConcrete gates on GPU - `fit_artifacts_regional` -> `out_geo_s6` `documents` - `fit_artifacts_regional` -> `out_diag` `documents` - `fit_artifacts_regional` -> `out_config_s6` `documents` +- `fit_artifacts_regional` -> `out_fit_contract_regional` `documents` +- `fit_model` -> `fit_contract_builder_regional` `data_flow` +- `fit_contract_builder_regional` -> `out_fit_contract_regional` `produces_artifact` - `init_weights` -> `create_model` `data_flow` - `create_model` -> `fit_model` `data_flow` - `modal_gpu` -> `fit_model` `runs_on_infra` (runs on) @@ -507,12 +512,14 @@ Fit national log-weights for the national H5 output using the same L0 calibratio | `modal_gpu_national` Modal GPU Container | `external` | `unknown` | `unknown` | | | `fit_spec_national` FittedWeightsSpec national | `library` | `unknown` | `unknown` | | | `fit_artifacts_national` ScopedFitArtifacts national | `library` | `unknown` | `unknown` | | +| `fit_contract_builder_national` FittedWeightsContractBuilder national | `library` | `unknown` | `unknown` | | | `create_model_national` Create National SparseCalibrationWeights | `process` | `unknown` | `unknown` | | | `extract_national_weights` Extract National Weights | `process` | `unknown` | `unknown` | | | `out_national_weights` national_calibration_weights.npy | `artifact` | `unknown` | `unknown` | | | `out_national_geo_s6` national_geography_assignment.npz | `artifact` | `unknown` | `unknown` | | | `out_national_diag` national_unified_diagnostics.csv | `artifact` | `unknown` | `unknown` | | | `out_national_config_s6` national_unified_run_config.json | `artifact` | `unknown` | `unknown` | | +| `out_fit_contract_national` fitted_weights_national_contract.json | `artifact` | `unknown` | `unknown` | | | `util_l0_national` l0-python | `utility` | `unknown` | `unknown` | | | `util_pytorch_national` PyTorch | `utility` | `unknown` | `unknown` | | | `init_weights` Compute Initial Weights | `library` | `current` | `moving` | `policyengine_us_data.calibration.unified_calibration.compute_initial_weights` | @@ -526,6 +533,9 @@ Fit national log-weights for the national H5 output using the same L0 calibratio - `fit_artifacts_national` -> `out_national_geo_s6` `documents` - `fit_artifacts_national` -> `out_national_diag` `documents` - `fit_artifacts_national` -> `out_national_config_s6` `documents` +- `fit_artifacts_national` -> `out_fit_contract_national` `documents` +- `fit_model` -> `fit_contract_builder_national` `data_flow` +- `fit_contract_builder_national` -> `out_fit_contract_national` `produces_artifact` - `init_weights` -> `create_model_national` `data_flow` - `create_model_national` -> `fit_model` `data_flow` - `modal_gpu_national` -> `fit_model` `runs_on_infra` (runs on) diff --git a/docs/engineering/stages/fit_weights.md b/docs/engineering/stages/fit_weights.md index 67d48b417..d71687934 100644 --- a/docs/engineering/stages/fit_weights.md +++ b/docs/engineering/stages/fit_weights.md @@ -21,6 +21,20 @@ step parameters. Manual legacy package runs may proceed without the contract only through the explicit no-contract fallback, which emits a warning and records that only the package checksum was available. +Each successful scoped fit writes a semantic Stage 3 handoff contract next to +the primary fitted-weight artifacts: + +- regional: `fitted_weights_regional_contract.json`; +- national: `fitted_weights_national_contract.json`. + +These contracts use the canonical `fitted_weights` stage-contract type, include +the matching Stage 2 package and contract inputs, list the scoped weights, +geography, run config, legacy diagnostics, and epoch log outputs, and embed the +solver parameters plus package, geography, weights, and diagnostics summaries. +The fit step manifests record these contract JSON files as normal outputs so +Stage 4 can validate a scoped semantic handoff without relying only on filename +conventions. + The current artifact names remain behavior-compatible: - regional: `calibration_weights.npy`, `geography_assignment.npz`, diff --git a/docs/generated/pipeline_api.json b/docs/generated/pipeline_api.json index e3fcb1abc..4f76da921 100644 --- a/docs/generated/pipeline_api.json +++ b/docs/generated/pipeline_api.json @@ -1164,7 +1164,7 @@ "docstring": "Scoped output bundle created before Stage 3 bytes become files.", "id": "fitted_weights_output_bundle", "kind": "class", - "line": 113, + "line": 302, "metadata": { "api_refs": [ "policyengine_us_data.fit_weights.bundles.FittedWeightsOutputBundle" @@ -3086,7 +3086,7 @@ "docstring": "Promote a completed pipeline run to production.\n\n1. Verify run status is \"completed\"\n2. Promote every staged artifact in one Hugging Face commit\n3. Upload/copy every artifact to GCS\n4. Finalize release_manifest.json, tag the release, and update\n version_manifest.json\n5. Update run status to \"promoted\"\n\nArgs:\n run_id: The run ID to promote.\n candidate_version: Candidate staging scope used for staged source files.\n release_version: Stable version used for final release metadata.\n\nReturns:\n Summary message.", "id": "promote_pipeline_run", "kind": "function", - "line": 2091, + "line": 2133, "metadata": { "api_refs": [ "modal_app.pipeline.promote_run" @@ -3541,7 +3541,7 @@ "docstring": "Run the full pipeline end-to-end.\n\nArgs:\n branch: Git branch to build from.\n gpu: GPU type for regional calibration.\n epochs: Training epochs for regional calibration.\n national_gpu: GPU type for national calibration.\n national_epochs: Training epochs for national.\n num_workers: Number of parallel H5 workers.\n n_clones: Number of clones for H5 building.\n skip_national: Skip national calibration/H5.\n resume_run_id: Resume a previously failed run.\n clear_checkpoints: Wipe ALL checkpoints before building\n (default False). Normally not needed \u2014 checkpoints are\n scoped by commit SHA, so stale ones from other commits\n are cleaned automatically. Use True only to force a\n full rebuild of the current commit.\n candidate_version: Candidate staging scope used for HF staging.\n release_version: Final stable release version. Usually empty until\n promotion.\n base_release_version: Stable release current when this candidate was\n built.\n release_bump: Intended SemVer bump for this candidate.\n sha_override: Exact source SHA deployed by GitHub Actions. When\n provided, this is recorded instead of reading the current\n branch tip.\n run_id: Cross-system run ID created by GitHub.\n run_context: Serialized run context from the launcher workflow.\n modal_app_name: Deployed Modal app name for this run.\n modal_environment: Modal environment used for this run.\n chunked_matrix: Build the calibration matrix in clone-household\n chunks instead of the non-chunked path. Opt-in; default off.\n chunk_size: Clone-household columns per chunk when\n ``chunked_matrix`` is True.\n parallel_matrix: Fan chunked matrix building across Modal\n workers via ``build_matrix_chunk_worker``. Only meaningful\n when ``chunked_matrix`` is True; ignored otherwise.\n num_matrix_workers: Number of Modal workers when\n ``parallel_matrix`` is True.\n\nReturns:\n The run ID for use with promote.", "id": "run_modal_pipeline", "kind": "function", - "line": 1113, + "line": 1115, "metadata": { "api_refs": [ "modal_app.pipeline.run_pipeline" @@ -4479,7 +4479,7 @@ "docstring": "Verify deployed-image imports and subprocess seams.", "id": "verify_runtime_seams", "kind": "function", - "line": 739, + "line": 741, "metadata": { "api_refs": [ "modal_app.pipeline.verify_runtime_seams" diff --git a/docs/generated/pipeline_map.json b/docs/generated/pipeline_map.json index ea6f4fb0e..1ccadf427 100644 --- a/docs/generated/pipeline_map.json +++ b/docs/generated/pipeline_map.json @@ -4486,6 +4486,21 @@ "source": "fit_artifacts_regional", "target": "out_config_s6" }, + { + "edge_type": "documents", + "source": "fit_artifacts_regional", + "target": "out_fit_contract_regional" + }, + { + "edge_type": "data_flow", + "source": "fit_model", + "target": "fit_contract_builder_regional" + }, + { + "edge_type": "produces_artifact", + "source": "fit_contract_builder_regional", + "target": "out_fit_contract_regional" + }, { "edge_type": "data_flow", "source": "init_weights", @@ -4546,6 +4561,7 @@ "node_ids": [ "fit_spec_regional", "fit_artifacts_regional", + "fit_contract_builder_regional", "init_weights", "create_model", "fit_model", @@ -4553,7 +4569,8 @@ "out_weights", "out_geo_s6", "out_diag", - "out_config_s6" + "out_config_s6", + "out_fit_contract_regional" ] } ], @@ -4588,6 +4605,12 @@ "label": "ScopedFitArtifacts regional", "node_type": "library" }, + { + "description": "Builds the regional fitted_weights stage contract from Stage 2 package identity, fit parameters, artifacts, and diagnostics", + "id": "fit_contract_builder_regional", + "label": "FittedWeightsContractBuilder regional", + "node_type": "library" + }, { "description": "n_features = 5.16M, init_keep_prob = 0.999", "id": "create_model", @@ -4624,6 +4647,12 @@ "label": "unified_run_config.json", "node_type": "artifact" }, + { + "description": "Scoped Stage 3 contract for regional fitted weights, geography, run config, and legacy diagnostics", + "id": "out_fit_contract_regional", + "label": "fitted_weights_regional_contract.json", + "node_type": "artifact" + }, { "description": "SparseCalibrationWeights - HardConcrete gates", "id": "util_l0", @@ -4722,6 +4751,21 @@ "source": "fit_artifacts_national", "target": "out_national_config_s6" }, + { + "edge_type": "documents", + "source": "fit_artifacts_national", + "target": "out_fit_contract_national" + }, + { + "edge_type": "data_flow", + "source": "fit_model", + "target": "fit_contract_builder_national" + }, + { + "edge_type": "produces_artifact", + "source": "fit_contract_builder_national", + "target": "out_fit_contract_national" + }, { "edge_type": "data_flow", "source": "init_weights", @@ -4782,6 +4826,7 @@ "node_ids": [ "fit_spec_national", "fit_artifacts_national", + "fit_contract_builder_national", "init_weights", "create_model_national", "fit_model", @@ -4789,7 +4834,8 @@ "out_national_weights", "out_national_geo_s6", "out_national_diag", - "out_national_config_s6" + "out_national_config_s6", + "out_fit_contract_national" ] } ], @@ -4824,6 +4870,12 @@ "label": "ScopedFitArtifacts national", "node_type": "library" }, + { + "description": "Builds the national fitted_weights stage contract from Stage 2 package identity, fit parameters, artifacts, and diagnostics", + "id": "fit_contract_builder_national", + "label": "FittedWeightsContractBuilder national", + "node_type": "library" + }, { "description": "National target scope with L0 HardConcrete gates", "id": "create_model_national", @@ -4860,6 +4912,12 @@ "label": "national_unified_run_config.json", "node_type": "artifact" }, + { + "description": "Scoped Stage 3 contract for national fitted weights, geography, run config, and legacy diagnostics", + "id": "out_fit_contract_national", + "label": "fitted_weights_national_contract.json", + "node_type": "artifact" + }, { "description": "SparseCalibrationWeights - HardConcrete gates", "id": "util_l0_national", diff --git a/docs/pipeline_map.yaml b/docs/pipeline_map.yaml index 02d5581c5..ae76e3559 100644 --- a/docs/pipeline_map.yaml +++ b/docs/pipeline_map.yaml @@ -1047,6 +1047,7 @@ stages: node_ids: - fit_spec_regional - fit_artifacts_regional + - fit_contract_builder_regional - init_weights - create_model - fit_model @@ -1055,6 +1056,7 @@ stages: - out_geo_s6 - out_diag - out_config_s6 + - out_fit_contract_regional extra_nodes: - id: in_pkg_s6 label: calibration_package.pkl @@ -1072,6 +1074,10 @@ stages: label: ScopedFitArtifacts regional node_type: library description: Regional fitted-weight artifact filenames and remote result mapping + - id: fit_contract_builder_regional + label: FittedWeightsContractBuilder regional + node_type: library + description: Builds the regional fitted_weights stage contract from Stage 2 package identity, fit parameters, artifacts, and diagnostics - id: create_model label: Create SparseCalibrationWeights node_type: process @@ -1096,6 +1102,10 @@ stages: label: unified_run_config.json node_type: artifact description: Hyperparameters + SHA256 checksums + - id: out_fit_contract_regional + label: fitted_weights_regional_contract.json + node_type: artifact + description: Scoped Stage 3 contract for regional fitted weights, geography, run config, and legacy diagnostics - id: util_l0 label: l0-python node_type: utility @@ -1123,6 +1133,15 @@ stages: - source: fit_artifacts_regional target: out_config_s6 edge_type: documents + - source: fit_artifacts_regional + target: out_fit_contract_regional + edge_type: documents + - source: fit_model + target: fit_contract_builder_regional + edge_type: data_flow + - source: fit_contract_builder_regional + target: out_fit_contract_regional + edge_type: produces_artifact - source: init_weights target: create_model edge_type: data_flow @@ -1172,6 +1191,7 @@ stages: node_ids: - fit_spec_national - fit_artifacts_national + - fit_contract_builder_national - init_weights - create_model_national - fit_model @@ -1180,6 +1200,7 @@ stages: - out_national_geo_s6 - out_national_diag - out_national_config_s6 + - out_fit_contract_national extra_nodes: - id: in_pkg_national_s6 label: calibration_package.pkl @@ -1197,6 +1218,10 @@ stages: label: ScopedFitArtifacts national node_type: library description: National fitted-weight artifact filenames and remote result mapping + - id: fit_contract_builder_national + label: FittedWeightsContractBuilder national + node_type: library + description: Builds the national fitted_weights stage contract from Stage 2 package identity, fit parameters, artifacts, and diagnostics - id: create_model_national label: Create National SparseCalibrationWeights node_type: process @@ -1221,6 +1246,10 @@ stages: label: national_unified_run_config.json node_type: artifact description: National hyperparameters + SHA256 checksums + - id: out_fit_contract_national + label: fitted_weights_national_contract.json + node_type: artifact + description: Scoped Stage 3 contract for national fitted weights, geography, run config, and legacy diagnostics - id: util_l0_national label: l0-python node_type: utility @@ -1248,6 +1277,15 @@ stages: - source: fit_artifacts_national target: out_national_config_s6 edge_type: documents + - source: fit_artifacts_national + target: out_fit_contract_national + edge_type: documents + - source: fit_model + target: fit_contract_builder_national + edge_type: data_flow + - source: fit_contract_builder_national + target: out_fit_contract_national + edge_type: produces_artifact - source: init_weights target: create_model_national edge_type: data_flow diff --git a/modal_app/pipeline.py b/modal_app/pipeline.py index a1e128e97..8a0d2752e 100644 --- a/modal_app/pipeline.py +++ b/modal_app/pipeline.py @@ -116,7 +116,9 @@ from policyengine_us_data.pipeline_metadata import pipeline_node # noqa: E402 from policyengine_us_data.pipeline_schema import PipelineNode # noqa: E402 from policyengine_us_data.fit_weights import ( # noqa: E402 + FITTED_WEIGHTS_CONTRACT_SCHEMA_VERSION, FitScope, + FittedWeightsContractBuilder, FittedWeightsInputBundle, FittedWeightsOutputBundle, NATIONAL_FIT_LAMBDA_L0 as _NATIONAL_FIT_LAMBDA_L0, @@ -1534,7 +1536,12 @@ def run_pipeline( scope=FitScope.REGIONAL, calibration_package_path=_artifacts_dir(run_id) / "calibration_package.pkl", ) - fit_stage2_identity = regional_fit_input.stage2_identity_parameters() + fit_stage2_identity = { + **regional_fit_input.stage2_identity_parameters(), + "fitted_weights_contract_schema_version": ( + FITTED_WEIGHTS_CONTRACT_SCHEMA_VERSION + ), + } fit_inputs = _artifact_identities(regional_fit_input.artifact_identity_paths()) regional_fit_spec = fitted_weights_spec_for_scope(FitScope.REGIONAL) national_fit_spec = fitted_weights_spec_for_scope(FitScope.NATIONAL) @@ -1666,8 +1673,21 @@ def run_pipeline( pipeline_volume, scope=regional_output.scope, ) + regional_contract_path = FittedWeightsContractBuilder( + scope=regional_output.scope, + input_bundle=regional_fit_input, + parameters=regional_fit_parameters, + artifacts_root=_artifacts_dir(run_id), + diagnostics_root=Path(RUNS_DIR) / run_id / "diagnostics", + run_id=run_id, + started_at=regional_fit_manifest.started_at, + modal_call_id=regional_handle.object_id, + ).write() regional_outputs = collect_artifacts( - regional_output.artifact_paths(_artifacts_dir(run_id)), + [ + *regional_output.artifact_paths(_artifacts_dir(run_id)), + regional_contract_path, + ], missing_ok=True, ) regional_fit_reuse_measurement = ReuseMeasurement( @@ -1704,8 +1724,21 @@ def run_pipeline( pipeline_volume, scope=national_output.scope, ) + national_contract_path = FittedWeightsContractBuilder( + scope=national_output.scope, + input_bundle=regional_fit_input, + parameters=national_fit_parameters, + artifacts_root=_artifacts_dir(run_id), + diagnostics_root=Path(RUNS_DIR) / run_id / "diagnostics", + run_id=run_id, + started_at=national_fit_manifest.started_at, + modal_call_id=national_handle.object_id, + ).write() national_outputs = collect_artifacts( - national_output.artifact_paths(_artifacts_dir(run_id)), + [ + *national_output.artifact_paths(_artifacts_dir(run_id)), + national_contract_path, + ], missing_ok=True, ) _complete_step_manifest( diff --git a/policyengine_us_data/fit_weights/__init__.py b/policyengine_us_data/fit_weights/__init__.py index 5524fca27..9ed4a1864 100644 --- a/policyengine_us_data/fit_weights/__init__.py +++ b/policyengine_us_data/fit_weights/__init__.py @@ -16,6 +16,12 @@ FittedWeightsOutputBundle, MissingFitWeightsOutputError, ) +from policyengine_us_data.fit_weights.contracts import ( + FITTED_WEIGHTS_CONTRACT_SCHEMA_VERSION, + FittedWeightsContractBuilder, + fitted_weights_contract_filename, + fitted_weights_contract_path, +) from policyengine_us_data.fit_weights.specs import ( FIT_BETA, FIT_LOG_FREQ, @@ -35,6 +41,7 @@ "FIT_BETA", "FIT_LOG_FREQ", "FIT_TARGET_CONFIG_PATH", + "FITTED_WEIGHTS_CONTRACT_SCHEMA_VERSION", "FIT_WEIGHTS_SPEC_SCHEMA_VERSION", "NATIONAL_FIT_LAMBDA_L0", "NATIONAL_FIT_LAMBDA_L2", @@ -51,9 +58,12 @@ "FittedWeightsInputBundle", "FittedWeightsInputIdentity", "FittedWeightsOutputBundle", + "FittedWeightsContractBuilder", "FittedWeightsSpec", "MissingFitWeightsOutputError", "ScopedFitArtifacts", "fit_artifacts_for_scope", + "fitted_weights_contract_filename", + "fitted_weights_contract_path", "fitted_weights_spec_for_scope", ] diff --git a/policyengine_us_data/fit_weights/contracts.py b/policyengine_us_data/fit_weights/contracts.py new file mode 100644 index 000000000..f06db0841 --- /dev/null +++ b/policyengine_us_data/fit_weights/contracts.py @@ -0,0 +1,331 @@ +"""Scoped Stage 3 fitted-weight contract builders.""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import numpy as np + +from policyengine_us_data.fit_weights.artifacts import ( + FitArtifactRole, + fit_artifacts_for_scope, +) +from policyengine_us_data.fit_weights.bundles import FittedWeightsInputBundle +from policyengine_us_data.fit_weights.specs import FitScope +from policyengine_us_data.stage_contracts import ArtifactRef, StageContract +from policyengine_us_data.stage_contracts.execution import ( + ExecutionRecord, + ReuseSummary, +) +from policyengine_us_data.stage_contracts.fingerprints import fingerprint_material +from policyengine_us_data.stage_contracts.io import write_contract +from policyengine_us_data.stage_contracts.stages import ( + STAGE_3_FIT_WEIGHTS, + contract_type_for_stage, +) +from policyengine_us_data.stage_contracts.substages import SubstageRecord +from policyengine_us_data.utils.step_manifest import sha256_file, utc_now + +FITTED_WEIGHTS_CONTRACT_SCHEMA_VERSION = "1" +FITTED_WEIGHTS_CONTRACT_TYPE = contract_type_for_stage(STAGE_3_FIT_WEIGHTS) +FITTED_WEIGHTS_CONTRACT_FILENAMES = { + FitScope.REGIONAL: "fitted_weights_regional_contract.json", + FitScope.NATIONAL: "fitted_weights_national_contract.json", +} +FITTED_WEIGHTS_SUBSTAGE_IDS = { + FitScope.REGIONAL: "3a_weight_fitting_regional", + FitScope.NATIONAL: "3b_weight_fitting_national", +} + + +def fitted_weights_contract_filename(scope: FitScope | str) -> str: + """Return the scoped Stage 3 contract filename.""" + + return FITTED_WEIGHTS_CONTRACT_FILENAMES[FitScope.parse(scope)] + + +def fitted_weights_contract_path( + *, + scope: FitScope | str, + artifacts_root: str | Path, +) -> Path: + """Return the scoped Stage 3 contract path under an artifacts root.""" + + return Path(artifacts_root) / fitted_weights_contract_filename(scope) + + +@dataclass(frozen=True) +class FittedWeightsContractBuilder: + """Build a semantic contract for one scoped Stage 3 fit.""" + + scope: FitScope | str + input_bundle: FittedWeightsInputBundle + parameters: Mapping[str, Any] + artifacts_root: Path + diagnostics_root: Path + run_id: str | None = None + started_at: str | None = None + completed_at: str | None = None + duration_s: float | None = None + modal_call_id: str | None = None + code_sha: str | None = None + package_version: str | None = None + target_metadata_paths: Mapping[str, Path] = field(default_factory=dict) + + def __post_init__(self) -> None: + object.__setattr__(self, "scope", FitScope.parse(self.scope)) + object.__setattr__(self, "artifacts_root", Path(self.artifacts_root)) + object.__setattr__(self, "diagnostics_root", Path(self.diagnostics_root)) + + @property + def contract_path(self) -> Path: + """Return the default contract file path for this scoped fit.""" + + return fitted_weights_contract_path( + scope=self.scope, + artifacts_root=self.artifacts_root, + ) + + def build(self) -> StageContract: + """Build the Stage 3 contract from existing fit artifacts.""" + + inputs = tuple(self._input_artifacts()) + outputs = tuple(self._output_artifacts()) + metadata = self._metadata() + fingerprint = fingerprint_material( + { + "stage_id": STAGE_3_FIT_WEIGHTS, + "contract_type": FITTED_WEIGHTS_CONTRACT_TYPE, + "schema_version": FITTED_WEIGHTS_CONTRACT_SCHEMA_VERSION, + "scope": self.scope.value, + "inputs": inputs, + "outputs": outputs, + "parameters": dict(self.parameters), + "metadata": metadata, + } + ) + execution = ExecutionRecord( + status="completed", + started_at=self.started_at, + completed_at=self.completed_at or utc_now(), + duration_s=self.duration_s, + modal_call_id=self.modal_call_id, + reuse_decision="computed", + reuse_summary=ReuseSummary( + expected_outputs=len(outputs), + recomputed_outputs=len(outputs), + ), + ) + substage = SubstageRecord( + substage_id=FITTED_WEIGHTS_SUBSTAGE_IDS[self.scope], + status="completed", + inputs=inputs, + outputs=outputs, + parameters=dict(self.parameters), + fingerprint=fingerprint, + reuse_mode="handoff", + metadata={"scope": self.scope.value}, + ) + return StageContract( + contract_type=FITTED_WEIGHTS_CONTRACT_TYPE, + stage_id=STAGE_3_FIT_WEIGHTS, + run_id=self.run_id, + created_at=execution.completed_at or utc_now(), + code_sha=self.code_sha, + package_version=self.package_version, + inputs=inputs, + outputs=outputs, + parameters=dict(self.parameters), + fingerprint=fingerprint, + substages=(substage,), + execution=execution, + metadata=metadata, + ) + + def write(self, path: str | Path | None = None) -> Path: + """Write the scoped Stage 3 contract and return its path.""" + + contract_path = Path(path) if path is not None else self.contract_path + write_contract(self.build(), contract_path) + return contract_path + + def _input_artifacts(self) -> list[ArtifactRef]: + artifacts = [ + _artifact_ref( + logical_name="calibration_package", + path=self.input_bundle.calibration_package_path, + artifact_family="calibration_package", + role="input", + scope=self.scope.value, + ) + ] + contract_path = self.input_bundle.calibration_package_contract_path + if contract_path is not None and Path(contract_path).exists(): + artifacts.append( + _artifact_ref( + logical_name="calibration_package_contract", + path=contract_path, + artifact_family="stage_contract", + role="input", + scope=self.scope.value, + ) + ) + for logical_name, path in sorted(self.target_metadata_paths.items()): + if Path(path).exists(): + artifacts.append( + _artifact_ref( + logical_name=logical_name, + path=path, + artifact_family="calibration_target_metadata", + role="input", + scope=self.scope.value, + ) + ) + return artifacts + + def _output_artifacts(self) -> list[ArtifactRef]: + scoped_artifacts = fit_artifacts_for_scope(self.scope) + artifacts: list[ArtifactRef] = [] + for spec in scoped_artifacts.artifact_specs(): + artifacts.append( + _artifact_ref( + logical_name=_logical_output_name(self.scope, spec.role), + path=spec.path_under(self.artifacts_root), + artifact_family="fitted_weights", + role=spec.role.value, + scope=self.scope.value, + location=spec.location.value, + ) + ) + for spec in scoped_artifacts.diagnostic_specs(): + if spec.role == FitArtifactRole.RUN_CONFIG: + continue + path = spec.path_under(self.diagnostics_root) + if not path.exists(): + continue + artifacts.append( + _artifact_ref( + logical_name=_logical_output_name(self.scope, spec.role), + path=path, + artifact_family="fitted_weights", + role=spec.role.value, + scope=self.scope.value, + location=spec.location.value, + ) + ) + return artifacts + + def _metadata(self) -> dict[str, Any]: + scoped_artifacts = fit_artifacts_for_scope(self.scope) + identity = self.input_bundle.stage2_identity() + weights_path = scoped_artifacts.weights.path_under(self.artifacts_root) + geography_path = scoped_artifacts.geography.path_under(self.artifacts_root) + diagnostics = {} + for spec in scoped_artifacts.diagnostic_specs(): + if spec.role == FitArtifactRole.RUN_CONFIG: + continue + path = spec.path_under(self.diagnostics_root) + if path.exists(): + diagnostics[spec.role.value] = _csv_summary(path) + return { + "schema_version": FITTED_WEIGHTS_CONTRACT_SCHEMA_VERSION, + "scope": self.scope.value, + "package_checksum": identity.calibration_package_sha256, + "package_contract_checksum": identity.calibration_package_contract_sha256, + "package_contract_fingerprint": ( + identity.calibration_package_contract_fingerprint + ), + "weight_summary": _npy_summary(weights_path), + "geography_checksum": f"sha256:{sha256_file(geography_path)}", + "geography_size_bytes": geography_path.stat().st_size, + "diagnostics_summary": diagnostics, + "target_metadata_available": any( + Path(path).exists() for path in self.target_metadata_paths.values() + ), + } + + +def _logical_output_name(scope: FitScope, role: FitArtifactRole) -> str: + return f"fitted_weights_{scope.value}_{role.value}" + + +def _artifact_ref( + *, + logical_name: str, + path: str | Path, + artifact_family: str, + role: str, + scope: str, + location: str | None = None, +) -> ArtifactRef: + artifact_path = Path(path) + metadata = { + "artifact_family": artifact_family, + "scope": scope, + "role": role, + } + if location is not None: + metadata["location"] = location + return ArtifactRef( + logical_name=logical_name, + uri=artifact_path.resolve().as_uri(), + sha256=f"sha256:{sha256_file(artifact_path)}", + size_bytes=artifact_path.stat().st_size, + media_type=_media_type_for_path(artifact_path), + metadata=metadata, + ) + + +def _media_type_for_path(path: Path) -> str: + suffix = path.suffix.lower() + if suffix == ".json": + return "application/json" + if suffix == ".csv": + return "text/csv" + if suffix == ".npy": + return "application/x-numpy" + if suffix == ".npz": + return "application/x-numpy-zip" + if suffix == ".pkl": + return "application/python-pickle" + return "application/octet-stream" + + +def _npy_summary(path: Path) -> dict[str, Any]: + array = np.load(path, mmap_mode="r") + summary: dict[str, Any] = { + "shape": list(array.shape), + "dtype": str(array.dtype), + "count": int(array.size), + "sha256": f"sha256:{sha256_file(path)}", + } + if array.size: + summary.update( + { + "min": float(np.min(array)), + "max": float(np.max(array)), + "sum": float(np.sum(array)), + } + ) + return summary + + +def _csv_summary(path: Path) -> dict[str, Any]: + with path.open(encoding="utf-8") as handle: + line_count = sum(1 for _ in handle) + return { + "sha256": f"sha256:{sha256_file(path)}", + "size_bytes": path.stat().st_size, + "row_count": max(line_count - 1, 0), + } + + +__all__ = [ + "FITTED_WEIGHTS_CONTRACT_SCHEMA_VERSION", + "FittedWeightsContractBuilder", + "fitted_weights_contract_filename", + "fitted_weights_contract_path", +] diff --git a/tests/unit/fit_weights/conftest.py b/tests/unit/fit_weights/conftest.py index e6c9fcc26..4900c9fd1 100644 --- a/tests/unit/fit_weights/conftest.py +++ b/tests/unit/fit_weights/conftest.py @@ -1,13 +1,16 @@ from collections.abc import Callable from dataclasses import dataclass +import json from pathlib import Path +import numpy as np import pytest import yaml from policyengine_us_data.fit_weights import ( FitScope, FittedWeightsOutputBundle, + fit_artifacts_for_scope, ) from policyengine_us_data.stage_contracts import StageContract from policyengine_us_data.stage_contracts.calibration_package import ( @@ -33,6 +36,13 @@ class Stage2ContractFixture: contract: StageContract +@dataclass(frozen=True) +class ScopedFitFiles: + scope: FitScope + artifacts_root: Path + diagnostics_root: Path + + class FakeBatch: def __init__(self) -> None: self.files: dict[str, bytes] = {} @@ -77,6 +87,61 @@ def stage2_contract_fixture(tmp_path: Path) -> Stage2ContractFixture: ) +@pytest.fixture +def fitted_weights_parameters() -> dict: + return { + "scope": "regional", + "gpu": "T4", + "epochs": 2, + "target_config": "policyengine_us_data/calibration/target_config.yaml", + "beta": 0.65, + "lambda_l0": 1e-7, + "lambda_l2": 1e-8, + "log_freq": 100, + "fit_parameter_identity": "sha256:" + "1" * 64, + "calibration_package_sha256": "sha256:" + "2" * 64, + "calibration_package_contract_sha256": "sha256:" + "3" * 64, + "fitted_weights_contract_schema_version": "1", + } + + +@pytest.fixture +def scoped_fit_files(tmp_path: Path) -> Callable[[FitScope | str], ScopedFitFiles]: + def write_files(scope: FitScope | str) -> ScopedFitFiles: + parsed_scope = FitScope.parse(scope) + artifacts_root = tmp_path / parsed_scope.value / "artifacts" + diagnostics_root = tmp_path / parsed_scope.value / "diagnostics" + artifacts_root.mkdir(parents=True) + diagnostics_root.mkdir(parents=True) + artifacts = fit_artifacts_for_scope(parsed_scope) + + np.save( + artifacts.weights.path_under(artifacts_root), + np.array([1.0, 2.5, 3.5]), + ) + np.savez( + artifacts.geography.path_under(artifacts_root), + block_geoid=np.array(["010010001", "010010002"]), + cd_geoid=np.array(["0101", "0102"]), + ) + artifacts.run_config.path_under(artifacts_root).write_text( + json.dumps({"scope": parsed_scope.value}) + "\n" + ) + artifacts.diagnostics.path_under(diagnostics_root).write_text( + "target_id,error\nincome_tax,0.1\nsnap,0.2\n" + ) + artifacts.epoch_log.path_under(diagnostics_root).write_text( + "epoch,loss\n0,1.0\n1,0.5\n" + ) + return ScopedFitFiles( + scope=parsed_scope, + artifacts_root=artifacts_root, + diagnostics_root=diagnostics_root, + ) + + return write_files + + @pytest.fixture def fake_batch() -> FakeBatch: return FakeBatch() diff --git a/tests/unit/fit_weights/test_contracts.py b/tests/unit/fit_weights/test_contracts.py new file mode 100644 index 000000000..c8691615a --- /dev/null +++ b/tests/unit/fit_weights/test_contracts.py @@ -0,0 +1,164 @@ +from pathlib import Path + +from policyengine_us_data.fit_weights import ( + FITTED_WEIGHTS_CONTRACT_SCHEMA_VERSION, + FitScope, + FittedWeightsContractBuilder, + FittedWeightsInputBundle, + fitted_weights_contract_filename, +) +from policyengine_us_data.stage_contracts import contract_from_json, contract_to_json +from policyengine_us_data.stage_contracts.stages import STAGE_3_FIT_WEIGHTS +from policyengine_us_data.utils.step_manifest import sha256_file + + +def test_regional_contract_shape( + stage2_contract_fixture, + scoped_fit_files, + fitted_weights_parameters: dict, +) -> None: + files = scoped_fit_files(FitScope.REGIONAL) + contract = _build_contract( + stage2_contract_fixture, + files, + fitted_weights_parameters, + ) + + assert contract.stage_id == STAGE_3_FIT_WEIGHTS + assert contract.contract_type == "fitted_weights" + assert contract.run_id == "run-a" + assert contract.parameters["scope"] == "regional" + assert contract.metadata["scope"] == "regional" + assert contract.metadata["schema_version"] == FITTED_WEIGHTS_CONTRACT_SCHEMA_VERSION + assert contract.metadata["weight_summary"]["shape"] == (3,) + assert contract.metadata["diagnostics_summary"]["diagnostics"]["row_count"] == 2 + assert contract.substages[0].substage_id == "3a_weight_fitting_regional" + assert {artifact.logical_name for artifact in contract.inputs} == { + "calibration_package", + "calibration_package_contract", + } + assert {artifact.logical_name for artifact in contract.outputs} == { + "fitted_weights_regional_weights", + "fitted_weights_regional_geography", + "fitted_weights_regional_run_config", + "fitted_weights_regional_diagnostics", + "fitted_weights_regional_epoch_log", + } + + +def test_national_contract_shape( + stage2_contract_fixture, + scoped_fit_files, + fitted_weights_parameters: dict, +) -> None: + files = scoped_fit_files(FitScope.NATIONAL) + params = { + **fitted_weights_parameters, + "scope": "national", + "lambda_l0": 1e-4, + } + + contract = _build_contract(stage2_contract_fixture, files, params) + + assert contract.parameters["scope"] == "national" + assert contract.metadata["scope"] == "national" + assert contract.substages[0].substage_id == "3b_weight_fitting_national" + assert {artifact.logical_name for artifact in contract.outputs} == { + "fitted_weights_national_weights", + "fitted_weights_national_geography", + "fitted_weights_national_run_config", + "fitted_weights_national_diagnostics", + "fitted_weights_national_epoch_log", + } + assert fitted_weights_contract_filename(FitScope.NATIONAL) == ( + "fitted_weights_national_contract.json" + ) + + +def test_contract_fingerprint_tracks_solver_parameters( + stage2_contract_fixture, + scoped_fit_files, + fitted_weights_parameters: dict, +) -> None: + files = scoped_fit_files(FitScope.REGIONAL) + first = _build_contract( + stage2_contract_fixture, + files, + fitted_weights_parameters, + ) + second = _build_contract( + stage2_contract_fixture, + files, + {**fitted_weights_parameters, "epochs": 3}, + ) + + assert first.fingerprint.value != second.fingerprint.value + + +def test_contract_references_stage_2_package_contract_checksum( + stage2_contract_fixture, + scoped_fit_files, + fitted_weights_parameters: dict, +) -> None: + files = scoped_fit_files(FitScope.REGIONAL) + contract = _build_contract( + stage2_contract_fixture, + files, + fitted_weights_parameters, + ) + + contract_input = next( + artifact + for artifact in contract.inputs + if artifact.logical_name == "calibration_package_contract" + ) + assert contract_input.sha256 == ( + f"sha256:{sha256_file(stage2_contract_fixture.contract_path)}" + ) + assert contract.metadata["package_contract_checksum"] == contract_input.sha256 + + +def test_contract_round_trips_through_generic_stage_contract( + tmp_path: Path, + stage2_contract_fixture, + scoped_fit_files, + fitted_weights_parameters: dict, +) -> None: + files = scoped_fit_files(FitScope.REGIONAL) + builder = _builder(stage2_contract_fixture, files, fitted_weights_parameters) + contract_path = builder.write(tmp_path / "fitted_weights_regional_contract.json") + + contract = contract_from_json(contract_path.read_text()) + + assert contract == contract_from_json(contract_to_json(contract)) + assert contract.fingerprint.value.startswith("sha256:") + + +def _builder( + stage2_contract_fixture, + files, + parameters: dict, +) -> FittedWeightsContractBuilder: + return FittedWeightsContractBuilder( + scope=files.scope, + input_bundle=FittedWeightsInputBundle( + scope=files.scope, + calibration_package_path=stage2_contract_fixture.package_path, + calibration_package_contract_path=stage2_contract_fixture.contract_path, + ), + parameters=parameters, + artifacts_root=files.artifacts_root, + diagnostics_root=files.diagnostics_root, + run_id="run-a", + started_at="2026-05-08T12:00:00+00:00", + completed_at="2026-05-08T12:01:00+00:00", + modal_call_id="fc-123", + ) + + +def _build_contract( + stage2_contract_fixture, + files, + parameters: dict, +): + return _builder(stage2_contract_fixture, files, parameters).build() diff --git a/tests/unit/test_pipeline_source_contracts.py b/tests/unit/test_pipeline_source_contracts.py index 2d672ed74..9d9bba6a2 100644 --- a/tests/unit/test_pipeline_source_contracts.py +++ b/tests/unit/test_pipeline_source_contracts.py @@ -187,19 +187,20 @@ def test_run_pipeline_uses_stage_3_fit_specs_for_reuse_and_paths() -> None: assert "fitted_weights_spec_for_scope(FitScope.REGIONAL)" in source assert "fitted_weights_spec_for_scope(FitScope.NATIONAL)" in source + assert "FITTED_WEIGHTS_CONTRACT_SCHEMA_VERSION" in source + assert "FittedWeightsContractBuilder(" in source assert "fit_artifacts_for_scope(FitScope.REGIONAL)" in source assert "fit_artifacts_for_scope(FitScope.NATIONAL)" in source assert "regional_fit_spec.manifest_parameters(" in source assert "national_fit_spec.manifest_parameters(" in source - assert ( - "fit_stage2_identity = regional_fit_input.stage2_identity_parameters()" - in source - ) + assert "regional_fit_input.stage2_identity_parameters()" in source assert "regional_fit_spec.runtime_kwargs()" in source assert "national_fit_spec.runtime_kwargs()" in source assert "volume_package_contract_path=vol_contract_path" in source assert "fit_scope=FitScope.REGIONAL.value" in source assert "fit_scope=FitScope.NATIONAL.value" in source + assert "regional_contract_path" in source + assert "national_contract_path" in source assert "regional_output.artifact_paths(_artifacts_dir(run_id))" in source assert "national_output.artifact_paths(_artifacts_dir(run_id))" in source assert "diagnostic_result_filenames()" in archive_source