Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
**/*.h5
**/*.npy
**/*.csv
**/*.csv.gz
**/_build
**/*.pkl
**/*.db
Expand Down
1 change: 1 addition & 0 deletions changelog.d/708.added
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Save calibration geography as a pipeline artifact, add ``--resume-from`` and checkpoint support for long-running calibration fits, and fix resume/artifact handling in the remote calibration pipeline. This also adds conservative CPS taxpayer-ID outputs (``has_tin``, ``has_valid_ssn``, and a temporary ``has_itin`` compatibility alias), plus string-valued constraint handling needed for ID-target calibration.
30 changes: 30 additions & 0 deletions docs/calibration.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ python -m policyengine_us_data.calibration.unified_calibration \
--package-path storage/calibration/calibration_package.pkl \
--epochs 500 --device cuda

# Resume a previous fit for 500 more epochs:
python -m policyengine_us_data.calibration.unified_calibration \
--package-path storage/calibration/calibration_package.pkl \
--resume-from storage/calibration/calibration_weights.npy \
--epochs 500 --device cuda

# Full pipeline with PUF (build + fit in one shot):
make calibrate
```
Expand Down Expand Up @@ -88,6 +94,30 @@ python -m policyengine_us_data.calibration.unified_calibration \
You can re-run Step 2 as many times as you want with different hyperparameters. The expensive matrix
build only happens once.

Every fit now also writes a checkpoint next to the weights output
(`calibration_weights.checkpoint.pt` by default). To continue the same fit,
pass `--resume-from` with the weights file or checkpoint path. If a sibling
checkpoint exists next to the weights file, it is used automatically so the
L0 gate state is restored as well.

```bash
python -m policyengine_us_data.calibration.unified_calibration \
--package-path storage/calibration/calibration_package.pkl \
--epochs 2000 \
--beta 0.65 \
--lambda-l0 1e-4 \
--lambda-l2 1e-12 \
--log-freq 500 \
--target-config policyengine_us_data/calibration/target_config.yaml \
--device cpu \
--output policyengine_us_data/storage/calibration/national/weights.npy \
--resume-from policyengine_us_data/storage/calibration/national/weights.npy
```

When `--resume-from` points to a checkpoint, `--epochs` means additional epochs
to run beyond the saved checkpoint epoch count. If only a `.npy` weights file
exists, the run warm-starts from those weights.

### 2. Full pipeline with PUF

Adding `--puf-dataset` doubles the record count (~24K base records x 430 clones = ~10.3M columns) by
Expand Down
7 changes: 7 additions & 0 deletions modal_app/local_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,8 @@ def build_areas_worker(
"--output-dir",
str(output_dir),
]
if "geography" in calibration_inputs:
worker_cmd.extend(["--geography-path", calibration_inputs["geography"]])
if "n_clones" in calibration_inputs:
worker_cmd.extend(["--n-clones", str(calibration_inputs["n_clones"])])
if "seed" in calibration_inputs:
Expand Down Expand Up @@ -659,6 +661,7 @@ def coordinate_publish(
Path(f"/pipeline/artifacts/{run_id}") if run_id else Path("/pipeline/artifacts")
)
weights_path = artifacts / "calibration_weights.npy"
geography_path = artifacts / "geography_assignment.npz"
db_path = artifacts / "policy_data.db"
dataset_path = artifacts / "source_imputed_stratified_extended_cps.h5"
config_json_path = artifacts / "unified_run_config.json"
Expand All @@ -678,6 +681,7 @@ def coordinate_publish(

calibration_inputs = {
"weights": str(weights_path),
"geography": str(geography_path),
"dataset": str(dataset_path),
"database": str(db_path),
"n_clones": n_clones,
Expand Down Expand Up @@ -943,6 +947,7 @@ def coordinate_national_publish(
Path(f"/pipeline/artifacts/{run_id}") if run_id else Path("/pipeline/artifacts")
)
weights_path = artifacts / "national_calibration_weights.npy"
geography_path = artifacts / "national_geography_assignment.npz"
db_path = artifacts / "policy_data.db"
dataset_path = artifacts / "source_imputed_stratified_extended_cps.h5"
config_json_path = artifacts / "national_unified_run_config.json"
Expand All @@ -962,6 +967,7 @@ def coordinate_national_publish(

calibration_inputs = {
"weights": str(weights_path),
"geography": str(geography_path),
"dataset": str(dataset_path),
"database": str(db_path),
"n_clones": n_clones,
Expand All @@ -972,6 +978,7 @@ def coordinate_national_publish(
artifacts,
filename_remap={
"calibration_weights.npy": "national_calibration_weights.npy",
"geography_assignment.npz": "national_geography_assignment.npz",
},
)
run_dir = staging_dir / run_id
Expand Down
10 changes: 10 additions & 0 deletions modal_app/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,11 @@ def run_pipeline(
BytesIO(regional_result["weights"]),
f"{artifacts_rel}/calibration_weights.npy",
)
if regional_result.get("geography"):
batch.put_file(
BytesIO(regional_result["geography"]),
f"{artifacts_rel}/geography_assignment.npz",
)
if regional_result.get("config"):
batch.put_file(
BytesIO(regional_result["config"]),
Expand All @@ -856,6 +861,11 @@ def run_pipeline(
BytesIO(national_result["weights"]),
f"{artifacts_rel}/national_calibration_weights.npy",
)
if national_result.get("geography"):
batch.put_file(
BytesIO(national_result["geography"]),
f"{artifacts_rel}/national_geography_assignment.npz",
)
if national_result.get("config"):
batch.put_file(
BytesIO(national_result["config"]),
Expand Down
Loading
Loading