Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/1116.changed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Require Stage 3 fitted-weight runs to verify the Stage 2 calibration package contract before fitting.
9 changes: 9 additions & 0 deletions docs/engineering/stages/fit_weights.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@ builds. The public identity boundary lives in `policyengine_us_data.fit_weights`
`FittedWeightsOutputBundle` keep Stage 3 package inputs and remote result
bytes typed before they become files.

Normal pipeline runs must fit from a Stage 2 package that has a matching
`calibration_package_contract.json` sidecar. `FittedWeightsInputBundle` reads
that contract before GPU fitting starts, checks the contract-declared
`calibration_package.pkl` checksum and size against the package on the pipeline
volume, and records both the package checksum and contract checksum in the fit
step parameters. Manual legacy package runs may proceed without the contract
only through the explicit no-contract fallback, which emits a warning and
records that only the package checksum was available.

The current artifact names remain behavior-compatible:

- regional: `calibration_weights.npy`, `geography_assignment.npz`,
Expand Down
11 changes: 10 additions & 1 deletion modal_app/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,7 @@ def run_pipeline(
scope=FitScope.REGIONAL,
calibration_package_path=_artifacts_dir(run_id) / "calibration_package.pkl",
)
fit_stage2_identity = regional_fit_input.stage2_identity_parameters()
fit_inputs = _artifact_identities(regional_fit_input.artifact_identity_paths())
regional_fit_spec = fitted_weights_spec_for_scope(FitScope.REGIONAL)
national_fit_spec = fitted_weights_spec_for_scope(FitScope.NATIONAL)
Expand All @@ -1542,11 +1543,12 @@ def run_pipeline(
regional_fit_parameters = regional_fit_spec.manifest_parameters(
gpu=gpu,
epochs=epochs,
extra=fit_stage2_identity,
)
national_fit_parameters = national_fit_spec.manifest_parameters(
gpu=national_gpu,
epochs=national_epochs,
extra={"skip_national": skip_national},
extra={**fit_stage2_identity, "skip_national": skip_national},
)
regional_fit_reuse = _step_reusable(
meta,
Expand Down Expand Up @@ -1587,6 +1589,9 @@ def run_pipeline(
step_start = time.time()

vol_path = f"{artifacts_dir_for_run(run_id)}/calibration_package.pkl"
vol_contract_path = str(
regional_fit_input.calibration_package_contract_path
)

# Spawn regional fit
regional_func = PACKAGE_GPU_FUNCTIONS[gpu]
Expand All @@ -1595,6 +1600,8 @@ def run_pipeline(
branch=branch,
epochs=epochs,
volume_package_path=vol_path,
volume_package_contract_path=vol_contract_path,
fit_scope=FitScope.REGIONAL.value,
**regional_fit_spec.runtime_kwargs(),
)
print(f" → regional fit fc: {regional_handle.object_id}")
Expand Down Expand Up @@ -1623,6 +1630,8 @@ def run_pipeline(
branch=branch,
epochs=national_epochs,
volume_package_path=vol_path,
volume_package_contract_path=vol_contract_path,
fit_scope=FitScope.NATIONAL.value,
**national_fit_spec.runtime_kwargs(),
)
print(f" → national fit fc: {national_handle.object_id}")
Expand Down
75 changes: 75 additions & 0 deletions modal_app/remote_calibration_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@

from modal_app.images import gpu_image as image # noqa: E402
from policyengine_us_data.calibration_package.specs import ( # noqa: E402
CALIBRATION_PACKAGE_CONTRACT_FILENAME,
calibration_package_artifact_paths,
stage2_build_context_for_run,
)
from policyengine_us_data.fit_weights import ( # noqa: E402
FitResultBytes,
FitScope,
FittedWeightsInputBundle,
NATIONAL_FIT_LAMBDA_L0,
fit_artifacts_for_scope,
)
Expand Down Expand Up @@ -288,6 +290,9 @@ def _fit_from_package_impl(
branch: str,
epochs: int,
volume_package_path: str = None,
volume_package_contract_path: str = None,
allow_legacy_no_contract: bool = False,
fit_scope: str = FitScope.REGIONAL.value,
target_config: str = None,
beta: float = None,
lambda_l0: float = None,
Expand All @@ -300,6 +305,21 @@ def _fit_from_package_impl(
raise ValueError("volume_package_path is required")

_setup_repo()
input_bundle = FittedWeightsInputBundle(
scope=fit_scope,
calibration_package_path=Path(volume_package_path),
calibration_package_contract_path=(
Path(volume_package_contract_path) if volume_package_contract_path else None
),
allow_legacy_no_contract=allow_legacy_no_contract,
)
stage2_identity = input_bundle.stage2_identity()
if stage2_identity.stage2_contract_mode == "stage2_contract":
print(
"Validated Stage 2 calibration package contract "
f"{stage2_identity.calibration_package_contract_fingerprint}",
flush=True,
)

pkg_path = "/root/calibration_package.pkl"
import shutil
Expand Down Expand Up @@ -816,11 +836,17 @@ def fit_from_package_t4(
learning_rate: float = None,
log_freq: int = None,
volume_package_path: str = None,
volume_package_contract_path: str = None,
allow_legacy_no_contract: bool = False,
fit_scope: str = FitScope.REGIONAL.value,
) -> dict:
return _fit_from_package_impl(
branch,
epochs,
volume_package_path=volume_package_path,
volume_package_contract_path=volume_package_contract_path,
allow_legacy_no_contract=allow_legacy_no_contract,
fit_scope=fit_scope,
target_config=target_config,
beta=beta,
lambda_l0=lambda_l0,
Expand Down Expand Up @@ -848,11 +874,17 @@ def fit_from_package_a10(
learning_rate: float = None,
log_freq: int = None,
volume_package_path: str = None,
volume_package_contract_path: str = None,
allow_legacy_no_contract: bool = False,
fit_scope: str = FitScope.REGIONAL.value,
) -> dict:
return _fit_from_package_impl(
branch,
epochs,
volume_package_path=volume_package_path,
volume_package_contract_path=volume_package_contract_path,
allow_legacy_no_contract=allow_legacy_no_contract,
fit_scope=fit_scope,
target_config=target_config,
beta=beta,
lambda_l0=lambda_l0,
Expand Down Expand Up @@ -880,11 +912,17 @@ def fit_from_package_a100_40(
learning_rate: float = None,
log_freq: int = None,
volume_package_path: str = None,
volume_package_contract_path: str = None,
allow_legacy_no_contract: bool = False,
fit_scope: str = FitScope.REGIONAL.value,
) -> dict:
return _fit_from_package_impl(
branch,
epochs,
volume_package_path=volume_package_path,
volume_package_contract_path=volume_package_contract_path,
allow_legacy_no_contract=allow_legacy_no_contract,
fit_scope=fit_scope,
target_config=target_config,
beta=beta,
lambda_l0=lambda_l0,
Expand Down Expand Up @@ -912,11 +950,17 @@ def fit_from_package_a100_80(
learning_rate: float = None,
log_freq: int = None,
volume_package_path: str = None,
volume_package_contract_path: str = None,
allow_legacy_no_contract: bool = False,
fit_scope: str = FitScope.REGIONAL.value,
) -> dict:
return _fit_from_package_impl(
branch,
epochs,
volume_package_path=volume_package_path,
volume_package_contract_path=volume_package_contract_path,
allow_legacy_no_contract=allow_legacy_no_contract,
fit_scope=fit_scope,
target_config=target_config,
beta=beta,
lambda_l0=lambda_l0,
Expand Down Expand Up @@ -944,11 +988,17 @@ def fit_from_package_h100(
learning_rate: float = None,
log_freq: int = None,
volume_package_path: str = None,
volume_package_contract_path: str = None,
allow_legacy_no_contract: bool = False,
fit_scope: str = FitScope.REGIONAL.value,
) -> dict:
return _fit_from_package_impl(
branch,
epochs,
volume_package_path=volume_package_path,
volume_package_contract_path=volume_package_contract_path,
allow_legacy_no_contract=allow_legacy_no_contract,
fit_scope=fit_scope,
target_config=target_config,
beta=beta,
lambda_l0=lambda_l0,
Expand Down Expand Up @@ -1008,12 +1058,23 @@ def main(

if package_path:
vol_path = f"{PIPELINE_MOUNT}/artifacts/calibration_package.pkl"
local_contract_path = Path(package_path).with_name(
CALIBRATION_PACKAGE_CONTRACT_FILENAME
)
vol_contract_path = (
f"{PIPELINE_MOUNT}/artifacts/{CALIBRATION_PACKAGE_CONTRACT_FILENAME}"
if local_contract_path.exists()
else None
)
print(f"Reading package from {package_path}...", flush=True)
import json as _json
import pickle as _pkl

with open(package_path, "rb") as f:
package_bytes = f.read()
contract_bytes = (
local_contract_path.read_bytes() if local_contract_path.exists() else None
)
size = len(package_bytes)
pkg_meta = _pkl.loads(package_bytes).get("metadata", {})
sidecar_bytes = _json.dumps(pkg_meta, indent=2).encode()
Expand All @@ -1032,6 +1093,11 @@ def main(
BytesIO(sidecar_bytes),
"artifacts/calibration_package_meta.json",
)
if contract_bytes is not None:
batch.put_file(
BytesIO(contract_bytes),
f"artifacts/{CALIBRATION_PACKAGE_CONTRACT_FILENAME}",
)
pipeline_vol.commit()
del package_bytes
print("Upload complete.", flush=True)
Expand All @@ -1047,6 +1113,9 @@ def main(
learning_rate=learning_rate,
log_freq=log_freq,
volume_package_path=vol_path,
volume_package_contract_path=vol_contract_path,
allow_legacy_no_contract=True,
fit_scope=scope.value,
)
elif full_pipeline:
print(
Expand Down Expand Up @@ -1080,6 +1149,9 @@ def main(
)
else:
vol_path = f"{PIPELINE_MOUNT}/artifacts/calibration_package.pkl"
vol_contract_path = (
f"{PIPELINE_MOUNT}/artifacts/{CALIBRATION_PACKAGE_CONTRACT_FILENAME}"
)
vol_info = check_volume_package.remote()
if not vol_info["exists"]:
raise SystemExit(
Expand Down Expand Up @@ -1134,6 +1206,9 @@ def main(
learning_rate=learning_rate,
log_freq=log_freq,
volume_package_path=vol_path,
volume_package_contract_path=vol_contract_path,
allow_legacy_no_contract=True,
fit_scope=scope.value,
)

with open(output, "wb") as f:
Expand Down
4 changes: 4 additions & 0 deletions policyengine_us_data/fit_weights/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from policyengine_us_data.fit_weights.bundles import (
FitResultBytes,
FitWeightsBuildContext,
FittedWeightsInputContractError,
FittedWeightsInputBundle,
FittedWeightsInputIdentity,
FittedWeightsOutputBundle,
MissingFitWeightsOutputError,
)
Expand Down Expand Up @@ -45,7 +47,9 @@
"FitResultBytes",
"FitScope",
"FitWeightsBuildContext",
"FittedWeightsInputContractError",
"FittedWeightsInputBundle",
"FittedWeightsInputIdentity",
"FittedWeightsOutputBundle",
"FittedWeightsSpec",
"MissingFitWeightsOutputError",
Expand Down
Loading