diff --git a/changelog.d/1116.changed b/changelog.d/1116.changed new file mode 100644 index 000000000..25a191ac2 --- /dev/null +++ b/changelog.d/1116.changed @@ -0,0 +1 @@ +Require Stage 3 fitted-weight runs to verify the Stage 2 calibration package contract before fitting. diff --git a/docs/engineering/stages/fit_weights.md b/docs/engineering/stages/fit_weights.md index 21424b455..67d48b417 100644 --- a/docs/engineering/stages/fit_weights.md +++ b/docs/engineering/stages/fit_weights.md @@ -12,6 +12,15 @@ builds. The public identity boundary lives in `policyengine_us_data.fit_weights` `FittedWeightsOutputBundle` keep Stage 3 package inputs and remote result bytes typed before they become files. +Normal pipeline runs must fit from a Stage 2 package that has a matching +`calibration_package_contract.json` sidecar. `FittedWeightsInputBundle` reads +that contract before GPU fitting starts, checks the contract-declared +`calibration_package.pkl` checksum and size against the package on the pipeline +volume, and records both the package checksum and contract checksum in the fit +step parameters. Manual legacy package runs may proceed without the contract +only through the explicit no-contract fallback, which emits a warning and +records that only the package checksum was available. + The current artifact names remain behavior-compatible: - regional: `calibration_weights.npy`, `geography_assignment.npz`, diff --git a/modal_app/pipeline.py b/modal_app/pipeline.py index 266892ac4..a1e128e97 100644 --- a/modal_app/pipeline.py +++ b/modal_app/pipeline.py @@ -1534,6 +1534,7 @@ def run_pipeline( scope=FitScope.REGIONAL, calibration_package_path=_artifacts_dir(run_id) / "calibration_package.pkl", ) + fit_stage2_identity = regional_fit_input.stage2_identity_parameters() fit_inputs = _artifact_identities(regional_fit_input.artifact_identity_paths()) regional_fit_spec = fitted_weights_spec_for_scope(FitScope.REGIONAL) national_fit_spec = fitted_weights_spec_for_scope(FitScope.NATIONAL) @@ -1542,11 +1543,12 @@ def run_pipeline( regional_fit_parameters = regional_fit_spec.manifest_parameters( gpu=gpu, epochs=epochs, + extra=fit_stage2_identity, ) national_fit_parameters = national_fit_spec.manifest_parameters( gpu=national_gpu, epochs=national_epochs, - extra={"skip_national": skip_national}, + extra={**fit_stage2_identity, "skip_national": skip_national}, ) regional_fit_reuse = _step_reusable( meta, @@ -1587,6 +1589,9 @@ def run_pipeline( step_start = time.time() vol_path = f"{artifacts_dir_for_run(run_id)}/calibration_package.pkl" + vol_contract_path = str( + regional_fit_input.calibration_package_contract_path + ) # Spawn regional fit regional_func = PACKAGE_GPU_FUNCTIONS[gpu] @@ -1595,6 +1600,8 @@ def run_pipeline( branch=branch, epochs=epochs, volume_package_path=vol_path, + volume_package_contract_path=vol_contract_path, + fit_scope=FitScope.REGIONAL.value, **regional_fit_spec.runtime_kwargs(), ) print(f" → regional fit fc: {regional_handle.object_id}") @@ -1623,6 +1630,8 @@ def run_pipeline( branch=branch, epochs=national_epochs, volume_package_path=vol_path, + volume_package_contract_path=vol_contract_path, + fit_scope=FitScope.NATIONAL.value, **national_fit_spec.runtime_kwargs(), ) print(f" → national fit fc: {national_handle.object_id}") diff --git a/modal_app/remote_calibration_runner.py b/modal_app/remote_calibration_runner.py index 969cf405d..068b6e7b2 100644 --- a/modal_app/remote_calibration_runner.py +++ b/modal_app/remote_calibration_runner.py @@ -13,12 +13,14 @@ from modal_app.images import gpu_image as image # noqa: E402 from policyengine_us_data.calibration_package.specs import ( # noqa: E402 + CALIBRATION_PACKAGE_CONTRACT_FILENAME, calibration_package_artifact_paths, stage2_build_context_for_run, ) from policyengine_us_data.fit_weights import ( # noqa: E402 FitResultBytes, FitScope, + FittedWeightsInputBundle, NATIONAL_FIT_LAMBDA_L0, fit_artifacts_for_scope, ) @@ -288,6 +290,9 @@ def _fit_from_package_impl( branch: str, epochs: int, volume_package_path: str = None, + volume_package_contract_path: str = None, + allow_legacy_no_contract: bool = False, + fit_scope: str = FitScope.REGIONAL.value, target_config: str = None, beta: float = None, lambda_l0: float = None, @@ -300,6 +305,21 @@ def _fit_from_package_impl( raise ValueError("volume_package_path is required") _setup_repo() + input_bundle = FittedWeightsInputBundle( + scope=fit_scope, + calibration_package_path=Path(volume_package_path), + calibration_package_contract_path=( + Path(volume_package_contract_path) if volume_package_contract_path else None + ), + allow_legacy_no_contract=allow_legacy_no_contract, + ) + stage2_identity = input_bundle.stage2_identity() + if stage2_identity.stage2_contract_mode == "stage2_contract": + print( + "Validated Stage 2 calibration package contract " + f"{stage2_identity.calibration_package_contract_fingerprint}", + flush=True, + ) pkg_path = "/root/calibration_package.pkl" import shutil @@ -816,11 +836,17 @@ def fit_from_package_t4( learning_rate: float = None, log_freq: int = None, volume_package_path: str = None, + volume_package_contract_path: str = None, + allow_legacy_no_contract: bool = False, + fit_scope: str = FitScope.REGIONAL.value, ) -> dict: return _fit_from_package_impl( branch, epochs, volume_package_path=volume_package_path, + volume_package_contract_path=volume_package_contract_path, + allow_legacy_no_contract=allow_legacy_no_contract, + fit_scope=fit_scope, target_config=target_config, beta=beta, lambda_l0=lambda_l0, @@ -848,11 +874,17 @@ def fit_from_package_a10( learning_rate: float = None, log_freq: int = None, volume_package_path: str = None, + volume_package_contract_path: str = None, + allow_legacy_no_contract: bool = False, + fit_scope: str = FitScope.REGIONAL.value, ) -> dict: return _fit_from_package_impl( branch, epochs, volume_package_path=volume_package_path, + volume_package_contract_path=volume_package_contract_path, + allow_legacy_no_contract=allow_legacy_no_contract, + fit_scope=fit_scope, target_config=target_config, beta=beta, lambda_l0=lambda_l0, @@ -880,11 +912,17 @@ def fit_from_package_a100_40( learning_rate: float = None, log_freq: int = None, volume_package_path: str = None, + volume_package_contract_path: str = None, + allow_legacy_no_contract: bool = False, + fit_scope: str = FitScope.REGIONAL.value, ) -> dict: return _fit_from_package_impl( branch, epochs, volume_package_path=volume_package_path, + volume_package_contract_path=volume_package_contract_path, + allow_legacy_no_contract=allow_legacy_no_contract, + fit_scope=fit_scope, target_config=target_config, beta=beta, lambda_l0=lambda_l0, @@ -912,11 +950,17 @@ def fit_from_package_a100_80( learning_rate: float = None, log_freq: int = None, volume_package_path: str = None, + volume_package_contract_path: str = None, + allow_legacy_no_contract: bool = False, + fit_scope: str = FitScope.REGIONAL.value, ) -> dict: return _fit_from_package_impl( branch, epochs, volume_package_path=volume_package_path, + volume_package_contract_path=volume_package_contract_path, + allow_legacy_no_contract=allow_legacy_no_contract, + fit_scope=fit_scope, target_config=target_config, beta=beta, lambda_l0=lambda_l0, @@ -944,11 +988,17 @@ def fit_from_package_h100( learning_rate: float = None, log_freq: int = None, volume_package_path: str = None, + volume_package_contract_path: str = None, + allow_legacy_no_contract: bool = False, + fit_scope: str = FitScope.REGIONAL.value, ) -> dict: return _fit_from_package_impl( branch, epochs, volume_package_path=volume_package_path, + volume_package_contract_path=volume_package_contract_path, + allow_legacy_no_contract=allow_legacy_no_contract, + fit_scope=fit_scope, target_config=target_config, beta=beta, lambda_l0=lambda_l0, @@ -1008,12 +1058,23 @@ def main( if package_path: vol_path = f"{PIPELINE_MOUNT}/artifacts/calibration_package.pkl" + local_contract_path = Path(package_path).with_name( + CALIBRATION_PACKAGE_CONTRACT_FILENAME + ) + vol_contract_path = ( + f"{PIPELINE_MOUNT}/artifacts/{CALIBRATION_PACKAGE_CONTRACT_FILENAME}" + if local_contract_path.exists() + else None + ) print(f"Reading package from {package_path}...", flush=True) import json as _json import pickle as _pkl with open(package_path, "rb") as f: package_bytes = f.read() + contract_bytes = ( + local_contract_path.read_bytes() if local_contract_path.exists() else None + ) size = len(package_bytes) pkg_meta = _pkl.loads(package_bytes).get("metadata", {}) sidecar_bytes = _json.dumps(pkg_meta, indent=2).encode() @@ -1032,6 +1093,11 @@ def main( BytesIO(sidecar_bytes), "artifacts/calibration_package_meta.json", ) + if contract_bytes is not None: + batch.put_file( + BytesIO(contract_bytes), + f"artifacts/{CALIBRATION_PACKAGE_CONTRACT_FILENAME}", + ) pipeline_vol.commit() del package_bytes print("Upload complete.", flush=True) @@ -1047,6 +1113,9 @@ def main( learning_rate=learning_rate, log_freq=log_freq, volume_package_path=vol_path, + volume_package_contract_path=vol_contract_path, + allow_legacy_no_contract=True, + fit_scope=scope.value, ) elif full_pipeline: print( @@ -1080,6 +1149,9 @@ def main( ) else: vol_path = f"{PIPELINE_MOUNT}/artifacts/calibration_package.pkl" + vol_contract_path = ( + f"{PIPELINE_MOUNT}/artifacts/{CALIBRATION_PACKAGE_CONTRACT_FILENAME}" + ) vol_info = check_volume_package.remote() if not vol_info["exists"]: raise SystemExit( @@ -1134,6 +1206,9 @@ def main( learning_rate=learning_rate, log_freq=log_freq, volume_package_path=vol_path, + volume_package_contract_path=vol_contract_path, + allow_legacy_no_contract=True, + fit_scope=scope.value, ) with open(output, "wb") as f: diff --git a/policyengine_us_data/fit_weights/__init__.py b/policyengine_us_data/fit_weights/__init__.py index fb135a374..5524fca27 100644 --- a/policyengine_us_data/fit_weights/__init__.py +++ b/policyengine_us_data/fit_weights/__init__.py @@ -10,7 +10,9 @@ from policyengine_us_data.fit_weights.bundles import ( FitResultBytes, FitWeightsBuildContext, + FittedWeightsInputContractError, FittedWeightsInputBundle, + FittedWeightsInputIdentity, FittedWeightsOutputBundle, MissingFitWeightsOutputError, ) @@ -45,7 +47,9 @@ "FitResultBytes", "FitScope", "FitWeightsBuildContext", + "FittedWeightsInputContractError", "FittedWeightsInputBundle", + "FittedWeightsInputIdentity", "FittedWeightsOutputBundle", "FittedWeightsSpec", "MissingFitWeightsOutputError", diff --git a/policyengine_us_data/fit_weights/bundles.py b/policyengine_us_data/fit_weights/bundles.py index 46a2df7e9..5e11773bb 100644 --- a/policyengine_us_data/fit_weights/bundles.py +++ b/policyengine_us_data/fit_weights/bundles.py @@ -5,8 +5,12 @@ from dataclasses import dataclass from io import BytesIO from pathlib import Path -from typing import Mapping +from typing import Any, Mapping +import warnings +from policyengine_us_data.calibration_package.specs import ( + CALIBRATION_PACKAGE_CONTRACT_FILENAME, +) from policyengine_us_data.fit_weights.artifacts import ( ScopedFitArtifacts, fit_artifacts_for_scope, @@ -14,12 +18,31 @@ from policyengine_us_data.fit_weights.specs import FitScope from policyengine_us_data.pipeline_metadata import pipeline_node from policyengine_us_data.pipeline_schema import PipelineNode +from policyengine_us_data.stage_contracts import StageContract +from policyengine_us_data.stage_contracts.io import read_contract +from policyengine_us_data.stage_contracts.stages import ( + STAGE_2_BUILD_CALIBRATION_PACKAGE, + contract_type_for_stage, +) +from policyengine_us_data.utils.step_manifest import sha256_file + +STAGE_2_CALIBRATION_PACKAGE_CONTRACT_TYPE = contract_type_for_stage( + STAGE_2_BUILD_CALIBRATION_PACKAGE +) class MissingFitWeightsOutputError(ValueError): """Raised when remote fit bytes omit required fitted-weight artifacts.""" +class FittedWeightsInputContractError(ValueError): + """Raised when Stage 3 cannot establish Stage 2 package identity.""" + + def __init__(self, message: str, *, code: str) -> None: + super().__init__(message) + self.code = code + + @dataclass(frozen=True) class FitWeightsBuildContext: """Run-scoped filesystem context for Stage 3 fitted-weight artifacts.""" @@ -29,25 +52,191 @@ class FitWeightsBuildContext: diagnostics_root: Path +@dataclass(frozen=True) +class FittedWeightsInputIdentity: + """Checksum-backed Stage 2 package identity consumed by Stage 3.""" + + calibration_package_sha256: str + calibration_package_size_bytes: int + stage2_contract_mode: str + calibration_package_contract_sha256: str | None = None + calibration_package_contract_size_bytes: int | None = None + calibration_package_contract_fingerprint: str | None = None + calibration_package_contract_run_id: str | None = None + + def to_manifest_parameters(self) -> dict[str, Any]: + """Return fit manifest parameters that identify the Stage 2 package.""" + + params: dict[str, Any] = { + "calibration_package_sha256": self.calibration_package_sha256, + "calibration_package_size_bytes": self.calibration_package_size_bytes, + "stage2_contract_mode": self.stage2_contract_mode, + "calibration_package_contract_sha256": ( + self.calibration_package_contract_sha256 + ), + "calibration_package_contract_size_bytes": ( + self.calibration_package_contract_size_bytes + ), + "calibration_package_contract_fingerprint": ( + self.calibration_package_contract_fingerprint + ), + "calibration_package_contract_run_id": ( + self.calibration_package_contract_run_id + ), + } + return {key: value for key, value in params.items() if value is not None} + + @dataclass(frozen=True) class FittedWeightsInputBundle: - """Scoped Stage 3 input paths consumed before fitting starts.""" + """Scoped Stage 3 input paths and Stage 2 package identity.""" scope: FitScope | str calibration_package_path: Path + calibration_package_contract_path: Path | None = None + allow_legacy_no_contract: bool = False def __post_init__(self) -> None: object.__setattr__(self, "scope", FitScope.parse(self.scope)) - object.__setattr__( - self, - "calibration_package_path", - Path(self.calibration_package_path), + package_path = Path(self.calibration_package_path) + contract_path = ( + Path(self.calibration_package_contract_path) + if self.calibration_package_contract_path is not None + else package_path.with_name(CALIBRATION_PACKAGE_CONTRACT_FILENAME) ) + object.__setattr__(self, "calibration_package_path", package_path) + object.__setattr__(self, "calibration_package_contract_path", contract_path) def artifact_identity_paths(self) -> dict[str, Path]: """Return paths used for Stage 3 input identity calculation.""" - return {"calibration_package": self.calibration_package_path} + paths = {"calibration_package": self.calibration_package_path} + contract_path = self.calibration_package_contract_path + if contract_path is not None and ( + not self.allow_legacy_no_contract or contract_path.exists() + ): + paths["calibration_package_contract"] = contract_path + return paths + + def stage2_identity(self) -> FittedWeightsInputIdentity: + """Validate and return the Stage 2 package identity for fitting.""" + + package_path = self.calibration_package_path + if not package_path.exists(): + raise FittedWeightsInputContractError( + f"Missing calibration package artifact: {package_path}", + code="missing_calibration_package", + ) + if not package_path.is_file(): + raise FittedWeightsInputContractError( + f"Calibration package artifact is not a file: {package_path}", + code="invalid_calibration_package_path", + ) + + package_sha256 = f"sha256:{sha256_file(package_path)}" + package_size_bytes = package_path.stat().st_size + contract_path = self.calibration_package_contract_path + if contract_path is None or not contract_path.exists(): + if self.allow_legacy_no_contract: + warnings.warn( + "Proceeding with Stage 3 fitting without " + f"{CALIBRATION_PACKAGE_CONTRACT_FILENAME}; this legacy " + "manual fallback records only the package checksum.", + RuntimeWarning, + stacklevel=2, + ) + return FittedWeightsInputIdentity( + calibration_package_sha256=package_sha256, + calibration_package_size_bytes=package_size_bytes, + stage2_contract_mode="legacy_no_contract", + ) + raise FittedWeightsInputContractError( + "Missing Stage 2 calibration package contract: " + f"{contract_path or CALIBRATION_PACKAGE_CONTRACT_FILENAME}", + code="missing_stage2_contract", + ) + if not contract_path.is_file(): + raise FittedWeightsInputContractError( + f"Stage 2 calibration package contract is not a file: {contract_path}", + code="invalid_stage2_contract_path", + ) + + contract = _read_stage2_contract(contract_path) + _assert_stage2_contract_matches_package( + contract=contract, + package_path=package_path, + package_sha256=package_sha256, + package_size_bytes=package_size_bytes, + ) + return FittedWeightsInputIdentity( + calibration_package_sha256=package_sha256, + calibration_package_size_bytes=package_size_bytes, + stage2_contract_mode="stage2_contract", + calibration_package_contract_sha256=f"sha256:{sha256_file(contract_path)}", + calibration_package_contract_size_bytes=contract_path.stat().st_size, + calibration_package_contract_fingerprint=contract.fingerprint.value, + calibration_package_contract_run_id=contract.run_id, + ) + + def stage2_identity_parameters(self) -> dict[str, Any]: + """Return manifest parameters for the validated Stage 2 identity.""" + + return self.stage2_identity().to_manifest_parameters() + + +def _read_stage2_contract(contract_path: Path) -> StageContract: + try: + contract = read_contract(contract_path) + except Exception as exc: + raise FittedWeightsInputContractError( + f"Could not read Stage 2 calibration package contract: {contract_path}", + code="invalid_stage2_contract", + ) from exc + if contract.stage_id != STAGE_2_BUILD_CALIBRATION_PACKAGE: + raise FittedWeightsInputContractError( + f"Invalid Stage 2 contract stage_id: {contract.stage_id!r}", + code="invalid_stage2_contract", + ) + if contract.contract_type != STAGE_2_CALIBRATION_PACKAGE_CONTRACT_TYPE: + raise FittedWeightsInputContractError( + f"Invalid Stage 2 contract type: {contract.contract_type!r}", + code="invalid_stage2_contract", + ) + return contract + + +def _assert_stage2_contract_matches_package( + *, + contract: StageContract, + package_path: Path, + package_sha256: str, + package_size_bytes: int, +) -> None: + package_artifacts = [ + artifact + for artifact in contract.outputs + if artifact.logical_name == "calibration_package" + ] + if len(package_artifacts) != 1: + raise FittedWeightsInputContractError( + "Stage 2 contract must declare exactly one calibration_package output; " + f"found {len(package_artifacts)}.", + code="invalid_stage2_contract", + ) + package_artifact = package_artifacts[0] + if package_artifact.sha256 != package_sha256: + raise FittedWeightsInputContractError( + "Stage 2 calibration package contract checksum mismatch for " + f"{package_path}: {package_artifact.sha256!r} != {package_sha256!r}", + code="stage2_contract_package_mismatch", + ) + if package_artifact.size_bytes != package_size_bytes: + raise FittedWeightsInputContractError( + "Stage 2 calibration package contract size mismatch for " + f"{package_path}: {package_artifact.size_bytes!r} != " + f"{package_size_bytes!r}", + code="stage2_contract_package_mismatch", + ) @dataclass(frozen=True) @@ -182,7 +371,9 @@ def diagnostic_result_bytes(self) -> dict[str, bytes | None]: __all__ = [ "FitResultBytes", "FitWeightsBuildContext", + "FittedWeightsInputContractError", "FittedWeightsInputBundle", + "FittedWeightsInputIdentity", "FittedWeightsOutputBundle", "MissingFitWeightsOutputError", ] diff --git a/tests/unit/fit_weights/conftest.py b/tests/unit/fit_weights/conftest.py index ca9565d13..e6c9fcc26 100644 --- a/tests/unit/fit_weights/conftest.py +++ b/tests/unit/fit_weights/conftest.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from dataclasses import dataclass from pathlib import Path import pytest @@ -8,6 +9,28 @@ FitScope, FittedWeightsOutputBundle, ) +from policyengine_us_data.stage_contracts import StageContract +from policyengine_us_data.stage_contracts.calibration_package import ( + write_calibration_package_contract, +) +from tests.unit.fixtures.calibration_package_stage_contract import ( + CALIBRATION_COMPLETED_AT, + CALIBRATION_DURATION_S, + CALIBRATION_RUN_ID, + CALIBRATION_STARTED_AT, + calibration_package_parameters, + contract_input_paths, + write_calibration_package_payload, +) + + +@dataclass(frozen=True) +class Stage2ContractFixture: + dataset_path: Path + db_path: Path + package_path: Path + contract_path: Path + contract: StageContract class FakeBatch: @@ -28,6 +51,32 @@ def calibration_package_path() -> Path: return Path("/pipeline/artifacts/run/calibration_package.pkl") +@pytest.fixture +def stage2_contract_fixture(tmp_path: Path) -> Stage2ContractFixture: + dataset_path, db_path, package_path = contract_input_paths(tmp_path) + package = write_calibration_package_payload(package_path) + contract_path = tmp_path / "calibration_package_contract.json" + contract = write_calibration_package_contract( + package_path=package_path, + dataset_path=dataset_path, + db_path=db_path, + package=package, + parameters=calibration_package_parameters(), + run_id=CALIBRATION_RUN_ID, + started_at=CALIBRATION_STARTED_AT, + completed_at=CALIBRATION_COMPLETED_AT, + duration_s=CALIBRATION_DURATION_S, + contract_path=contract_path, + ) + return Stage2ContractFixture( + dataset_path=dataset_path, + db_path=db_path, + package_path=package_path, + contract_path=contract_path, + contract=contract, + ) + + @pytest.fixture def fake_batch() -> FakeBatch: return FakeBatch() diff --git a/tests/unit/fit_weights/test_bundles.py b/tests/unit/fit_weights/test_bundles.py index c5985f96d..26abdd4ac 100644 --- a/tests/unit/fit_weights/test_bundles.py +++ b/tests/unit/fit_weights/test_bundles.py @@ -4,10 +4,16 @@ from policyengine_us_data.fit_weights import ( FitScope, + FittedWeightsInputContractError, FittedWeightsInputBundle, FittedWeightsOutputBundle, MissingFitWeightsOutputError, ) +from policyengine_us_data.utils.step_manifest import sha256_file +from tests.unit.fixtures.calibration_package_stage_contract import ( + calibration_package_payload_with_block_geoids, + write_calibration_package_payload, +) def test_input_bundle_exposes_calibration_package_identity_path( @@ -20,7 +26,116 @@ def test_input_bundle_exposes_calibration_package_identity_path( assert bundle.scope == FitScope.REGIONAL assert bundle.artifact_identity_paths() == { - "calibration_package": calibration_package_path + "calibration_package": calibration_package_path, + "calibration_package_contract": calibration_package_path.with_name( + "calibration_package_contract.json" + ), + } + + +def test_input_bundle_records_stage_2_contract_identity( + stage2_contract_fixture, +) -> None: + bundle = FittedWeightsInputBundle( + scope=FitScope.REGIONAL, + calibration_package_path=stage2_contract_fixture.package_path, + calibration_package_contract_path=stage2_contract_fixture.contract_path, + ) + + assert bundle.artifact_identity_paths() == { + "calibration_package": stage2_contract_fixture.package_path, + "calibration_package_contract": stage2_contract_fixture.contract_path, + } + assert bundle.stage2_identity_parameters() == { + "calibration_package_sha256": stage2_contract_fixture.contract.outputs[ + 0 + ].sha256, + "calibration_package_size_bytes": ( + stage2_contract_fixture.package_path.stat().st_size + ), + "stage2_contract_mode": "stage2_contract", + "calibration_package_contract_sha256": ( + f"sha256:{sha256_file(stage2_contract_fixture.contract_path)}" + ), + "calibration_package_contract_size_bytes": ( + stage2_contract_fixture.contract_path.stat().st_size + ), + "calibration_package_contract_fingerprint": ( + stage2_contract_fixture.contract.fingerprint.value + ), + "calibration_package_contract_run_id": stage2_contract_fixture.contract.run_id, + } + + +def test_input_bundle_rejects_package_contract_checksum_mismatch( + stage2_contract_fixture, +) -> None: + write_calibration_package_payload( + stage2_contract_fixture.package_path, + calibration_package_payload_with_block_geoids(), + ) + bundle = FittedWeightsInputBundle( + scope=FitScope.REGIONAL, + calibration_package_path=stage2_contract_fixture.package_path, + calibration_package_contract_path=stage2_contract_fixture.contract_path, + ) + + with pytest.raises( + FittedWeightsInputContractError, + match="checksum mismatch", + ) as exc_info: + bundle.stage2_identity_parameters() + + assert exc_info.value.code == "stage2_contract_package_mismatch" + + +def test_input_bundle_rejects_missing_package_artifact(tmp_path: Path) -> None: + bundle = FittedWeightsInputBundle( + scope=FitScope.REGIONAL, + calibration_package_path=tmp_path / "missing.pkl", + ) + + with pytest.raises(FittedWeightsInputContractError, match="Missing") as exc_info: + bundle.stage2_identity_parameters() + + assert exc_info.value.code == "missing_calibration_package" + + +def test_input_bundle_requires_contract_unless_legacy_fallback( + stage2_contract_fixture, +) -> None: + stage2_contract_fixture.contract_path.unlink() + bundle = FittedWeightsInputBundle( + scope=FitScope.REGIONAL, + calibration_package_path=stage2_contract_fixture.package_path, + calibration_package_contract_path=stage2_contract_fixture.contract_path, + ) + + with pytest.raises(FittedWeightsInputContractError) as exc_info: + bundle.stage2_identity_parameters() + + assert exc_info.value.code == "missing_stage2_contract" + + +def test_input_bundle_legacy_no_contract_fallback_warns( + stage2_contract_fixture, +) -> None: + stage2_contract_fixture.contract_path.unlink() + bundle = FittedWeightsInputBundle( + scope=FitScope.REGIONAL, + calibration_package_path=stage2_contract_fixture.package_path, + calibration_package_contract_path=stage2_contract_fixture.contract_path, + allow_legacy_no_contract=True, + ) + + with pytest.warns(RuntimeWarning, match="legacy manual fallback"): + identity = bundle.stage2_identity_parameters() + + assert identity["stage2_contract_mode"] == "legacy_no_contract" + assert identity["calibration_package_sha256"].startswith("sha256:") + assert "calibration_package_contract_sha256" not in identity + assert bundle.artifact_identity_paths() == { + "calibration_package": stage2_contract_fixture.package_path } diff --git a/tests/unit/test_pipeline_source_contracts.py b/tests/unit/test_pipeline_source_contracts.py index 350a5043c..2d672ed74 100644 --- a/tests/unit/test_pipeline_source_contracts.py +++ b/tests/unit/test_pipeline_source_contracts.py @@ -191,8 +191,15 @@ def test_run_pipeline_uses_stage_3_fit_specs_for_reuse_and_paths() -> None: assert "fit_artifacts_for_scope(FitScope.NATIONAL)" in source assert "regional_fit_spec.manifest_parameters(" in source assert "national_fit_spec.manifest_parameters(" in source + assert ( + "fit_stage2_identity = regional_fit_input.stage2_identity_parameters()" + in source + ) assert "regional_fit_spec.runtime_kwargs()" in source assert "national_fit_spec.runtime_kwargs()" in source + assert "volume_package_contract_path=vol_contract_path" in source + assert "fit_scope=FitScope.REGIONAL.value" in source + assert "fit_scope=FitScope.NATIONAL.value" in source assert "regional_output.artifact_paths(_artifacts_dir(run_id))" in source assert "national_output.artifact_paths(_artifacts_dir(run_id))" in source assert "diagnostic_result_filenames()" in archive_source diff --git a/tests/unit/test_remote_calibration_runner.py b/tests/unit/test_remote_calibration_runner.py index 9b280cad9..8a6909156 100644 --- a/tests/unit/test_remote_calibration_runner.py +++ b/tests/unit/test_remote_calibration_runner.py @@ -4,6 +4,10 @@ from types import ModuleType, SimpleNamespace from unittest.mock import Mock +import pytest + +from policyengine_us_data.fit_weights import FittedWeightsInputContractError + def _load_remote_calibration_runner_module(): fake_modal = ModuleType("modal") @@ -51,6 +55,29 @@ def test_remote_runner_does_not_expose_optimizer_checkpoint_contract(): assert "checkpoint_name" not in inspect.signature(func).parameters +def test_fit_from_package_impl_requires_stage_2_contract_before_calibration( + monkeypatch, + tmp_path, +): + remote_runner = _load_remote_calibration_runner_module() + package_path = tmp_path / "calibration_package.pkl" + package_path.write_bytes(b"package") + run_streaming = Mock() + + monkeypatch.setattr(remote_runner, "_setup_repo", lambda: None) + monkeypatch.setattr(remote_runner, "_run_streaming", run_streaming) + + with pytest.raises(FittedWeightsInputContractError) as exc_info: + remote_runner._fit_from_package_impl( + branch="main", + epochs=1, + volume_package_path=str(package_path), + ) + + assert exc_info.value.code == "missing_stage2_contract" + run_streaming.assert_not_called() + + def test_collect_outputs_returns_pipeline_artifact_bytes(tmp_path): remote_runner = _load_remote_calibration_runner_module() weights = tmp_path / "weights.npy"