From c1980d6425184d9b1639fac0f23dcb47a991cae1 Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Wed, 8 Apr 2026 23:05:27 -0400 Subject: [PATCH 01/12] Save geography artifacts and add calibration resume/checkpoint support Calibration now persists geography_assignment.npz alongside weights so that downstream publish and worker steps use the exact same geography instead of regenerating it randomly. Adds --resume-from and --checkpoint-output flags to unified_calibration for continuing fits from a saved checkpoint or warm-starting from weights. Also gitignores *.csv.gz to prevent accidental commits of cached ORG data. Co-Authored-By: Claude Opus 4.6 (1M context) --- .gitignore | 1 + docs/calibration.md | 30 ++ modal_app/local_area.py | 7 + modal_app/pipeline.py | 10 + modal_app/remote_calibration_runner.py | 27 + modal_app/worker_script.py | 33 +- .../calibration/clone_and_assign.py | 99 ++++ .../calibration/publish_local_area.py | 223 ++++++-- .../calibration/unified_calibration.py | 496 ++++++++++++++---- policyengine_us_data/utils/huggingface.py | 13 +- .../unit/calibration/test_clone_and_assign.py | 77 +++ .../calibration/test_unified_calibration.py | 181 +++++++ 12 files changed, 1037 insertions(+), 160 deletions(-) diff --git a/.gitignore b/.gitignore index b53ecb473..9e85d7069 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ **/*.h5 **/*.npy **/*.csv +**/*.csv.gz **/_build **/*.pkl **/*.db diff --git a/docs/calibration.md b/docs/calibration.md index 3ad2fff70..4b67e0ecc 100644 --- a/docs/calibration.md +++ b/docs/calibration.md @@ -23,6 +23,12 @@ python -m policyengine_us_data.calibration.unified_calibration \ --package-path storage/calibration/calibration_package.pkl \ --epochs 500 --device cuda +# Resume a previous fit for 500 more epochs: +python -m policyengine_us_data.calibration.unified_calibration \ + --package-path storage/calibration/calibration_package.pkl \ + --resume-from storage/calibration/calibration_weights.npy \ + --epochs 500 --device cuda + # Full pipeline with PUF (build + fit in one shot): make calibrate ``` @@ -88,6 +94,30 @@ python -m policyengine_us_data.calibration.unified_calibration \ You can re-run Step 2 as many times as you want with different hyperparameters. The expensive matrix build only happens once. +Every fit now also writes a checkpoint next to the weights output +(`calibration_weights.checkpoint.pt` by default). To continue the same fit, +pass `--resume-from` with the weights file or checkpoint path. If a sibling +checkpoint exists next to the weights file, it is used automatically so the +L0 gate state is restored as well. + +```bash +python -m policyengine_us_data.calibration.unified_calibration \ + --package-path storage/calibration/calibration_package.pkl \ + --epochs 2000 \ + --beta 0.65 \ + --lambda-l0 1e-4 \ + --lambda-l2 1e-12 \ + --log-freq 500 \ + --target-config policyengine_us_data/calibration/target_config.yaml \ + --device cpu \ + --output policyengine_us_data/storage/calibration/national/weights.npy \ + --resume-from policyengine_us_data/storage/calibration/national/weights.npy +``` + +When `--resume-from` points to a checkpoint, `--epochs` means additional epochs +to run beyond the saved checkpoint epoch count. If only a `.npy` weights file +exists, the run warm-starts from those weights. + ### 2. Full pipeline with PUF Adding `--puf-dataset` doubles the record count (~24K base records x 430 clones = ~10.3M columns) by diff --git a/modal_app/local_area.py b/modal_app/local_area.py index 036f069a8..9965393d6 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -334,6 +334,8 @@ def build_areas_worker( "--output-dir", str(output_dir), ] + if "geography" in calibration_inputs: + worker_cmd.extend(["--geography-path", calibration_inputs["geography"]]) if "n_clones" in calibration_inputs: worker_cmd.extend(["--n-clones", str(calibration_inputs["n_clones"])]) if "seed" in calibration_inputs: @@ -659,6 +661,7 @@ def coordinate_publish( Path(f"/pipeline/artifacts/{run_id}") if run_id else Path("/pipeline/artifacts") ) weights_path = artifacts / "calibration_weights.npy" + geography_path = artifacts / "geography_assignment.npz" db_path = artifacts / "policy_data.db" dataset_path = artifacts / "source_imputed_stratified_extended_cps.h5" config_json_path = artifacts / "unified_run_config.json" @@ -678,6 +681,7 @@ def coordinate_publish( calibration_inputs = { "weights": str(weights_path), + "geography": str(geography_path), "dataset": str(dataset_path), "database": str(db_path), "n_clones": n_clones, @@ -943,6 +947,7 @@ def coordinate_national_publish( Path(f"/pipeline/artifacts/{run_id}") if run_id else Path("/pipeline/artifacts") ) weights_path = artifacts / "national_calibration_weights.npy" + geography_path = artifacts / "national_geography_assignment.npz" db_path = artifacts / "policy_data.db" dataset_path = artifacts / "source_imputed_stratified_extended_cps.h5" config_json_path = artifacts / "national_unified_run_config.json" @@ -962,6 +967,7 @@ def coordinate_national_publish( calibration_inputs = { "weights": str(weights_path), + "geography": str(geography_path), "dataset": str(dataset_path), "database": str(db_path), "n_clones": n_clones, @@ -972,6 +978,7 @@ def coordinate_national_publish( artifacts, filename_remap={ "calibration_weights.npy": "national_calibration_weights.npy", + "geography_assignment.npz": "national_geography_assignment.npz", }, ) run_dir = staging_dir / run_id diff --git a/modal_app/pipeline.py b/modal_app/pipeline.py index 413a12d18..446a2f669 100644 --- a/modal_app/pipeline.py +++ b/modal_app/pipeline.py @@ -832,6 +832,11 @@ def run_pipeline( BytesIO(regional_result["weights"]), f"{artifacts_rel}/calibration_weights.npy", ) + if regional_result.get("geography"): + batch.put_file( + BytesIO(regional_result["geography"]), + f"{artifacts_rel}/geography_assignment.npz", + ) if regional_result.get("config"): batch.put_file( BytesIO(regional_result["config"]), @@ -856,6 +861,11 @@ def run_pipeline( BytesIO(national_result["weights"]), f"{artifacts_rel}/national_calibration_weights.npy", ) + if national_result.get("geography"): + batch.put_file( + BytesIO(national_result["geography"]), + f"{artifacts_rel}/national_geography_assignment.npz", + ) if national_result.get("config"): batch.put_file( BytesIO(national_result["config"]), diff --git a/modal_app/remote_calibration_runner.py b/modal_app/remote_calibration_runner.py index 30126e24e..5fac45209 100644 --- a/modal_app/remote_calibration_runner.py +++ b/modal_app/remote_calibration_runner.py @@ -68,12 +68,15 @@ def _append_hyperparams(cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq def _collect_outputs(cal_lines): """Extract weights and log bytes from calibration output lines.""" output_path = None + geography_path = None log_path = None cal_log_path = None config_path = None for line in cal_lines: if "OUTPUT_PATH:" in line: output_path = line.split("OUTPUT_PATH:")[1].strip() + elif "GEOGRAPHY_PATH:" in line: + geography_path = line.split("GEOGRAPHY_PATH:")[1].strip() elif "CONFIG_PATH:" in line: config_path = line.split("CONFIG_PATH:")[1].strip() elif "CAL_LOG_PATH:" in line: @@ -84,6 +87,11 @@ def _collect_outputs(cal_lines): with open(output_path, "rb") as f: weights_bytes = f.read() + geography_bytes = None + if geography_path: + with open(geography_path, "rb") as f: + geography_bytes = f.read() + log_bytes = None if log_path: with open(log_path, "rb") as f: @@ -101,6 +109,7 @@ def _collect_outputs(cal_lines): return { "weights": weights_bytes, + "geography": geography_bytes, "log": log_bytes, "cal_log": cal_log_bytes, "config": config_bytes, @@ -975,6 +984,10 @@ def main( f" - calibration/{prefix}calibration_weights.npy", flush=True, ) + print( + f" - calibration/{prefix}geography_assignment.npz", + flush=True, + ) print( f" - calibration/logs/{prefix}* (diagnostics, " "config, calibration log)", @@ -1006,6 +1019,12 @@ def main( f.write(result["log"]) print(f"Diagnostics log saved to: {log_output}") + geography_output = f"{prefix}geography_assignment.npz" + if result.get("geography"): + with open(geography_output, "wb") as f: + f.write(result["geography"]) + print(f"Geography saved to: {geography_output}") + cal_log_output = f"{prefix}calibration_log.csv" if result.get("cal_log"): with open(cal_log_output, "wb") as f: @@ -1027,6 +1046,11 @@ def main( BytesIO(result["weights"]), f"artifacts/{prefix}calibration_weights.npy", ) + if result.get("geography"): + batch.put_file( + BytesIO(result["geography"]), + f"artifacts/{prefix}geography_assignment.npz", + ) if result.get("config"): batch.put_file( BytesIO(result["config"]), @@ -1042,6 +1066,9 @@ def main( upload_calibration_artifacts( weights_path=output, + geography_path=( + geography_output if result.get("geography") else None + ), log_dir=".", prefix=prefix, ) diff --git a/modal_app/worker_script.py b/modal_app/worker_script.py index 27dbb8c2a..41b02b651 100644 --- a/modal_app/worker_script.py +++ b/modal_app/worker_script.py @@ -155,6 +155,11 @@ def main(): parser.add_argument("--dataset-path", required=True) parser.add_argument("--db-path", required=True) parser.add_argument("--output-dir", required=True) + parser.add_argument( + "--geography-path", + default=None, + help="Optional explicit path to geography_assignment.npz", + ) parser.add_argument( "--n-clones", type=int, @@ -210,13 +215,11 @@ def main(): build_h5, NYC_COUNTY_FIPS, AT_LARGE_DISTRICTS, + load_calibration_geography, ) from policyengine_us_data.calibration.calibration_utils import ( STATE_CODES, ) - from policyengine_us_data.calibration.clone_and_assign import ( - assign_random_geography, - ) weights = np.load(weights_path) @@ -226,15 +229,18 @@ def main(): n_records = len(_sim.calculate("household_id", map_to="household").values) del _sim - geography = assign_random_geography( + geography = load_calibration_geography( + weights_path=weights_path, n_records=n_records, n_clones=args.n_clones, - seed=args.seed, + geography_path=( + Path(args.geography_path) if args.geography_path is not None else None + ), ) cds_to_calibrate = sorted(set(geography.cd_geoid.astype(str))) geo_labels = cds_to_calibrate print( - f"Generated geography: " + f"Loaded geography: " f"{geography.n_clones} clones x " f"{geography.n_records} records", file=sys.stderr, @@ -403,19 +409,12 @@ def main(): national_dir.mkdir(parents=True, exist_ok=True) n_clones_from_weights = weights.shape[0] // n_records if n_clones_from_weights != geography.n_clones: - print( + raise ValueError( f"National weights have {n_clones_from_weights} clones " - f"but geography has {geography.n_clones}; " - f"regenerating geography", - file=sys.stderr, + f"but geography has {geography.n_clones}. " + "Use the matching saved geography artifact." ) - national_geo = assign_random_geography( - n_records=n_records, - n_clones=n_clones_from_weights, - seed=args.seed, - ) - else: - national_geo = geography + national_geo = geography path = build_h5( weights=weights, geography=national_geo, diff --git a/policyengine_us_data/calibration/clone_and_assign.py b/policyengine_us_data/calibration/clone_and_assign.py index defcea17e..cee43704a 100644 --- a/policyengine_us_data/calibration/clone_and_assign.py +++ b/policyengine_us_data/calibration/clone_and_assign.py @@ -209,6 +209,105 @@ def _sample(size, mask_slice=None): ) +def save_geography(geography: GeographyAssignment, path) -> None: + """Save a GeographyAssignment to a compressed .npz file. + + Args: + geography: The geography assignment to save. + path: Output file path (should end in .npz). + """ + from pathlib import Path + + path = Path(path) + np.savez_compressed( + path, + block_geoid=geography.block_geoid, + cd_geoid=geography.cd_geoid, + county_fips=geography.county_fips, + state_fips=geography.state_fips, + n_records=np.array([geography.n_records]), + n_clones=np.array([geography.n_clones]), + ) + + +def load_geography(path) -> GeographyAssignment: + """Load a GeographyAssignment from a .npz file. + + Args: + path: Path to the .npz file saved by save_geography. + + Returns: + GeographyAssignment with all fields restored. + """ + from pathlib import Path + + path = Path(path) + data = np.load(path, allow_pickle=True) + return GeographyAssignment( + block_geoid=data["block_geoid"], + cd_geoid=data["cd_geoid"], + county_fips=data["county_fips"], + state_fips=data["state_fips"], + n_records=int(data["n_records"][0]), + n_clones=int(data["n_clones"][0]), + ) + + +@lru_cache(maxsize=1) +def load_sorted_block_cd_lookup(): + """Load a sorted block -> CD lookup for legacy block artifacts.""" + blocks, cds, _, _ = load_global_block_distribution() + order = np.argsort(blocks) + return blocks[order], cds[order] + + +def reconstruct_geography_from_blocks( + block_geoids: np.ndarray, + n_records: int, + n_clones: int, +) -> GeographyAssignment: + """Reconstruct a GeographyAssignment from saved block GEOIDs.""" + block_geoids = np.asarray(block_geoids, dtype=str) + expected_len = n_records * n_clones + if len(block_geoids) != expected_len: + raise ValueError( + f"Expected {expected_len} block GEOIDs for " + f"{n_records} records x {n_clones} clones, got {len(block_geoids)}" + ) + + sorted_blocks, sorted_cds = load_sorted_block_cd_lookup() + indices = np.searchsorted(sorted_blocks, block_geoids) + valid = indices < len(sorted_blocks) + matched = np.zeros(len(block_geoids), dtype=bool) + matched[valid] = sorted_blocks[indices[valid]] == block_geoids[valid] + + if not np.all(matched): + missing = np.unique(block_geoids[~matched])[:5] + raise KeyError( + "Could not recover congressional districts for some blocks. " + f"Examples: {missing.tolist()}" + ) + + county_fips = np.fromiter( + (block[:5] for block in block_geoids), + dtype="U5", + count=len(block_geoids), + ) + state_fips = np.fromiter( + (int(block[:2]) for block in block_geoids), + dtype=np.int32, + count=len(block_geoids), + ) + return GeographyAssignment( + block_geoid=block_geoids, + cd_geoid=sorted_cds[indices], + county_fips=county_fips, + state_fips=state_fips, + n_records=n_records, + n_clones=n_clones, + ) + + def double_geography_for_puf( geography: GeographyAssignment, ) -> GeographyAssignment: diff --git a/policyengine_us_data/calibration/publish_local_area.py b/policyengine_us_data/calibration/publish_local_area.py index 2a017668c..f90148e91 100644 --- a/policyengine_us_data/calibration/publish_local_area.py +++ b/policyengine_us_data/calibration/publish_local_area.py @@ -15,7 +15,7 @@ import numpy as np from pathlib import Path -from typing import List +from typing import List, Optional, Tuple from policyengine_us import Microsimulation from policyengine_us_data.utils.huggingface import download_calibration_inputs @@ -32,8 +32,8 @@ derive_geography_from_blocks, ) from policyengine_us_data.calibration.clone_and_assign import ( - GeographyAssignment, - assign_random_geography, + load_geography, + reconstruct_geography_from_blocks, ) from policyengine_us_data.utils.takeup import ( SIMPLE_TAKEUP_VARS, @@ -52,18 +52,140 @@ META_FILE = WORK_DIR / "checkpoint_meta.json" +CALIBRATION_WEIGHTS_SUFFIX = "calibration_weights.npy" +GEOGRAPHY_FILENAME = "geography_assignment.npz" +LEGACY_BLOCKS_FILENAME = "stacked_blocks.npy" + + +def _calibration_artifact_prefix(weights_path: Path) -> str: + if weights_path.name.endswith(CALIBRATION_WEIGHTS_SUFFIX): + return weights_path.name[: -len(CALIBRATION_WEIGHTS_SUFFIX)] + return "" + + +def _sibling_artifact_path(weights_path: Path, artifact_name: str) -> Path: + prefix = _calibration_artifact_prefix(weights_path) + return weights_path.with_name(f"{prefix}{artifact_name}") + + +def resolve_calibration_geography_paths( + weights_path: Path, + geography_path: Optional[Path] = None, + blocks_path: Optional[Path] = None, +) -> Tuple[Optional[Path], Optional[Path]]: + geo_candidates = [] + block_candidates = [] + if geography_path is not None: + geo_candidates.append(Path(geography_path)) + geo_candidates.append(_sibling_artifact_path(weights_path, GEOGRAPHY_FILENAME)) + + if blocks_path is not None: + block_candidates.append(Path(blocks_path)) + block_candidates.append(_sibling_artifact_path(weights_path, LEGACY_BLOCKS_FILENAME)) + block_candidates.append(weights_path.with_name(LEGACY_BLOCKS_FILENAME)) + + resolved_geo = next((path for path in geo_candidates if path.exists()), None) + resolved_blocks = next( + (path for path in block_candidates if path.exists()), + None, + ) + return resolved_geo, resolved_blocks + + +def _update_hash_from_file(h: "hashlib._Hash", path: Path) -> None: + with open(path, "rb") as f: + while chunk := f.read(8192): + h.update(chunk) + + def compute_input_fingerprint( - weights_path: Path, dataset_path: Path, n_clones: int, seed: int + weights_path: Path, + dataset_path: Path, + n_clones: Optional[int] = None, + seed: int = 42, + geography_path: Optional[Path] = None, + blocks_path: Optional[Path] = None, ) -> str: h = hashlib.sha256() for p in [weights_path, dataset_path]: - with open(p, "rb") as f: - while chunk := f.read(8192): - h.update(chunk) - h.update(f"{n_clones}:{seed}".encode()) + _update_hash_from_file(h, p) + + resolved_geo, resolved_blocks = resolve_calibration_geography_paths( + weights_path=weights_path, + geography_path=geography_path, + blocks_path=blocks_path, + ) + if resolved_geo is not None: + h.update(b"geography_assignment") + _update_hash_from_file(h, resolved_geo) + elif resolved_blocks is not None: + h.update(b"legacy_stacked_blocks") + _update_hash_from_file(h, resolved_blocks) + else: + h.update(f"legacy_regeneration:{n_clones}:{seed}".encode()) return h.hexdigest()[:16] +def load_calibration_geography( + weights_path: Path, + n_records: int, + n_clones: Optional[int] = None, + geography_path: Optional[Path] = None, + blocks_path: Optional[Path] = None, +): + resolved_geo, resolved_blocks = resolve_calibration_geography_paths( + weights_path=weights_path, + geography_path=geography_path, + blocks_path=blocks_path, + ) + + if resolved_geo is not None: + geography = load_geography(resolved_geo) + if geography.n_records != n_records: + raise ValueError( + f"Geography artifact {resolved_geo} has n_records={geography.n_records}, " + f"expected {n_records}" + ) + if n_clones is not None and geography.n_clones != n_clones: + raise ValueError( + f"Geography artifact {resolved_geo} has n_clones={geography.n_clones}, " + f"expected {n_clones}" + ) + print(f"Loaded calibration geography from {resolved_geo}") + return geography + + if resolved_blocks is not None: + block_geoids = np.asarray(np.load(resolved_blocks, allow_pickle=True), dtype=str) + if len(block_geoids) % n_records != 0: + raise ValueError( + f"Legacy blocks artifact {resolved_blocks} has {len(block_geoids)} " + f"rows, not divisible by n_records={n_records}" + ) + inferred_n_clones = len(block_geoids) // n_records + if n_clones is not None and inferred_n_clones != n_clones: + raise ValueError( + f"Legacy blocks artifact {resolved_blocks} implies " + f"n_clones={inferred_n_clones}, expected {n_clones}" + ) + print( + "Reconstructing geography from legacy stacked blocks at " + f"{resolved_blocks}" + ) + return reconstruct_geography_from_blocks( + block_geoids=block_geoids, + n_records=n_records, + n_clones=inferred_n_clones, + ) + + geo_hint = _sibling_artifact_path(weights_path, GEOGRAPHY_FILENAME) + legacy_hint = _sibling_artifact_path(weights_path, LEGACY_BLOCKS_FILENAME) + raise FileNotFoundError( + "No saved calibration geography found. Expected either " + f"{geo_hint} or {legacy_hint}. Re-run calibration on this branch or " + "provide --geography-path." + ) + + def validate_or_clear_checkpoints(fingerprint: str): if META_FILE.exists(): stored = json.loads(META_FILE.read_text()) @@ -849,6 +971,11 @@ def main(): action="store_true", help="Only build and upload city files (e.g., NYC)", ) + parser.add_argument( + "--national-only", + action="store_true", + help="Only build the national US.h5 file", + ) parser.add_argument( "--weights-path", type=str, @@ -867,14 +994,26 @@ def main(): parser.add_argument( "--n-clones", type=int, - required=True, - help="Number of clones used in calibration", + default=None, + help="Clone count override for validating saved geography artifacts", ) parser.add_argument( "--seed", type=int, default=42, - help="Random seed used in calibration (default: 42)", + help="Legacy fallback seed used only if no saved geography is available", + ) + parser.add_argument( + "--geography-path", + type=str, + default=None, + help="Override path to saved geography_assignment.npz", + ) + parser.add_argument( + "--blocks-path", + type=str, + default=None, + help="Override path to legacy stacked_blocks.npy", ) parser.add_argument( "--upload", @@ -917,6 +1056,8 @@ def main(): inputs["dataset"], args.n_clones, args.seed, + geography_path=Path(args.geography_path) if args.geography_path else None, + blocks_path=Path(args.blocks_path) if args.blocks_path else None, ) validate_or_clear_checkpoints(fingerprint) @@ -926,43 +1067,21 @@ def main(): del _sim print(f"\nBase dataset has {n_hh:,} households") - geo_cache = WORK_DIR / f"geography_{n_hh}x{args.n_clones}_s{args.seed}.npz" - if geo_cache.exists(): - print(f"Loading cached geography from {geo_cache}") - npz = np.load(geo_cache, allow_pickle=True) - geography = GeographyAssignment( - block_geoid=npz["block_geoid"], - cd_geoid=npz["cd_geoid"], - county_fips=npz["county_fips"], - state_fips=npz["state_fips"], - n_records=n_hh, - n_clones=args.n_clones, - ) - else: - print( - f"Generating geography: {n_hh} records x " - f"{args.n_clones} clones, seed={args.seed}" - ) - geography = assign_random_geography( - n_records=n_hh, - n_clones=args.n_clones, - seed=args.seed, - ) - np.savez_compressed( - geo_cache, - block_geoid=geography.block_geoid, - cd_geoid=geography.cd_geoid, - county_fips=geography.county_fips, - state_fips=geography.state_fips, - ) - print(f"Saved geography cache to {geo_cache}") + geography = load_calibration_geography( + weights_path=inputs["weights"], + n_records=n_hh, + n_clones=args.n_clones, + geography_path=Path(args.geography_path) if args.geography_path else None, + blocks_path=Path(args.blocks_path) if args.blocks_path else None, + ) takeup_filter = [spec["variable"] for spec in SIMPLE_TAKEUP_VARS] print(f"Takeup filter: {takeup_filter}") # Determine what to build based on flags - do_states = not args.districts_only and not args.cities_only - do_districts = not args.states_only and not args.cities_only - do_cities = not args.states_only and not args.districts_only + do_national = args.national_only + do_states = not args.districts_only and not args.cities_only and not args.national_only + do_districts = not args.states_only and not args.cities_only and not args.national_only + do_cities = not args.states_only and not args.districts_only and not args.national_only # If a specific *-only flag is set, only build that type if args.states_only: @@ -978,6 +1097,22 @@ def main(): do_districts = False do_cities = True + if do_national: + print("\n" + "=" * 60) + print("BUILDING NATIONAL US.h5") + print("=" * 60) + weights = np.load(inputs["weights"]) + national_dir = WORK_DIR / "national" + national_dir.mkdir(parents=True, exist_ok=True) + path = build_h5( + weights=weights, + geography=geography, + dataset_path=inputs["dataset"], + output_path=national_dir / "US.h5", + takeup_filter=takeup_filter, + ) + print(f"Built {path}") + if do_states: print("\n" + "=" * 60) print("BUILDING STATE FILES") diff --git a/policyengine_us_data/calibration/unified_calibration.py b/policyengine_us_data/calibration/unified_calibration.py index e449cea4d..22344fc65 100644 --- a/policyengine_us_data/calibration/unified_calibration.py +++ b/policyengine_us_data/calibration/unified_calibration.py @@ -316,6 +316,20 @@ def parse_args(argv=None): help="Number of parallel workers for state/county " "precomputation (default: 1, sequential).", ) + parser.add_argument( + "--resume-from", + default=None, + help="Resume fitting from a `.checkpoint.pt` file or " + "warm-start from a `.npy` weights file. " + "If a sibling checkpoint exists next to the weights file, " + "it is preferred automatically.", + ) + parser.add_argument( + "--checkpoint-output", + default=None, + help="Where to save resumable fit checkpoints " + "(default: .checkpoint.pt).", + ) return parser.parse_args(argv) @@ -473,6 +487,134 @@ def load_calibration_package(path: str) -> dict: return package +def default_checkpoint_path(output_path: str) -> Path: + """Derive the default checkpoint artifact path for a weights file.""" + return Path(output_path).with_suffix(".checkpoint.pt") + + +def _hash_string_list(values: list) -> str: + """Hash an ordered list of strings for checkpoint compatibility.""" + import hashlib + + digest = hashlib.sha256() + for value in values or []: + digest.update(str(value).encode("utf-8")) + digest.update(b"\0") + return digest.hexdigest() + + +def build_checkpoint_signature( + X_sparse, + targets: np.ndarray, + target_names: list, + lambda_l0: float, + beta: float, + lambda_l2: float, + learning_rate: float, +) -> dict: + """Build a compact signature to validate resume compatibility.""" + import hashlib + + targets_arr = np.asarray(targets, dtype=np.float64) + return { + "n_features": int(X_sparse.shape[1]), + "n_targets": int(len(targets_arr)), + "target_names_sha256": _hash_string_list(target_names), + "targets_sha256": hashlib.sha256(targets_arr.tobytes()).hexdigest(), + "lambda_l0": float(lambda_l0), + "beta": float(beta), + "lambda_l2": float(lambda_l2), + "learning_rate": float(learning_rate), + } + + +def checkpoint_signature_mismatches(expected: dict, actual: dict) -> list: + """Return human-readable checkpoint compatibility mismatches.""" + mismatches = [] + float_keys = {"lambda_l0", "beta", "lambda_l2", "learning_rate"} + for key, expected_value in expected.items(): + actual_value = actual.get(key) + if key in float_keys: + if actual_value is None or not np.isclose(expected_value, actual_value): + mismatches.append( + f"{key} expected {expected_value}, got {actual_value}" + ) + elif actual_value != expected_value: + mismatches.append( + f"{key} expected {expected_value}, got {actual_value}" + ) + return mismatches + + +def save_fit_checkpoint( + path: str, + model, + epochs_completed: int, + signature: dict, +) -> None: + """Persist model state for resumable calibration fitting.""" + import datetime + import torch + + state_dict = {} + for key, value in model.state_dict().items(): + if hasattr(value, "detach"): + state_dict[key] = value.detach().cpu() + else: + state_dict[key] = value + + checkpoint = { + "format_version": 1, + "saved_at": datetime.datetime.now().isoformat(), + "epochs_completed": int(epochs_completed), + "signature": signature, + "model_state_dict": state_dict, + } + Path(path).parent.mkdir(parents=True, exist_ok=True) + torch.save(checkpoint, path) + logger.info( + "Calibration checkpoint saved to %s (epoch %d)", + path, + epochs_completed, + ) + + +def load_fit_checkpoint(path: str, device: str = "cpu") -> dict: + """Load a saved calibration fit checkpoint.""" + import torch + + if not Path(path).exists(): + raise FileNotFoundError(f"Checkpoint not found: {path}") + checkpoint = torch.load(path, map_location=device) + if "model_state_dict" not in checkpoint: + raise ValueError(f"Invalid checkpoint file: {path}") + return checkpoint + + +def resolve_resume_artifact(resume_from: str) -> tuple: + """Resolve a resume input to a checkpoint or weight artifact. + + When a `.npy` weights path is provided, prefer a sibling checkpoint if + it exists so resume restores the full model state. + """ + resume_path = Path(resume_from) + if not resume_path.exists(): + raise FileNotFoundError(f"Resume artifact not found: {resume_from}") + + if resume_path.suffix == ".npy": + checkpoint_path = default_checkpoint_path(str(resume_path)) + if checkpoint_path.exists(): + logger.info( + "Found sibling checkpoint %s for %s; resuming full model state", + checkpoint_path, + resume_path, + ) + return "checkpoint", checkpoint_path + return "weights", resume_path + + return "checkpoint", resume_path + + def compute_initial_weights( X_sparse, targets_df: "pd.DataFrame", @@ -551,6 +693,8 @@ def fit_l0_weights( initial_weights: np.ndarray = None, targets_df: "pd.DataFrame" = None, achievable: np.ndarray = None, + resume_from: str = None, + checkpoint_path: str = None, ) -> np.ndarray: """Fit L0-regularized calibration weights. @@ -572,6 +716,9 @@ def fit_l0_weights( computed from targets_df age targets. targets_df: Targets DataFrame, used to compute initial_weights when not provided. + resume_from: Path to a `.checkpoint.pt` file or `.npy` + weights file to continue fitting from. + checkpoint_path: Where to save resumable fit checkpoints. Returns: Weight array of shape (n_records,). @@ -591,6 +738,56 @@ def fit_l0_weights( if initial_weights is None: initial_weights = compute_initial_weights(X_sparse, targets_df) + checkpoint_signature = build_checkpoint_signature( + X_sparse=X_sparse, + targets=targets, + target_names=target_names, + lambda_l0=lambda_l0, + beta=beta, + lambda_l2=lambda_l2, + learning_rate=learning_rate, + ) + checkpoint_state_dict = None + start_epoch = 0 + + if resume_from is not None: + resume_kind, resume_path = resolve_resume_artifact(resume_from) + if resume_kind == "weights": + resume_weights = np.load(resume_path) + if resume_weights.shape != (n_total,): + raise ValueError( + f"Resume weights at {resume_path} must have shape " + f"({n_total},), got {resume_weights.shape}" + ) + initial_weights = resume_weights.astype(np.float64, copy=False) + logger.info( + "Warm-starting calibration from saved weights at %s", + resume_path, + ) + else: + checkpoint = load_fit_checkpoint(str(resume_path), device=device) + stored_signature = checkpoint.get("signature") + if stored_signature is None: + raise ValueError( + f"Checkpoint {resume_path} is missing compatibility metadata" + ) + mismatches = checkpoint_signature_mismatches( + stored_signature, + checkpoint_signature, + ) + if mismatches: + raise ValueError( + "Checkpoint is incompatible with the requested run: " + + "; ".join(mismatches) + ) + checkpoint_state_dict = checkpoint["model_state_dict"] + start_epoch = int(checkpoint.get("epochs_completed", 0)) + logger.info( + "Resuming calibration from checkpoint %s at epoch %d", + resume_path, + start_epoch, + ) + logger.info( "L0 calibration: %d targets, %d features, " "lambda_l0=%.1e, beta=%.2f, lambda_l2=%.1e, " @@ -603,6 +800,12 @@ def fit_l0_weights( learning_rate, epochs, ) + if start_epoch > 0: + logger.info( + "Continuing for %d additional epochs (total after run: %d)", + epochs, + start_epoch + epochs, + ) model = SparseCalibrationWeights( n_features=n_total, @@ -615,6 +818,10 @@ def fit_l0_weights( log_alpha_jitter_sd=LOG_ALPHA_JITTER_SD, device=device, ) + if checkpoint_state_dict is not None: + model.load_state_dict(checkpoint_state_dict) + if resume_from is not None: + model.log_weight_jitter_sd = 0.0 if verbose_freq is None: verbose_freq = max(1, epochs // 10) @@ -632,111 +839,132 @@ def _flushed_print(*args, **kwargs): ) if enable_logging: Path(log_path).parent.mkdir(parents=True, exist_ok=True) - with open(log_path, "w") as f: - f.write( - "target_name,estimate,target,epoch," - "error,rel_error,abs_error," - "rel_abs_error,loss,achievable\n" + if start_epoch > 0 and Path(log_path).exists(): + logger.info( + "Appending epoch log to %s from epoch %d", + log_path, + start_epoch, + ) + else: + with open(log_path, "w") as f: + f.write( + "target_name,estimate,target,epoch," + "error,rel_error,abs_error," + "rel_abs_error,loss,achievable\n" + ) + logger.info( + "Epoch logging enabled: freq=%d, path=%s", + log_freq, + log_path, ) - logger.info( - "Epoch logging enabled: freq=%d, path=%s", - log_freq, - log_path, - ) t0 = time.time() - if enable_logging: - epochs_done = 0 - while epochs_done < epochs: - chunk = min(log_freq, epochs - epochs_done) - model.fit( - M=X_sparse, - y=targets, - target_groups=None, - lambda_l0=lambda_l0, - lambda_l2=lambda_l2, - lr=learning_rate, - epochs=chunk, - loss_type="relative", - verbose=False, - ) + try: + if enable_logging: + epochs_done = 0 + while epochs_done < epochs: + chunk = min(log_freq, epochs - epochs_done) + model.fit( + M=X_sparse, + y=targets, + target_groups=None, + lambda_l0=lambda_l0, + lambda_l2=lambda_l2, + lr=learning_rate, + epochs=chunk, + loss_type="relative", + verbose=False, + ) + model.log_weight_jitter_sd = 0.0 - epochs_done += chunk + epochs_done += chunk + absolute_epoch = start_epoch + epochs_done - with torch.no_grad(): - y_pred = model.predict(X_sparse).cpu().numpy() - weights_snap = model.get_weights(deterministic=True).cpu().numpy() + with torch.no_grad(): + y_pred = model.predict(X_sparse).cpu().numpy() + weights_snap = ( + model.get_weights(deterministic=True).cpu().numpy() + ) - active_w = weights_snap[weights_snap > 0] - nz = len(active_w) - sparsity = (1 - nz / n_total) * 100 + if checkpoint_path is not None: + save_fit_checkpoint( + checkpoint_path, + model, + epochs_completed=absolute_epoch, + signature=checkpoint_signature, + ) - rel_errs = np.where( - np.abs(targets) > 0, - (y_pred - targets) / np.abs(targets), - 0.0, - ) - mean_err = np.mean(np.abs(rel_errs)) - max_err = np.max(np.abs(rel_errs)) - total_loss = np.sum(rel_errs**2) - - if nz > 0: - w_tiny = (active_w < 0.01).sum() - w_small = ((active_w >= 0.01) & (active_w < 0.1)).sum() - w_med = ((active_w >= 0.1) & (active_w < 1.0)).sum() - w_normal = ((active_w >= 1.0) & (active_w < 10.0)).sum() - w_large = ((active_w >= 10.0) & (active_w < 1000.0)).sum() - w_huge = (active_w >= 1000.0).sum() - weight_dist = ( - f"[<0.01: {100 * w_tiny / nz:.1f}%, " - f"0.01-0.1: {100 * w_small / nz:.1f}%, " - f"0.1-1: {100 * w_med / nz:.1f}%, " - f"1-10: {100 * w_normal / nz:.1f}%, " - f"10-1000: {100 * w_large / nz:.1f}%, " - f">1000: {100 * w_huge / nz:.1f}%]" - ) - else: - weight_dist = "[no active weights]" - - print( - f"Epoch {epochs_done:4d}: " - f"mean_error={mean_err:.4%}, " - f"max_error={max_err:.1%}, " - f"total_loss={total_loss:.3f}, " - f"active={nz}/{n_total} " - f"({sparsity:.1f}% sparse)\n" - f" Weight dist: {weight_dist}", - flush=True, - ) + active_w = weights_snap[weights_snap > 0] + nz = len(active_w) + sparsity = (1 - nz / n_total) * 100 - ach_flags = achievable if achievable is not None else [True] * len(targets) - with open(log_path, "a") as f: - for i in range(len(targets)): - est = y_pred[i] - tgt = targets[i] - err = est - tgt - rel_err = err / tgt if tgt != 0 else 0 - abs_err = abs(err) - rel_abs = abs(rel_err) - loss = rel_err**2 - f.write( - f'"{target_names[i]}",' - f"{est},{tgt},{epochs_done}," - f"{err},{rel_err},{abs_err}," - f"{rel_abs},{loss}," - f"{ach_flags[i]}\n" + rel_errs = np.where( + np.abs(targets) > 0, + (y_pred - targets) / np.abs(targets), + 0.0, + ) + mean_err = np.mean(np.abs(rel_errs)) + max_err = np.max(np.abs(rel_errs)) + total_loss = np.sum(rel_errs**2) + + if nz > 0: + w_tiny = (active_w < 0.01).sum() + w_small = ((active_w >= 0.01) & (active_w < 0.1)).sum() + w_med = ((active_w >= 0.1) & (active_w < 1.0)).sum() + w_normal = ((active_w >= 1.0) & (active_w < 10.0)).sum() + w_large = ((active_w >= 10.0) & (active_w < 1000.0)).sum() + w_huge = (active_w >= 1000.0).sum() + weight_dist = ( + f"[<0.01: {100 * w_tiny / nz:.1f}%, " + f"0.01-0.1: {100 * w_small / nz:.1f}%, " + f"0.1-1: {100 * w_med / nz:.1f}%, " + f"1-10: {100 * w_normal / nz:.1f}%, " + f"10-1000: {100 * w_large / nz:.1f}%, " + f">1000: {100 * w_huge / nz:.1f}%]" ) + else: + weight_dist = "[no active weights]" + + print( + f"Epoch {absolute_epoch:4d}: " + f"mean_error={mean_err:.4%}, " + f"max_error={max_err:.1%}, " + f"total_loss={total_loss:.3f}, " + f"active={nz}/{n_total} " + f"({sparsity:.1f}% sparse)\n" + f" Weight dist: {weight_dist}", + flush=True, + ) - logger.info( - "Logged %d targets at epoch %d", - len(targets), - epochs_done, - ) + ach_flags = ( + achievable if achievable is not None else [True] * len(targets) + ) + with open(log_path, "a") as f: + for i in range(len(targets)): + est = y_pred[i] + tgt = targets[i] + err = est - tgt + rel_err = err / tgt if tgt != 0 else 0 + abs_err = abs(err) + rel_abs = abs(rel_err) + loss = rel_err**2 + f.write( + f'"{target_names[i]}",' + f"{est},{tgt},{absolute_epoch}," + f"{err},{rel_err},{abs_err}," + f"{rel_abs},{loss}," + f"{ach_flags[i]}\n" + ) + + logger.info( + "Logged %d targets at epoch %d", + len(targets), + absolute_epoch, + ) - if torch.cuda.is_available(): - torch.cuda.empty_cache() - else: - try: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + else: model.fit( M=X_sparse, y=targets, @@ -749,14 +977,22 @@ def _flushed_print(*args, **kwargs): verbose=True, verbose_freq=verbose_freq, ) - finally: - builtins.print = _builtin_print + model.log_weight_jitter_sd = 0.0 + if checkpoint_path is not None: + save_fit_checkpoint( + checkpoint_path, + model, + epochs_completed=start_epoch + epochs, + signature=checkpoint_signature, + ) + finally: + builtins.print = _builtin_print elapsed = time.time() - t0 logger.info( "L0 done in %.1f min (%.1f sec/epoch)", elapsed / 60, - elapsed / epochs, + elapsed / max(epochs, 1), ) with torch.no_grad(): @@ -824,6 +1060,8 @@ def run_calibration( log_freq: int = None, log_path: str = None, workers: int = 1, + resume_from: str = None, + checkpoint_path: str = None, ): """Run unified calibration pipeline. @@ -849,6 +1087,9 @@ def run_calibration( learning_rate: Optimizer learning rate. log_freq: Epochs between per-target CSV logs. log_path: Path for per-target calibration log CSV. + resume_from: Path to a checkpoint or weights file to + continue fitting from. + checkpoint_path: Where to save resumable fit checkpoints. Returns: (weights, targets_df, X_sparse, target_names, geography_info) @@ -890,6 +1131,8 @@ def run_calibration( initial_weights=initial_weights, targets_df=targets_df, achievable=pkg_achievable, + resume_from=resume_from, + checkpoint_path=checkpoint_path, ) logger.info( "Total pipeline (from package): %.1f min", @@ -1144,6 +1387,8 @@ def run_calibration( initial_weights=initial_weights, targets_df=targets_df, achievable=achievable, + resume_from=resume_from, + checkpoint_path=checkpoint_path, ) logger.info( @@ -1206,6 +1451,9 @@ def main(argv=None): else: lambda_l0 = PRESETS["local"] + if args.build_only and args.resume_from: + raise ValueError("--resume-from cannot be used with --build-only") + domain_variables = None if args.domain_variables: domain_variables = [x.strip() for x in args.domain_variables.split(",")] @@ -1230,6 +1478,11 @@ def main(argv=None): cal_log_path = None if args.log_freq is not None: cal_log_path = str(output_dir / "calibration_log.csv") + checkpoint_output_path = None + if not args.build_only: + checkpoint_output_path = args.checkpoint_output or str( + default_checkpoint_path(output_path) + ) ( weights, targets_df, @@ -1259,6 +1512,8 @@ def main(argv=None): log_freq=args.log_freq, log_path=cal_log_path, workers=args.workers, + resume_from=args.resume_from, + checkpoint_path=checkpoint_output_path, ) source_imputed = geography_info.get("dataset_for_matrix") @@ -1293,10 +1548,43 @@ def main(argv=None): np.save(output_path, weights) logger.info("Weights saved to %s", output_path) print(f"OUTPUT_PATH:{output_path}") + if checkpoint_output_path and Path(checkpoint_output_path).exists(): + logger.info("Checkpoint saved to %s", checkpoint_output_path) + print(f"CHECKPOINT_PATH:{checkpoint_output_path}") + + from policyengine_us_data.calibration.clone_and_assign import ( + GeographyAssignment, + save_geography, + ) + + block_geoids = np.asarray(geography_info["block_geoid"], dtype=str) + cd_geoids = np.asarray(geography_info["cd_geoid"], dtype=str) + geography_path = output_dir / "geography_assignment.npz" + save_geography( + GeographyAssignment( + block_geoid=block_geoids, + cd_geoid=cd_geoids, + county_fips=np.fromiter( + (block[:5] for block in block_geoids), + dtype="U5", + count=len(block_geoids), + ), + state_fips=np.fromiter( + (int(block[:2]) for block in block_geoids), + dtype=np.int32, + count=len(block_geoids), + ), + n_records=int(geography_info["base_n_records"]), + n_clones=args.n_clones, + ), + geography_path, + ) + logger.info("Geography saved to %s", geography_path) + print(f"GEOGRAPHY_PATH:{geography_path}") # Save legacy block artifact for backward compatibility blocks_path = output_dir / "stacked_blocks.npy" - np.save(str(blocks_path), geography_info["block_geoid"]) + np.save(str(blocks_path), block_geoids) logger.info("Blocks saved to %s", blocks_path) print(f"BLOCKS_PATH:{blocks_path}") @@ -1322,6 +1610,10 @@ def _sha256(filepath): t_end = time.time() weight_format = "clone_level" + epochs_total = args.epochs + if checkpoint_output_path and Path(checkpoint_output_path).exists(): + checkpoint_meta = load_fit_checkpoint(checkpoint_output_path, device="cpu") + epochs_total = int(checkpoint_meta.get("epochs_completed", args.epochs)) run_config = { "dataset": dataset_path, "db_path": db_path, @@ -1332,8 +1624,11 @@ def _sha256(filepath): "lambda_l2": args.lambda_l2, "learning_rate": args.learning_rate, "epochs": args.epochs, + "epochs_total": epochs_total, "device": args.device, "seed": args.seed, + "resume_from": args.resume_from, + "checkpoint_output": checkpoint_output_path, "domain_variables": domain_variables, "hierarchical_domains": hierarchical_domains, "target_config": args.target_config, @@ -1346,8 +1641,13 @@ def _sha256(filepath): "elapsed_seconds": round(t_end - t_start, 1), "artifacts": { "calibration_weights.npy": _sha256(output_path), + "geography_assignment.npz": _sha256(geography_path), }, } + if checkpoint_output_path and Path(checkpoint_output_path).exists(): + run_config["artifacts"]["calibration_checkpoint.pt"] = _sha256( + checkpoint_output_path + ) run_config.update(get_git_provenance()) config_path = output_dir / "unified_run_config.json" with open(config_path, "w") as f: diff --git a/policyengine_us_data/utils/huggingface.py b/policyengine_us_data/utils/huggingface.py index e43de479d..4b80a29a2 100644 --- a/policyengine_us_data/utils/huggingface.py +++ b/policyengine_us_data/utils/huggingface.py @@ -50,7 +50,8 @@ def download_calibration_inputs( (e.g. "national_") Returns: - dict with keys 'weights', 'dataset' mapping to local paths + dict with keys including weights, geography, dataset, and database + for any artifacts that exist remotely """ from pathlib import Path @@ -80,6 +81,7 @@ def download_calibration_inputs( # but won't exist yet when running calibration from scratch optional_files = { "weights": f"calibration/{prefix}calibration_weights.npy", + "geography": f"calibration/{prefix}geography_assignment.npz", "run_config": (f"calibration/{prefix}unified_run_config.json"), } for key, hf_path in optional_files.items(): @@ -151,6 +153,7 @@ def download_calibration_logs( def upload_calibration_artifacts( weights_path: str = None, + geography_path: str = None, log_dir: str = None, repo: str = "policyengine/policyengine-us-data", prefix: str = "", @@ -178,6 +181,14 @@ def upload_calibration_artifacts( ) ) + if geography_path and os.path.exists(geography_path): + operations.append( + CommitOperationAdd( + path_in_repo=(f"calibration/{prefix}geography_assignment.npz"), + path_or_fileobj=geography_path, + ) + ) + if log_dir: # Upload run config to calibration/ root for artifact validation run_config_local = os.path.join(log_dir, f"{prefix}unified_run_config.json") diff --git a/tests/unit/calibration/test_clone_and_assign.py b/tests/unit/calibration/test_clone_and_assign.py index ab297cd89..e7e70e719 100644 --- a/tests/unit/calibration/test_clone_and_assign.py +++ b/tests/unit/calibration/test_clone_and_assign.py @@ -13,8 +13,12 @@ GeographyAssignment, _build_agi_block_probs, load_global_block_distribution, + load_sorted_block_cd_lookup, assign_random_geography, double_geography_for_puf, + reconstruct_geography_from_blocks, + save_geography, + load_geography, ) MOCK_BLOCKS = pd.DataFrame( @@ -49,8 +53,10 @@ @pytest.fixture(autouse=True) def _clear_lru_cache(): load_global_block_distribution.cache_clear() + load_sorted_block_cd_lookup.cache_clear() yield load_global_block_distribution.cache_clear() + load_sorted_block_cd_lookup.cache_clear() def _mock_distribution(): @@ -221,3 +227,74 @@ def test_puf_half_matches_cps_half(self): r.state_fips[start:mid], r.state_fips[mid:end], ) + + +class TestGeographyArtifacts: + @patch( + "policyengine_us_data.calibration.clone_and_assign" + ".load_global_block_distribution" + ) + def test_reconstruct_geography_from_blocks(self, mock_load): + mock_load.return_value = _mock_distribution() + blocks = np.array( + [ + "010010001001001", + "020010001001002", + "360100001001004", + "010010001001003", + ], + dtype=str, + ) + + geo = reconstruct_geography_from_blocks( + block_geoids=blocks, + n_records=2, + n_clones=2, + ) + + np.testing.assert_array_equal( + geo.cd_geoid, + np.array(["101", "102", "103", "101"]), + ) + np.testing.assert_array_equal( + geo.county_fips, + np.array(["01001", "02001", "36010", "01001"]), + ) + np.testing.assert_array_equal( + geo.state_fips, + np.array([1, 2, 36, 1], dtype=np.int32), + ) + + @patch( + "policyengine_us_data.calibration.clone_and_assign" + ".load_global_block_distribution" + ) + def test_reconstruct_geography_from_blocks_raises_on_unknown_block(self, mock_load): + mock_load.return_value = _mock_distribution() + with pytest.raises(KeyError): + reconstruct_geography_from_blocks( + block_geoids=np.array(["999999999999999"], dtype=str), + n_records=1, + n_clones=1, + ) + + def test_save_and_load_geography_round_trip(self, tmp_path): + geo = GeographyAssignment( + block_geoid=np.array(["010010001001001", "020010001001001"]), + cd_geoid=np.array(["101", "202"]), + county_fips=np.array(["01001", "02001"]), + state_fips=np.array([1, 2], dtype=np.int32), + n_records=1, + n_clones=2, + ) + path = tmp_path / "geography_assignment.npz" + + save_geography(geo, path) + loaded = load_geography(path) + + np.testing.assert_array_equal(loaded.block_geoid, geo.block_geoid) + np.testing.assert_array_equal(loaded.cd_geoid, geo.cd_geoid) + np.testing.assert_array_equal(loaded.county_fips, geo.county_fips) + np.testing.assert_array_equal(loaded.state_fips, geo.state_fips) + assert loaded.n_records == geo.n_records + assert loaded.n_clones == geo.n_clones diff --git a/tests/unit/calibration/test_unified_calibration.py b/tests/unit/calibration/test_unified_calibration.py index 0414bbfb8..1baddd3ff 100644 --- a/tests/unit/calibration/test_unified_calibration.py +++ b/tests/unit/calibration/test_unified_calibration.py @@ -7,6 +7,7 @@ import numpy as np import pytest +import scipy.sparse as sp from types import SimpleNamespace from unittest.mock import patch @@ -391,6 +392,186 @@ def test_skip_takeup_rerandomize_flag(self): args_default = parse_args([]) assert args_default.skip_takeup_rerandomize is False + def test_resume_flags(self): + from policyengine_us_data.calibration.unified_calibration import ( + parse_args, + ) + + args = parse_args( + [ + "--resume-from", + "weights.npy", + "--checkpoint-output", + "weights.checkpoint.pt", + ] + ) + assert args.resume_from == "weights.npy" + assert args.checkpoint_output == "weights.checkpoint.pt" + + +class FakeSparseCalibrationWeights: + def __init__( + self, + n_features, + beta=None, + gamma=None, + zeta=None, + init_keep_prob=None, + init_weights=None, + log_weight_jitter_sd=0.0, + log_alpha_jitter_sd=0.0, + device="cpu", + ): + import torch + + self.n_features = n_features + self.device = device + self.log_weight_jitter_sd = log_weight_jitter_sd + weight_values = ( + np.ones(n_features, dtype=np.float32) + if init_weights is None + else np.asarray(init_weights, dtype=np.float32) + ) + self.weights = torch.tensor(weight_values, dtype=torch.float32) + self.alpha = torch.zeros(n_features, dtype=torch.float32) + + def fit( + self, + M, + y, + lambda_l0=0.0, + lambda_l2=0.0, + lr=0.0, + epochs=1, + loss_type="relative", + verbose=False, + verbose_freq=1, + target_groups=None, + ): + increment = float(epochs) + (self.alpha / 10.0) + self.weights = self.weights + increment + self.alpha = self.alpha + (10.0 * float(epochs)) + return self + + def predict(self, M): + import torch + + weights = self.get_weights(deterministic=True).cpu().numpy() + return torch.tensor(M.dot(weights), dtype=torch.float32) + + def get_weights(self, deterministic=True): + return self.weights.clone() + + def state_dict(self): + return { + "weights": self.weights.clone(), + "alpha": self.alpha.clone(), + } + + def load_state_dict(self, state_dict): + self.weights = state_dict["weights"].clone() + self.alpha = state_dict["alpha"].clone() + + +class TestFitResume: + def _fit_kwargs(self, tmp_path): + return { + "X_sparse": sp.csr_matrix(np.eye(2, dtype=np.float32)), + "targets": np.array([1.0, 2.0], dtype=np.float64), + "lambda_l0": 1e-4, + "epochs": 1, + "device": "cpu", + "beta": 0.65, + "lambda_l2": 1e-12, + "learning_rate": 0.15, + "log_freq": 1, + "log_path": str(tmp_path / "calibration_log.csv"), + "target_names": ["target_a", "target_b"], + "initial_weights": np.array([1.0, 2.0], dtype=np.float64), + "achievable": np.array([True, True]), + } + + def test_resume_from_weights_prefers_sibling_checkpoint(self, tmp_path): + from policyengine_us_data.calibration.unified_calibration import ( + default_checkpoint_path, + fit_l0_weights, + ) + + weights_path = tmp_path / "weights.npy" + checkpoint_path = default_checkpoint_path(str(weights_path)) + kwargs = self._fit_kwargs(tmp_path) + kwargs["checkpoint_path"] = str(checkpoint_path) + + with patch( + "l0.calibration.SparseCalibrationWeights", + FakeSparseCalibrationWeights, + ): + first_weights = fit_l0_weights(**kwargs) + np.save(weights_path, first_weights) + + resumed_weights = fit_l0_weights( + **{ + **kwargs, + "resume_from": str(weights_path), + } + ) + + np.testing.assert_allclose(first_weights, np.array([2.0, 3.0])) + np.testing.assert_allclose(resumed_weights, np.array([4.0, 5.0])) + + with open(kwargs["log_path"]) as f: + lines = f.read().strip().splitlines() + assert len(lines) == 5 + assert lines[1].split(",")[3] == "1" + assert lines[3].split(",")[3] == "2" + + def test_resume_from_weights_falls_back_when_checkpoint_missing(self, tmp_path): + from policyengine_us_data.calibration.unified_calibration import fit_l0_weights + + weights_path = tmp_path / "weights.npy" + np.save(weights_path, np.array([2.0, 3.0], dtype=np.float64)) + kwargs = self._fit_kwargs(tmp_path) + + with patch( + "l0.calibration.SparseCalibrationWeights", + FakeSparseCalibrationWeights, + ): + resumed_weights = fit_l0_weights( + **{ + **kwargs, + "resume_from": str(weights_path), + } + ) + + np.testing.assert_allclose(resumed_weights, np.array([3.0, 4.0])) + + def test_resume_checkpoint_rejects_incompatible_hyperparams(self, tmp_path): + from policyengine_us_data.calibration.unified_calibration import ( + default_checkpoint_path, + fit_l0_weights, + ) + + weights_path = tmp_path / "weights.npy" + checkpoint_path = default_checkpoint_path(str(weights_path)) + kwargs = self._fit_kwargs(tmp_path) + kwargs["checkpoint_path"] = str(checkpoint_path) + + with patch( + "l0.calibration.SparseCalibrationWeights", + FakeSparseCalibrationWeights, + ): + first_weights = fit_l0_weights(**kwargs) + np.save(weights_path, first_weights) + + with pytest.raises(ValueError, match="Checkpoint is incompatible"): + fit_l0_weights( + **{ + **kwargs, + "lambda_l0": 9e-4, + "resume_from": str(checkpoint_path), + } + ) + class TestGeographyAssignmentCountyFips: """Verify county_fips field on GeographyAssignment.""" From 2d40f74b3d4c77e2056a7343c80b47400a236ffe Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Wed, 8 Apr 2026 23:08:09 -0400 Subject: [PATCH 02/12] Run ruff format Co-Authored-By: Claude Opus 4.6 (1M context) --- modal_app/remote_calibration_runner.py | 4 +--- .../calibration/publish_local_area.py | 23 +++++++++++++------ .../calibration/unified_calibration.py | 8 ++----- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/modal_app/remote_calibration_runner.py b/modal_app/remote_calibration_runner.py index 5fac45209..b22e4a158 100644 --- a/modal_app/remote_calibration_runner.py +++ b/modal_app/remote_calibration_runner.py @@ -1066,9 +1066,7 @@ def main( upload_calibration_artifacts( weights_path=output, - geography_path=( - geography_output if result.get("geography") else None - ), + geography_path=(geography_output if result.get("geography") else None), log_dir=".", prefix=prefix, ) diff --git a/policyengine_us_data/calibration/publish_local_area.py b/policyengine_us_data/calibration/publish_local_area.py index f90148e91..2db93f7fc 100644 --- a/policyengine_us_data/calibration/publish_local_area.py +++ b/policyengine_us_data/calibration/publish_local_area.py @@ -81,7 +81,9 @@ def resolve_calibration_geography_paths( if blocks_path is not None: block_candidates.append(Path(blocks_path)) - block_candidates.append(_sibling_artifact_path(weights_path, LEGACY_BLOCKS_FILENAME)) + block_candidates.append( + _sibling_artifact_path(weights_path, LEGACY_BLOCKS_FILENAME) + ) block_candidates.append(weights_path.with_name(LEGACY_BLOCKS_FILENAME)) resolved_geo = next((path for path in geo_candidates if path.exists()), None) @@ -155,7 +157,9 @@ def load_calibration_geography( return geography if resolved_blocks is not None: - block_geoids = np.asarray(np.load(resolved_blocks, allow_pickle=True), dtype=str) + block_geoids = np.asarray( + np.load(resolved_blocks, allow_pickle=True), dtype=str + ) if len(block_geoids) % n_records != 0: raise ValueError( f"Legacy blocks artifact {resolved_blocks} has {len(block_geoids)} " @@ -168,8 +172,7 @@ def load_calibration_geography( f"n_clones={inferred_n_clones}, expected {n_clones}" ) print( - "Reconstructing geography from legacy stacked blocks at " - f"{resolved_blocks}" + f"Reconstructing geography from legacy stacked blocks at {resolved_blocks}" ) return reconstruct_geography_from_blocks( block_geoids=block_geoids, @@ -1079,9 +1082,15 @@ def main(): # Determine what to build based on flags do_national = args.national_only - do_states = not args.districts_only and not args.cities_only and not args.national_only - do_districts = not args.states_only and not args.cities_only and not args.national_only - do_cities = not args.states_only and not args.districts_only and not args.national_only + do_states = ( + not args.districts_only and not args.cities_only and not args.national_only + ) + do_districts = ( + not args.states_only and not args.cities_only and not args.national_only + ) + do_cities = ( + not args.states_only and not args.districts_only and not args.national_only + ) # If a specific *-only flag is set, only build that type if args.states_only: diff --git a/policyengine_us_data/calibration/unified_calibration.py b/policyengine_us_data/calibration/unified_calibration.py index 22344fc65..1535afb58 100644 --- a/policyengine_us_data/calibration/unified_calibration.py +++ b/policyengine_us_data/calibration/unified_calibration.py @@ -540,9 +540,7 @@ def checkpoint_signature_mismatches(expected: dict, actual: dict) -> list: f"{key} expected {expected_value}, got {actual_value}" ) elif actual_value != expected_value: - mismatches.append( - f"{key} expected {expected_value}, got {actual_value}" - ) + mismatches.append(f"{key} expected {expected_value}, got {actual_value}") return mismatches @@ -882,9 +880,7 @@ def _flushed_print(*args, **kwargs): with torch.no_grad(): y_pred = model.predict(X_sparse).cpu().numpy() - weights_snap = ( - model.get_weights(deterministic=True).cpu().numpy() - ) + weights_snap = model.get_weights(deterministic=True).cpu().numpy() if checkpoint_path is not None: save_fit_checkpoint( From d268effb46ccb3a697275a522d1a01fc0e5fc02e Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Wed, 8 Apr 2026 23:22:07 -0400 Subject: [PATCH 03/12] Add changelog fragment for PR 708 Co-Authored-By: Claude Opus 4.6 (1M context) --- changelog.d/708.added | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/708.added diff --git a/changelog.d/708.added b/changelog.d/708.added new file mode 100644 index 000000000..4d24af1c6 --- /dev/null +++ b/changelog.d/708.added @@ -0,0 +1 @@ +Save calibration geography as a pipeline artifact and add ``--resume-from`` checkpoint support for long-running calibration fits. From 5ca12418b3a024713ee6544be88aa356e8c477da Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Wed, 8 Apr 2026 23:30:44 -0400 Subject: [PATCH 04/12] Stub l0 module in test so patch works without l0-python installed Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/unit/calibration/test_unified_calibration.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/unit/calibration/test_unified_calibration.py b/tests/unit/calibration/test_unified_calibration.py index 1baddd3ff..da2f79523 100644 --- a/tests/unit/calibration/test_unified_calibration.py +++ b/tests/unit/calibration/test_unified_calibration.py @@ -5,12 +5,24 @@ block-level takeup seeding, county precomputation, and CLI flags. """ +import sys +import types + import numpy as np import pytest import scipy.sparse as sp from types import SimpleNamespace from unittest.mock import patch +# Ensure `l0.calibration` is importable so patch() can traverse the path +# even when the real l0-python package is not installed (e.g. CI). +if "l0" not in sys.modules: + _l0 = types.ModuleType("l0") + _l0.calibration = types.ModuleType("l0.calibration") + _l0.calibration.SparseCalibrationWeights = None + sys.modules["l0"] = _l0 + sys.modules["l0.calibration"] = _l0.calibration + from policyengine_us_data.utils.randomness import seeded_rng from policyengine_us_data.utils.takeup import ( SIMPLE_TAKEUP_VARS, From bc12152da8de807515e0133c0b078db1a7987eaf Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Thu, 9 Apr 2026 21:33:27 -0400 Subject: [PATCH 05/12] Add self-employment and SSN card type count targets to calibration config Co-Authored-By: Claude Opus 4.6 (1M context) --- .../calibration/block_assignment.py | 47 ++++++++++++++----- .../calibration/target_config.yaml | 17 +++++++ 2 files changed, 53 insertions(+), 11 deletions(-) diff --git a/policyengine_us_data/calibration/block_assignment.py b/policyengine_us_data/calibration/block_assignment.py index 3754ad5af..7ddb0fe48 100644 --- a/policyengine_us_data/calibration/block_assignment.py +++ b/policyengine_us_data/calibration/block_assignment.py @@ -22,7 +22,7 @@ """ import random -import re +import unicodedata from functools import lru_cache from io import StringIO from typing import Dict, Optional @@ -72,8 +72,9 @@ def _build_county_fips_to_enum() -> Dict[str, str]: """ url = "https://www2.census.gov/geo/docs/reference/codes2020/national_county2020.txt" response = requests.get(url, timeout=60) + response.raise_for_status() df = pd.read_csv( - StringIO(response.text), + StringIO(response.content.decode("utf-8")), delimiter="|", dtype=str, usecols=["STATE", "STATEFP", "COUNTYFP", "COUNTYNAME"], @@ -82,20 +83,44 @@ def _build_county_fips_to_enum() -> Dict[str, str]: valid_enum_names = set(County._member_names_) fips_to_enum = {} + def county_name_candidates(county_name: str, state_code: str) -> list[str]: + """Return normalized enum-name candidates for a Census county label.""" + raw = county_name.strip() + candidates = [] + for candidate_name in ( + raw, + unicodedata.normalize("NFKD", raw) + .encode("ascii", "ignore") + .decode("ascii"), + ): + enum_name = candidate_name.upper() + enum_name = enum_name.replace("-", "_") + enum_name = enum_name.replace(" ", "_") + enum_name = enum_name.replace(".", "") + enum_name = enum_name.replace("'", "_") + candidates.append(f"{enum_name}_{state_code}") + + # Backwards-compatible fallback for names that historically + # dropped apostrophes entirely. + candidates.append(f"{enum_name.replace('_S_', 'S_')}_{state_code}") + + ordered = [] + seen = set() + for candidate in candidates: + if candidate not in seen: + seen.add(candidate) + ordered.append(candidate) + return ordered + for _, row in df.iterrows(): county_fips = row["STATEFP"] + row["COUNTYFP"] state_code = row["STATE"] county_name = row["COUNTYNAME"] - # Transform to enum name format - enum_name = county_name.upper() - enum_name = re.sub(r"[.'\"]", "", enum_name) - enum_name = enum_name.replace("-", "_") - enum_name = enum_name.replace(" ", "_") - enum_name = f"{enum_name}_{state_code}" - - if enum_name in valid_enum_names: - fips_to_enum[county_fips] = enum_name + for enum_name in county_name_candidates(county_name, state_code): + if enum_name in valid_enum_names: + fips_to_enum[county_fips] = enum_name + break return fips_to_enum diff --git a/policyengine_us_data/calibration/target_config.yaml b/policyengine_us_data/calibration/target_config.yaml index 2ec74b1e2..8cd182ec0 100644 --- a/policyengine_us_data/calibration/target_config.yaml +++ b/policyengine_us_data/calibration/target_config.yaml @@ -44,6 +44,10 @@ include: - variable: person_count geo_level: state domain_variable: medicaid_enrolled + # Restore old loss.py's national Medicaid enrollment count target. + - variable: person_count + geo_level: national + domain_variable: medicaid # REMOVED: is_pregnant — 100% unachievable across all 51 state geos - variable: snap geo_level: state @@ -170,12 +174,25 @@ include: - variable: tax_unit_count geo_level: national domain_variable: aca_ptc + # Restore old loss.py's ACA enrollment count target. + - variable: person_count + geo_level: national + domain_variable: aca_ptc - variable: tax_unit_count geo_level: national domain_variable: refundable_ctc - variable: tax_unit_count geo_level: national domain_variable: non_refundable_ctc + # Restore old loss.py's self-employment return-count target. + - variable: tax_unit_count + geo_level: national + domain_variable: self_employment_income + + # === NATIONAL — identity / population count targets from old loss.py === + - variable: person_count + geo_level: national + domain_variable: ssn_card_type # === NATIONAL — SOI deduction totals (non-reform) === - variable: medical_expense_deduction From 57bea317bb2b2b781330422e9bc49c8921ac5bd2 Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Thu, 9 Apr 2026 22:21:34 -0400 Subject: [PATCH 06/12] Distinguish ITIN holders from SSN holders in CPS data Fix calibration crash on string constraint variables (ssn_card_type) by falling back from float32 cast when values are non-numeric. Impute ITIN status for undocumented (code-0) persons: select tax units with code-0 earners via weighted random sampling targeting 4.4M ITIN returns (IRS NTA), then mark all code-0 members of those units. Updates has_tin = (ssn_card_type != 0) | has_itin_number so ITIN holders correctly qualify for ODC ($500 credit). Co-Authored-By: Claude Opus 4.6 (1M context) --- .../calibration/unified_matrix_builder.py | 16 +++- policyengine_us_data/datasets/cps/cps.py | 94 +++++++++++++++++-- policyengine_us_data/utils/identification.py | 18 +++- tests/integration/test_enhanced_cps.py | 8 +- tests/integration/test_sparse_enhanced_cps.py | 8 +- .../unit/datasets/test_cps_identification.py | 30 ++++++ 6 files changed, 156 insertions(+), 18 deletions(-) diff --git a/policyengine_us_data/calibration/unified_matrix_builder.py b/policyengine_us_data/calibration/unified_matrix_builder.py index 53408fea1..d1d2ca535 100644 --- a/policyengine_us_data/calibration/unified_matrix_builder.py +++ b/policyengine_us_data/calibration/unified_matrix_builder.py @@ -161,11 +161,15 @@ def _compute_single_state( person = {} for var in constraint_vars: try: - person[var] = state_sim.calculate( + raw = state_sim.calculate( var, time_period, map_to="person", - ).values.astype(np.float32) + ).values + try: + person[var] = raw.astype(np.float32) + except (ValueError, TypeError): + person[var] = raw except Exception as exc: logger.warning( "Cannot calculate constraint '%s' for state %d: %s", @@ -1113,11 +1117,15 @@ def _build_state_values( person = {} for var in constraint_vars: try: - person[var] = state_sim.calculate( + raw = state_sim.calculate( var, self.time_period, map_to="person", - ).values.astype(np.float32) + ).values + try: + person[var] = raw.astype(np.float32) + except (ValueError, TypeError): + person[var] = raw except Exception as exc: logger.warning( "Cannot calculate constraint '%s' for state %d: %s", diff --git a/policyengine_us_data/datasets/cps/cps.py b/policyengine_us_data/datasets/cps/cps.py index b6a7f40a6..30e4d2fae 100644 --- a/policyengine_us_data/datasets/cps/cps.py +++ b/policyengine_us_data/datasets/cps/cps.py @@ -126,7 +126,7 @@ def generate(self): logging.info("Adding previous year income variables") add_previous_year_income(self, cps) logging.info("Adding SSN card type") - add_ssn_card_type( + ssn_card_type = add_ssn_card_type( cps, person, spm_unit, @@ -135,6 +135,9 @@ def generate(self): undocumented_workers_target=8.3e6, undocumented_students_target=0.21 * 1.9e6, ) + logging.info("Imputing ITIN status") + has_itin_number = impute_itin_status(cps, person, ssn_card_type) + _store_identification_variables(cps, ssn_card_type, has_itin_number) logging.info("Adding family variables") add_spm_variables(self, cps, spm_unit) logging.info("Adding household variables") @@ -975,7 +978,7 @@ def add_ssn_card_type( undocumented_target: float = 13e6, undocumented_workers_target: float = 8.3e6, undocumented_students_target: float = 0.21 * 1.9e6, -) -> None: +) -> np.ndarray: """ Assign SSN card type using PRCITSHP, employment status, and ASEC-UA conditions. Codes: @@ -1702,12 +1705,6 @@ def get_arrival_year_midpoint(peinusyr): # Final write (all values now in ImmigrationStatus Enum) # Save as immigration_status_str since that's what PolicyEngine expects cps["immigration_status_str"] = immigration_status.astype("S") - # ============================================================================ - # CONVERT TO STRING LABELS AND STORE - # ============================================================================ - - _store_identification_variables(cps, ssn_card_type) - # Final population summary print(f"\nFinal populations:") code_to_str = { @@ -1756,6 +1753,87 @@ def get_arrival_year_midpoint(peinusyr): # Update documentation with actual numbers _update_documentation_with_numbers(log_df, DOCS_FOLDER) + return ssn_card_type + + +def impute_itin_status( + cps: dict, + person: pd.DataFrame, + ssn_card_type: np.ndarray, + itin_returns_target: float = 4.4e6, + random_seed: int = 98765, +) -> np.ndarray: + """Impute which undocumented (code-0) persons hold ITINs. + + The 4.4M target is ITIN *returns* (tax units), not persons. + Eligible tax units are those with at least one code-0 member who has + employment or self-employment income. This is a provisional rule — + it will miss some ITIN filers with no current earnings and include + some units that wouldn't file. + """ + n_persons = len(person) + has_itin_number = np.zeros(n_persons, dtype=bool) + + person_tax_unit_ids = np.asarray(cps["person_tax_unit_id"]) + is_code_0 = ssn_card_type == 0 + has_earnings = (person.WSAL_VAL > 0) | (person.SEMP_VAL > 0) + + # Identify tax units with at least one code-0 earner + eligible_person_mask = is_code_0 & has_earnings + eligible_tu_ids = np.unique(person_tax_unit_ids[eligible_person_mask]) + + # Build tax-unit weight lookup from household weights + household_ids = np.asarray(cps["household_id"]) + household_weights = np.asarray(cps["household_weight"]) + person_household_ids = np.asarray(cps["person_household_id"]) + hh_to_weight = dict(zip(household_ids, household_weights)) + + # Map each tax unit to its weight (weight of first member's household) + tu_to_weight = {} + for i, tu_id in enumerate(person_tax_unit_ids): + if tu_id not in tu_to_weight: + tu_to_weight[tu_id] = hh_to_weight.get(person_household_ids[i], 0) + + eligible_weights = np.array( + [tu_to_weight.get(tu_id, 0) for tu_id in eligible_tu_ids] + ) + total_eligible_weighted = eligible_weights.sum() + + if total_eligible_weighted <= 0: + print("ITIN imputation: no eligible tax units found") + return has_itin_number + + # Weighted random selection of tax units to hit the 4.4M target + target_share = min(itin_returns_target / total_eligible_weighted, 1.0) + rng = np.random.default_rng(seed=random_seed) + selected_mask = rng.random(len(eligible_tu_ids)) < target_share + selected_tu_ids = set(eligible_tu_ids[selected_mask]) + + # Mark all code-0 members of selected tax units as ITIN holders + for i in range(n_persons): + if is_code_0[i] and person_tax_unit_ids[i] in selected_tu_ids: + has_itin_number[i] = True + + # Logging + person_weights = np.array( + [hh_to_weight.get(hh_id, 0) for hh_id in person_household_ids] + ) + selected_tu_weighted = sum(tu_to_weight.get(tu_id, 0) for tu_id in selected_tu_ids) + itin_filers = has_itin_number & has_earnings + itin_total = has_itin_number + print("\nITIN imputation:") + print( + f" Weighted ITIN filer tax units: {selected_tu_weighted:,.0f} " + f"(target: {itin_returns_target:,.0f})" + ) + print( + f" Weighted code-0 earners with ITIN: " + f"{np.sum(person_weights[itin_filers]):,.0f}" + ) + print(f" Total weighted ITIN persons: {np.sum(person_weights[itin_total]):,.0f}") + + return has_itin_number + def _update_documentation_with_numbers(log_df, docs_dir): """Update the documentation file with actual population numbers from CSV""" diff --git a/policyengine_us_data/utils/identification.py b/policyengine_us_data/utils/identification.py index 69558f068..687ac772e 100644 --- a/policyengine_us_data/utils/identification.py +++ b/policyengine_us_data/utils/identification.py @@ -10,14 +10,24 @@ } -def _derive_has_tin_from_ssn_card_type_codes(ssn_card_type: np.ndarray) -> np.ndarray: +def _derive_has_tin_from_ssn_card_type_codes( + ssn_card_type: np.ndarray, + has_itin_number: np.ndarray | None = None, +) -> np.ndarray: """Return whether a person has any taxpayer ID from CPS ID status codes.""" - return np.asarray(ssn_card_type) != 0 + has_ssn = np.asarray(ssn_card_type) != 0 + if has_itin_number is not None: + return has_ssn | np.asarray(has_itin_number) + return has_ssn -def _store_identification_variables(cps: dict, ssn_card_type: np.ndarray) -> None: +def _store_identification_variables( + cps: dict, + ssn_card_type: np.ndarray, + has_itin_number: np.ndarray | None = None, +) -> None: """Persist identification inputs used by PolicyEngine US.""" - has_tin = _derive_has_tin_from_ssn_card_type_codes(ssn_card_type) + has_tin = _derive_has_tin_from_ssn_card_type_codes(ssn_card_type, has_itin_number) cps["ssn_card_type"] = ( pd.Series(ssn_card_type).map(SSN_CARD_TYPE_CODE_TO_STR).astype("S").values ) diff --git a/tests/integration/test_enhanced_cps.py b/tests/integration/test_enhanced_cps.py index 74c35def5..5f4d897a3 100644 --- a/tests/integration/test_enhanced_cps.py +++ b/tests/integration/test_enhanced_cps.py @@ -230,8 +230,14 @@ def test_has_tin_matches_identification_inputs(ecps_sim): has_itin = _period_array(data["has_itin"], 2024) ssn_card_type = _period_array(data["ssn_card_type"], 2024).astype(str) + # has_itin is still an alias for has_tin np.testing.assert_array_equal(has_itin, has_tin) - np.testing.assert_array_equal(has_tin, ssn_card_type != "NONE") + # Everyone with an SSN card has a TIN + assert has_tin[ssn_card_type != "NONE"].all() + # Some code-0 (NONE) people have TINs via ITIN + none_mask = ssn_card_type == "NONE" + assert none_mask.any(), "Expected some ssn_card_type == NONE" + assert has_tin[none_mask].any(), "Expected some ITIN holders among code-0" def test_aca_calibration(): diff --git a/tests/integration/test_sparse_enhanced_cps.py b/tests/integration/test_sparse_enhanced_cps.py index 2d53c73b8..f5474fd56 100644 --- a/tests/integration/test_sparse_enhanced_cps.py +++ b/tests/integration/test_sparse_enhanced_cps.py @@ -214,8 +214,14 @@ def test_sparse_has_tin_matches_identification_inputs(sim): has_itin = _period_array(data["has_itin"], 2024) ssn_card_type = _period_array(data["ssn_card_type"], 2024).astype(str) + # has_itin is still an alias for has_tin np.testing.assert_array_equal(has_itin, has_tin) - np.testing.assert_array_equal(has_tin, ssn_card_type != "NONE") + # Everyone with an SSN card has a TIN + assert has_tin[ssn_card_type != "NONE"].all() + # Some code-0 (NONE) people have TINs via ITIN + none_mask = ssn_card_type == "NONE" + assert none_mask.any(), "Expected some ssn_card_type == NONE" + assert has_tin[none_mask].any(), "Expected some ITIN holders among code-0" def test_sparse_aca_calibration(sim): diff --git a/tests/unit/datasets/test_cps_identification.py b/tests/unit/datasets/test_cps_identification.py index 690aeeaa9..cb3250e8c 100644 --- a/tests/unit/datasets/test_cps_identification.py +++ b/tests/unit/datasets/test_cps_identification.py @@ -15,6 +15,19 @@ def test_derive_has_tin_from_ssn_card_type_codes(): ) +def test_derive_has_tin_with_itin(): + ssn_codes = np.array([0, 0, 1, 2, 3]) + has_itin = np.array([True, False, False, False, False]) + + result = _derive_has_tin_from_ssn_card_type_codes(ssn_codes, has_itin) + + # code-0 with ITIN → True, code-0 without ITIN → False, others → True + np.testing.assert_array_equal( + result, + np.array([True, False, True, True, True], dtype=bool), + ) + + def test_store_identification_variables_writes_has_tin_and_alias(): cps = {} @@ -31,3 +44,20 @@ def test_store_identification_variables_writes_has_tin_and_alias(): np.array([False, True, True, True], dtype=bool), ) np.testing.assert_array_equal(cps["has_itin"], cps["has_tin"]) + + +def test_store_identification_variables_with_itin(): + cps = {} + ssn_codes = np.array([0, 0, 1, 2, 3]) + has_itin = np.array([True, False, False, False, False]) + + _store_identification_variables(cps, ssn_codes, has_itin) + + # code-0 with ITIN gets has_tin=True + assert cps["has_tin"][0] == True # noqa: E712 + # code-0 without ITIN gets has_tin=False + assert cps["has_tin"][1] == False # noqa: E712 + # SSN holders still have has_tin=True + assert cps["has_tin"][2] == True # noqa: E712 + # alias still matches + np.testing.assert_array_equal(cps["has_itin"], cps["has_tin"]) From d628ac5434dfb011c3e1ba90cb2ad2d2c3b5d42d Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Thu, 9 Apr 2026 23:22:18 -0400 Subject: [PATCH 07/12] Fix PR 708 checkpoint and ID regressions --- modal_app/remote_calibration_runner.py | 23 +++- .../calibration/unified_calibration.py | 14 +++ .../calibration/unified_matrix_builder.py | 40 +++++- policyengine_us_data/datasets/cps/cps.py | 7 +- policyengine_us_data/utils/huggingface.py | 10 ++ policyengine_us_data/utils/identification.py | 69 +++++++++- .../calibration/test_unified_calibration.py | 37 ++++++ .../test_unified_matrix_builder_merge.py | 69 ++++++++++ .../unit/datasets/test_cps_identification.py | 118 +++++++++++++----- tests/unit/test_remote_calibration_runner.py | 76 +++++++++++ 10 files changed, 420 insertions(+), 43 deletions(-) create mode 100644 tests/unit/calibration/test_unified_matrix_builder_merge.py create mode 100644 tests/unit/test_remote_calibration_runner.py diff --git a/modal_app/remote_calibration_runner.py b/modal_app/remote_calibration_runner.py index b22e4a158..c939eebec 100644 --- a/modal_app/remote_calibration_runner.py +++ b/modal_app/remote_calibration_runner.py @@ -11,7 +11,7 @@ if _p not in sys.path: sys.path.insert(0, _p) -from modal_app.images import gpu_image as image +from modal_app.images import gpu_image as image # noqa: E402 app = modal.App("policyengine-us-data-fit-weights") @@ -72,6 +72,7 @@ def _collect_outputs(cal_lines): log_path = None cal_log_path = None config_path = None + checkpoint_path = None for line in cal_lines: if "OUTPUT_PATH:" in line: output_path = line.split("OUTPUT_PATH:")[1].strip() @@ -83,6 +84,8 @@ def _collect_outputs(cal_lines): cal_log_path = line.split("CAL_LOG_PATH:")[1].strip() elif "LOG_PATH:" in line: log_path = line.split("LOG_PATH:")[1].strip() + elif "CHECKPOINT_PATH:" in line: + checkpoint_path = line.split("CHECKPOINT_PATH:")[1].strip() with open(output_path, "rb") as f: weights_bytes = f.read() @@ -107,12 +110,18 @@ def _collect_outputs(cal_lines): with open(config_path, "rb") as f: config_bytes = f.read() + checkpoint_bytes = None + if checkpoint_path: + with open(checkpoint_path, "rb") as f: + checkpoint_bytes = f.read() + return { "weights": weights_bytes, "geography": geography_bytes, "log": log_bytes, "cal_log": cal_log_bytes, "config": config_bytes, + "checkpoint": checkpoint_bytes, } @@ -1037,6 +1046,12 @@ def main( f.write(result["config"]) print(f"Run config saved to: {config_output}") + checkpoint_output = f"{prefix}calibration_weights.checkpoint.pt" + if result.get("checkpoint"): + with open(checkpoint_output, "wb") as f: + f.write(result["checkpoint"]) + print(f"Checkpoint saved to: {checkpoint_output}") + # Push weights to pipeline volume for downstream steps from io import BytesIO @@ -1056,6 +1071,11 @@ def main( BytesIO(result["config"]), f"artifacts/{prefix}unified_run_config.json", ) + if result.get("checkpoint"): + batch.put_file( + BytesIO(result["checkpoint"]), + f"artifacts/{prefix}calibration_weights.checkpoint.pt", + ) pipeline_vol.commit() print("Weights committed to pipeline volume", flush=True) @@ -1067,6 +1087,7 @@ def main( upload_calibration_artifacts( weights_path=output, geography_path=(geography_output if result.get("geography") else None), + checkpoint_path=(checkpoint_output if result.get("checkpoint") else None), log_dir=".", prefix=prefix, ) diff --git a/policyengine_us_data/calibration/unified_calibration.py b/policyengine_us_data/calibration/unified_calibration.py index 1535afb58..cf9d0532c 100644 --- a/policyengine_us_data/calibration/unified_calibration.py +++ b/policyengine_us_data/calibration/unified_calibration.py @@ -503,6 +503,19 @@ def _hash_string_list(values: list) -> str: return digest.hexdigest() +def _hash_sparse_matrix(X_sparse) -> str: + """Hash sparse matrix structure and values for resume compatibility.""" + import hashlib + + X_csr = X_sparse.tocsr() + digest = hashlib.sha256() + digest.update(np.asarray(X_csr.shape, dtype=np.int64).tobytes()) + digest.update(np.asarray(X_csr.indptr, dtype=np.int64).tobytes()) + digest.update(np.asarray(X_csr.indices, dtype=np.int64).tobytes()) + digest.update(np.asarray(X_csr.data).tobytes()) + return digest.hexdigest() + + def build_checkpoint_signature( X_sparse, targets: np.ndarray, @@ -519,6 +532,7 @@ def build_checkpoint_signature( return { "n_features": int(X_sparse.shape[1]), "n_targets": int(len(targets_arr)), + "x_sparse_sha256": _hash_sparse_matrix(X_sparse), "target_names_sha256": _hash_string_list(target_names), "targets_sha256": hashlib.sha256(targets_arr.tobytes()).hexdigest(), "lambda_l0": float(lambda_l0), diff --git a/policyengine_us_data/calibration/unified_matrix_builder.py b/policyengine_us_data/calibration/unified_matrix_builder.py index d1d2ca535..e0befae57 100644 --- a/policyengine_us_data/calibration/unified_matrix_builder.py +++ b/policyengine_us_data/calibration/unified_matrix_builder.py @@ -95,6 +95,28 @@ def _compute_reform_household_values( return reform_hh +def _merged_person_constraint_dtype( + state_values: dict, + states: np.ndarray, + variable: str, +): + """Pick a merge dtype that preserves string-valued constraints.""" + dtypes = [] + for state in states: + person = state_values[int(state)]["person"] + if variable in person: + dtypes.append(np.asarray(person[variable]).dtype) + if not dtypes: + return None + merged_dtype = np.result_type(*dtypes) + if np.issubdtype(merged_dtype, np.number) or np.issubdtype( + merged_dtype, + np.bool_, + ): + return np.float32 + return merged_dtype + + def _compute_single_state( dataset_path: str, time_period: int, @@ -480,9 +502,14 @@ def _assemble_clone_values_standalone( person_vars: dict = {} for var in constraint_vars: - if var not in state_values[unique_clone_states[0]]["person"]: + dtype = _merged_person_constraint_dtype( + state_values, + unique_person_states, + var, + ) + if dtype is None: continue - arr = np.empty(n_persons, dtype=np.float32) + arr = np.empty(n_persons, dtype=dtype) for state in unique_person_states: mask = person_state_masks[int(state)] arr[mask] = state_values[int(state)]["person"][var][mask] @@ -1484,9 +1511,14 @@ def _assemble_clone_values( person_vars = {} for var in constraint_vars: - if var not in state_values[unique_clone_states[0]]["person"]: + dtype = _merged_person_constraint_dtype( + state_values, + unique_person_states, + var, + ) + if dtype is None: continue - arr = np.empty(n_persons, dtype=np.float32) + arr = np.empty(n_persons, dtype=dtype) for state in unique_person_states: mask = person_state_masks[int(state)] arr[mask] = state_values[int(state)]["person"][var][mask] diff --git a/policyengine_us_data/datasets/cps/cps.py b/policyengine_us_data/datasets/cps/cps.py index 30e4d2fae..f9052f7a7 100644 --- a/policyengine_us_data/datasets/cps/cps.py +++ b/policyengine_us_data/datasets/cps/cps.py @@ -137,7 +137,12 @@ def generate(self): ) logging.info("Imputing ITIN status") has_itin_number = impute_itin_status(cps, person, ssn_card_type) - _store_identification_variables(cps, ssn_card_type, has_itin_number) + _store_identification_variables( + cps, + person, + ssn_card_type, + has_itin_number, + ) logging.info("Adding family variables") add_spm_variables(self, cps, spm_unit) logging.info("Adding household variables") diff --git a/policyengine_us_data/utils/huggingface.py b/policyengine_us_data/utils/huggingface.py index 4b80a29a2..187f303f8 100644 --- a/policyengine_us_data/utils/huggingface.py +++ b/policyengine_us_data/utils/huggingface.py @@ -82,6 +82,7 @@ def download_calibration_inputs( optional_files = { "weights": f"calibration/{prefix}calibration_weights.npy", "geography": f"calibration/{prefix}geography_assignment.npz", + "checkpoint": f"calibration/{prefix}calibration_weights.checkpoint.pt", "run_config": (f"calibration/{prefix}unified_run_config.json"), } for key, hf_path in optional_files.items(): @@ -154,6 +155,7 @@ def download_calibration_logs( def upload_calibration_artifacts( weights_path: str = None, geography_path: str = None, + checkpoint_path: str = None, log_dir: str = None, repo: str = "policyengine/policyengine-us-data", prefix: str = "", @@ -189,6 +191,14 @@ def upload_calibration_artifacts( ) ) + if checkpoint_path and os.path.exists(checkpoint_path): + operations.append( + CommitOperationAdd( + path_in_repo=(f"calibration/{prefix}calibration_weights.checkpoint.pt"), + path_or_fileobj=checkpoint_path, + ) + ) + if log_dir: # Upload run config to calibration/ root for artifact validation run_config_local = os.path.join(log_dir, f"{prefix}unified_run_config.json") diff --git a/policyengine_us_data/utils/identification.py b/policyengine_us_data/utils/identification.py index 687ac772e..e744b832a 100644 --- a/policyengine_us_data/utils/identification.py +++ b/policyengine_us_data/utils/identification.py @@ -10,27 +10,84 @@ } -def _derive_has_tin_from_ssn_card_type_codes( +def _derive_has_valid_ssn_from_ssn_card_type_codes( + ssn_card_type: np.ndarray, +) -> np.ndarray: + """Return direct valid-SSN evidence from CPS ID status codes.""" + ssn_card_type = np.asarray(ssn_card_type) + return ssn_card_type == 1 + + +def _derive_taxpayer_id_type_from_identification_flags( + has_valid_ssn: np.ndarray, + has_tin: np.ndarray, +) -> np.ndarray: + """Return statute-facing taxpayer ID classes from ID flags.""" + return np.where( + has_valid_ssn, + "VALID_SSN", + np.where(has_tin, "OTHER_TIN", "NONE"), + ) + + +def _high_confidence_tin_evidence(person: pd.DataFrame) -> np.ndarray: + """Return admin-linked signals that strongly imply TIN possession.""" + social_security = ( + (person.SS_YN == 1) + | np.isin(person.RESNSS1, [1, 2, 3, 4, 5, 6, 7]) + | np.isin(person.RESNSS2, [1, 2, 3, 4, 5, 6, 7]) + ) + medicare = person.MCARE == 1 + federal_pension = np.isin(person.PEN_SC1, [3]) | np.isin(person.PEN_SC2, [3]) + government_worker = np.isin(person.PEIO1COW, [1, 2, 3]) | (person.A_MJOCC == 11) + military_link = (person.MIL == 1) | (person.PEAFEVER == 1) | (person.CHAMPVA == 1) + ssi = person.SSI_YN == 1 + return ( + social_security + | medicare + | federal_pension + | government_worker + | military_link + | ssi + ).to_numpy(dtype=bool) + + +def _derive_has_tin_from_identification_inputs( + person: pd.DataFrame, ssn_card_type: np.ndarray, has_itin_number: np.ndarray | None = None, ) -> np.ndarray: - """Return whether a person has any taxpayer ID from CPS ID status codes.""" - has_ssn = np.asarray(ssn_card_type) != 0 + """Return broad TIN possession without treating proxy codes as direct IDs.""" + has_valid_ssn = _derive_has_valid_ssn_from_ssn_card_type_codes(ssn_card_type) + has_tin = has_valid_ssn.copy() + has_tin |= ~has_valid_ssn & _high_confidence_tin_evidence(person) if has_itin_number is not None: - return has_ssn | np.asarray(has_itin_number) - return has_ssn + has_tin |= np.asarray(has_itin_number, dtype=bool) + return has_tin def _store_identification_variables( cps: dict, + person: pd.DataFrame, ssn_card_type: np.ndarray, has_itin_number: np.ndarray | None = None, ) -> None: """Persist identification inputs used by PolicyEngine US.""" - has_tin = _derive_has_tin_from_ssn_card_type_codes(ssn_card_type, has_itin_number) + has_valid_ssn = _derive_has_valid_ssn_from_ssn_card_type_codes(ssn_card_type) + has_tin = _derive_has_tin_from_identification_inputs( + person=person, + ssn_card_type=ssn_card_type, + has_itin_number=has_itin_number, + ) + taxpayer_id_type = _derive_taxpayer_id_type_from_identification_flags( + has_valid_ssn=has_valid_ssn, + has_tin=has_tin, + ) cps["ssn_card_type"] = ( pd.Series(ssn_card_type).map(SSN_CARD_TYPE_CODE_TO_STR).astype("S").values ) + cps["taxpayer_id_type"] = pd.Series(taxpayer_id_type).astype("S").values cps["has_tin"] = has_tin + cps["has_valid_ssn"] = has_valid_ssn # Temporary compatibility alias while policyengine-us users migrate. cps["has_itin"] = has_tin diff --git a/tests/unit/calibration/test_unified_calibration.py b/tests/unit/calibration/test_unified_calibration.py index da2f79523..61462dd36 100644 --- a/tests/unit/calibration/test_unified_calibration.py +++ b/tests/unit/calibration/test_unified_calibration.py @@ -584,6 +584,43 @@ def test_resume_checkpoint_rejects_incompatible_hyperparams(self, tmp_path): } ) + def test_resume_checkpoint_rejects_changed_matrix_with_same_shape(self, tmp_path): + from policyengine_us_data.calibration.unified_calibration import ( + default_checkpoint_path, + fit_l0_weights, + ) + + weights_path = tmp_path / "weights.npy" + checkpoint_path = default_checkpoint_path(str(weights_path)) + kwargs = self._fit_kwargs(tmp_path) + kwargs["checkpoint_path"] = str(checkpoint_path) + + with patch( + "l0.calibration.SparseCalibrationWeights", + FakeSparseCalibrationWeights, + ): + first_weights = fit_l0_weights(**kwargs) + np.save(weights_path, first_weights) + + changed_matrix = sp.csr_matrix( + np.array( + [ + [0.0, 1.0], + [1.0, 0.0], + ], + dtype=np.float32, + ) + ) + + with pytest.raises(ValueError, match="Checkpoint is incompatible"): + fit_l0_weights( + **{ + **kwargs, + "X_sparse": changed_matrix, + "resume_from": str(checkpoint_path), + } + ) + class TestGeographyAssignmentCountyFips: """Verify county_fips field on GeographyAssignment.""" diff --git a/tests/unit/calibration/test_unified_matrix_builder_merge.py b/tests/unit/calibration/test_unified_matrix_builder_merge.py new file mode 100644 index 000000000..d7bb1bb25 --- /dev/null +++ b/tests/unit/calibration/test_unified_matrix_builder_merge.py @@ -0,0 +1,69 @@ +import numpy as np + +from policyengine_us_data.calibration.unified_matrix_builder import ( + UnifiedMatrixBuilder, + _assemble_clone_values_standalone, +) + + +def _state_values_for_string_constraint(): + return { + 1: { + "hh": {}, + "person": { + "ssn_card_type": np.array( + [b"CITIZEN", b"", b""], + dtype="S24", + ), + }, + "reform_hh": {}, + }, + 2: { + "hh": {}, + "person": { + "ssn_card_type": np.array( + [ + b"", + b"NON_CITIZEN_VALID_EAD", + b"OTHER_NON_CITIZEN", + ], + dtype="S24", + ), + }, + "reform_hh": {}, + }, + } + + +def test_assemble_clone_values_standalone_preserves_string_constraints(): + _, person_vars, _ = _assemble_clone_values_standalone( + state_values=_state_values_for_string_constraint(), + clone_states=np.array([1, 2, 2]), + person_hh_indices=np.array([0, 1, 2]), + target_vars=set(), + constraint_vars={"ssn_card_type"}, + ) + + assert person_vars["ssn_card_type"].tolist() == [ + b"CITIZEN", + b"NON_CITIZEN_VALID_EAD", + b"OTHER_NON_CITIZEN", + ] + + +def test_builder_assemble_clone_values_preserves_string_constraints(): + builder = UnifiedMatrixBuilder.__new__(UnifiedMatrixBuilder) + + _, person_vars, _ = builder._assemble_clone_values( + state_values=_state_values_for_string_constraint(), + clone_states=np.array([1, 2, 2]), + person_hh_indices=np.array([0, 1, 2]), + target_vars=set(), + constraint_vars={"ssn_card_type"}, + ) + + assert person_vars["ssn_card_type"].tolist() == [ + b"CITIZEN", + b"NON_CITIZEN_VALID_EAD", + b"OTHER_NON_CITIZEN", + ] diff --git a/tests/unit/datasets/test_cps_identification.py b/tests/unit/datasets/test_cps_identification.py index cb3250e8c..a741ae462 100644 --- a/tests/unit/datasets/test_cps_identification.py +++ b/tests/unit/datasets/test_cps_identification.py @@ -1,63 +1,119 @@ import numpy as np +import pandas as pd from policyengine_us_data.utils.identification import ( - _derive_has_tin_from_ssn_card_type_codes, + _derive_has_tin_from_identification_inputs, + _derive_has_valid_ssn_from_ssn_card_type_codes, + _derive_taxpayer_id_type_from_identification_flags, + _high_confidence_tin_evidence, _store_identification_variables, ) -def test_derive_has_tin_from_ssn_card_type_codes(): - result = _derive_has_tin_from_ssn_card_type_codes(np.array([0, 1, 2, 3])) +def _person_fixture(**overrides): + n = max((len(value) for value in overrides.values()), default=4) + defaults = { + "SS_YN": np.zeros(n, dtype=int), + "RESNSS1": np.zeros(n, dtype=int), + "RESNSS2": np.zeros(n, dtype=int), + "MCARE": np.zeros(n, dtype=int), + "PEN_SC1": np.zeros(n, dtype=int), + "PEN_SC2": np.zeros(n, dtype=int), + "PEIO1COW": np.zeros(n, dtype=int), + "A_MJOCC": np.zeros(n, dtype=int), + "MIL": np.zeros(n, dtype=int), + "PEAFEVER": np.zeros(n, dtype=int), + "CHAMPVA": np.zeros(n, dtype=int), + "SSI_YN": np.zeros(n, dtype=int), + } + defaults.update(overrides) + return pd.DataFrame(defaults) + + +def test_derive_has_valid_ssn_from_ssn_card_type_codes(): + result = _derive_has_valid_ssn_from_ssn_card_type_codes( + np.array([0, 1, 2, 3]), + ) np.testing.assert_array_equal( result, - np.array([False, True, True, True], dtype=bool), + np.array([False, True, False, False], dtype=bool), + ) + + +def test_derive_taxpayer_id_type_from_identification_flags(): + result = _derive_taxpayer_id_type_from_identification_flags( + has_valid_ssn=np.array([False, True, False]), + has_tin=np.array([False, True, True]), ) + assert result.tolist() == ["NONE", "VALID_SSN", "OTHER_TIN"] + + +def test_high_confidence_admin_signal_gets_tin(): + person = _person_fixture(SS_YN=np.array([1, 0]), MCARE=np.array([0, 1])) + + result = _high_confidence_tin_evidence(person) -def test_derive_has_tin_with_itin(): - ssn_codes = np.array([0, 0, 1, 2, 3]) - has_itin = np.array([True, False, False, False, False]) + np.testing.assert_array_equal(result, np.array([True, True])) - result = _derive_has_tin_from_ssn_card_type_codes(ssn_codes, has_itin) - # code-0 with ITIN → True, code-0 without ITIN → False, others → True +def test_derive_has_tin_from_identification_inputs_is_conservative(): + person = _person_fixture(SS_YN=np.zeros(5, dtype=int)) + result = _derive_has_tin_from_identification_inputs( + person=person, + ssn_card_type=np.array([0, 1, 2, 3, 0]), + has_itin_number=np.array([False, False, False, False, True]), + ) + np.testing.assert_array_equal( result, - np.array([True, False, True, True, True], dtype=bool), + np.array([False, True, False, False, True], dtype=bool), ) -def test_store_identification_variables_writes_has_tin_and_alias(): - cps = {} +def test_other_non_citizen_with_admin_signal_gets_tin(): + person = _person_fixture(SS_YN=np.array([1])) + result = _derive_has_tin_from_identification_inputs( + person=person, + ssn_card_type=np.array([3]), + ) + + np.testing.assert_array_equal(result, np.array([True])) + - _store_identification_variables(cps, np.array([0, 1, 2, 3])) +def test_store_identification_variables_writes_id_primitives(): + cps = {} + person = _person_fixture(SS_YN=np.zeros(5, dtype=int)) + has_itin = np.array([False, False, False, False, True]) + + _store_identification_variables( + cps, + person, + np.array([0, 1, 2, 3, 0]), + has_itin, + ) assert cps["ssn_card_type"].tolist() == [ b"NONE", b"CITIZEN", b"NON_CITIZEN_VALID_EAD", b"OTHER_NON_CITIZEN", + b"NONE", + ] + assert cps["taxpayer_id_type"].tolist() == [ + b"NONE", + b"VALID_SSN", + b"NONE", + b"NONE", + b"OTHER_TIN", ] np.testing.assert_array_equal( cps["has_tin"], - np.array([False, True, True, True], dtype=bool), + np.array([False, True, False, False, True], dtype=bool), + ) + np.testing.assert_array_equal( + cps["has_valid_ssn"], + np.array([False, True, False, False, False], dtype=bool), ) - np.testing.assert_array_equal(cps["has_itin"], cps["has_tin"]) - - -def test_store_identification_variables_with_itin(): - cps = {} - ssn_codes = np.array([0, 0, 1, 2, 3]) - has_itin = np.array([True, False, False, False, False]) - - _store_identification_variables(cps, ssn_codes, has_itin) - - # code-0 with ITIN gets has_tin=True - assert cps["has_tin"][0] == True # noqa: E712 - # code-0 without ITIN gets has_tin=False - assert cps["has_tin"][1] == False # noqa: E712 - # SSN holders still have has_tin=True - assert cps["has_tin"][2] == True # noqa: E712 - # alias still matches np.testing.assert_array_equal(cps["has_itin"], cps["has_tin"]) diff --git a/tests/unit/test_remote_calibration_runner.py b/tests/unit/test_remote_calibration_runner.py new file mode 100644 index 000000000..2d0196263 --- /dev/null +++ b/tests/unit/test_remote_calibration_runner.py @@ -0,0 +1,76 @@ +import importlib +import sys +from types import ModuleType, SimpleNamespace + + +def _load_remote_calibration_runner_module(): + fake_modal = ModuleType("modal") + + class _FakeApp: + def __init__(self, *args, **kwargs): + pass + + def function(self, *args, **kwargs): + def decorator(func): + return func + + return decorator + + def local_entrypoint(self, *args, **kwargs): + def decorator(func): + return func + + return decorator + + fake_modal.App = _FakeApp + fake_modal.Secret = SimpleNamespace(from_name=lambda *args, **kwargs: object()) + fake_modal.Volume = SimpleNamespace(from_name=lambda *args, **kwargs: object()) + + fake_images = ModuleType("modal_app.images") + fake_images.gpu_image = object() + + sys.modules["modal"] = fake_modal + sys.modules["modal_app.images"] = fake_images + sys.modules.pop("modal_app.remote_calibration_runner", None) + return importlib.import_module("modal_app.remote_calibration_runner") + + +def test_collect_outputs_reads_checkpoint_bytes(tmp_path): + remote_runner = _load_remote_calibration_runner_module() + weights = tmp_path / "weights.npy" + geography = tmp_path / "geography.npz" + log_path = tmp_path / "diag.csv" + cal_log = tmp_path / "calibration.csv" + config = tmp_path / "config.json" + checkpoint = tmp_path / "weights.checkpoint.pt" + + paths_and_bytes = { + weights: b"weights", + geography: b"geography", + log_path: b"log", + cal_log: b"cal-log", + config: b"config", + checkpoint: b"checkpoint", + } + for path, content in paths_and_bytes.items(): + path.write_bytes(content) + + result = remote_runner._collect_outputs( + [ + f"OUTPUT_PATH:{weights}", + f"GEOGRAPHY_PATH:{geography}", + f"LOG_PATH:{log_path}", + f"CAL_LOG_PATH:{cal_log}", + f"CONFIG_PATH:{config}", + f"CHECKPOINT_PATH:{checkpoint}", + ] + ) + + assert result == { + "weights": b"weights", + "geography": b"geography", + "log": b"log", + "cal_log": b"cal-log", + "config": b"config", + "checkpoint": b"checkpoint", + } From f6ebcbdc0be9d8921e010a56941fab73e5f63b9c Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 10 Apr 2026 08:09:10 -0400 Subject: [PATCH 08/12] Handle string ID fields in PUF cloning --- policyengine_us_data/calibration/puf_impute.py | 4 +--- .../calibration/test_calibration_puf_impute.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/policyengine_us_data/calibration/puf_impute.py b/policyengine_us_data/calibration/puf_impute.py index b87f846f8..92cdc1b28 100644 --- a/policyengine_us_data/calibration/puf_impute.py +++ b/policyengine_us_data/calibration/puf_impute.py @@ -513,7 +513,7 @@ def _map_to_entity(pred_values, variable_name): elif variable in IMPUTED_VARIABLES and y_full: pred = _map_to_entity(y_full[variable], variable) new_data[variable] = {time_period: np.concatenate([values, pred])} - elif "_id" in variable: + elif "_id" in variable and np.issubdtype(values.dtype, np.number): new_data[variable] = { time_period: np.concatenate([values, values + values.max()]) } @@ -582,7 +582,6 @@ def _impute_weeks_unemployed( except (ValueError, KeyError): logger.warning("weeks_unemployed not in CPS, returning zeros") n_persons = len(data["person_id"][time_period]) - del cps_sim return np.zeros(n_persons) WEEKS_PREDICTORS = [ @@ -685,7 +684,6 @@ def _impute_retirement_contributions( except (ValueError, KeyError) as e: logger.warning("Could not build retirement training data: %s", e) n_persons = len(data["person_id"][time_period]) - del cps_sim return {var: np.zeros(n_persons) for var in CPS_RETIREMENT_VARIABLES} # Build test data: demographics from CPS sim, income from PUF diff --git a/tests/unit/calibration/test_calibration_puf_impute.py b/tests/unit/calibration/test_calibration_puf_impute.py index d803486ee..2596ab61f 100644 --- a/tests/unit/calibration/test_calibration_puf_impute.py +++ b/tests/unit/calibration/test_calibration_puf_impute.py @@ -78,6 +78,24 @@ def test_ids_are_unique(self): assert len(np.unique(person_ids)) == len(person_ids) assert len(np.unique(household_ids)) == len(household_ids) + def test_string_id_like_variables_are_duplicated_without_numeric_offset(self): + data = _make_mock_data(n_persons=20, n_households=5) + data["taxpayer_id_type"] = { + 2024: np.array([b"VALID_SSN", b"NONE"] * 10, dtype="S9") + } + state_fips = np.array([1, 2, 36, 6, 48]) + + result = puf_clone_dataset( + data=data, + state_fips=state_fips, + time_period=2024, + skip_qrf=True, + ) + + values = result["taxpayer_id_type"][2024] + n = len(values) // 2 + np.testing.assert_array_equal(values[:n], values[n:]) + def test_puf_half_weight_zero(self): data = _make_mock_data(n_persons=20, n_households=5) state_fips = np.array([1, 2, 36, 6, 48]) From c1158c6fa9f74436807d2c816926713640f33c59 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 10 Apr 2026 08:42:57 -0400 Subject: [PATCH 09/12] Fold taxpayer ID imputation into calibration resume PR --- modal_app/remote_calibration_runner.py | 46 +++++ .../calibration/calibration_utils.py | 5 + .../calibration/unified_calibration.py | 9 +- policyengine_us_data/datasets/cps/cps.py | 84 +------- policyengine_us_data/utils/identification.py | 189 +++++++++++++++++- tests/integration/test_enhanced_cps.py | 32 ++- tests/integration/test_sparse_enhanced_cps.py | 32 ++- .../calibration/test_unified_calibration.py | 31 +++ .../test_unified_matrix_builder_merge.py | 10 + .../unit/datasets/test_cps_identification.py | 144 ++++++++++++- tests/unit/test_remote_calibration_runner.py | 40 ++++ 11 files changed, 509 insertions(+), 113 deletions(-) diff --git a/modal_app/remote_calibration_runner.py b/modal_app/remote_calibration_runner.py index c939eebec..dfc5410e1 100644 --- a/modal_app/remote_calibration_runner.py +++ b/modal_app/remote_calibration_runner.py @@ -65,6 +65,13 @@ def _append_hyperparams(cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq cmd.extend(["--log-freq", str(log_freq)]) +def _append_checkpoint_args(cmd, checkpoint_path: str): + """Save checkpoints on the mounted volume and resume when present.""" + if os.path.exists(checkpoint_path): + cmd.extend(["--resume-from", checkpoint_path]) + cmd.extend(["--checkpoint-output", checkpoint_path]) + + def _collect_outputs(cal_lines): """Extract weights and log bytes from calibration output lines.""" output_path = None @@ -175,6 +182,7 @@ def _fit_weights_impl( skip_county: bool = True, workers: int = 8, artifacts_dir: str = "", + checkpoint_name: str = "calibration_weights.checkpoint.pt", ) -> dict: """Full pipeline: read data from pipeline volume, build matrix, fit.""" _setup_repo() @@ -183,6 +191,7 @@ def _fit_weights_impl( artifacts = artifacts_dir if artifacts_dir else f"{PIPELINE_MOUNT}/artifacts" db_path = f"{artifacts}/policy_data.db" dataset_path = f"{artifacts}/source_imputed_stratified_extended_cps.h5" + checkpoint_path = f"{artifacts}/{checkpoint_name}" for label, p in [("database", db_path), ("dataset", dataset_path)]: if not os.path.exists(p): raise RuntimeError( @@ -211,6 +220,7 @@ def _fit_weights_impl( if workers > 1: cmd.extend(["--workers", str(workers)]) _append_hyperparams(cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq) + _append_checkpoint_args(cmd, checkpoint_path) cal_rc, cal_lines = _run_streaming( cmd, @@ -218,7 +228,11 @@ def _fit_weights_impl( label="calibrate", ) if cal_rc != 0: + if os.path.exists(checkpoint_path): + pipeline_vol.commit() raise RuntimeError(f"Script failed with code {cal_rc}") + if os.path.exists(checkpoint_path): + pipeline_vol.commit() return _collect_outputs(cal_lines) @@ -233,6 +247,7 @@ def _fit_from_package_impl( lambda_l2: float = None, learning_rate: float = None, log_freq: int = None, + checkpoint_name: str = "calibration_weights.checkpoint.pt", ) -> dict: """Fit weights from a pre-built calibration package.""" if not volume_package_path: @@ -240,6 +255,8 @@ def _fit_from_package_impl( _setup_repo() + artifacts = os.path.dirname(volume_package_path) or f"{PIPELINE_MOUNT}/artifacts" + checkpoint_path = f"{artifacts}/{checkpoint_name}" pkg_path = "/root/calibration_package.pkl" import shutil @@ -266,6 +283,7 @@ def _fit_from_package_impl( if target_config: cmd.extend(["--target-config", target_config]) _append_hyperparams(cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq) + _append_checkpoint_args(cmd, checkpoint_path) print(f"Running command: {' '.join(cmd)}", flush=True) @@ -275,7 +293,11 @@ def _fit_from_package_impl( label="calibrate", ) if cal_rc != 0: + if os.path.exists(checkpoint_path): + pipeline_vol.commit() raise RuntimeError(f"Script failed with code {cal_rc}") + if os.path.exists(checkpoint_path): + pipeline_vol.commit() return _collect_outputs(cal_lines) @@ -511,6 +533,7 @@ def fit_weights_t4( skip_county: bool = True, workers: int = 8, artifacts_dir: str = "", + checkpoint_name: str = "calibration_weights.checkpoint.pt", ) -> dict: return _fit_weights_impl( branch, @@ -524,6 +547,7 @@ def fit_weights_t4( skip_county=skip_county, workers=workers, artifacts_dir=artifacts_dir, + checkpoint_name=checkpoint_name, ) @@ -548,6 +572,7 @@ def fit_weights_a10( skip_county: bool = True, workers: int = 8, artifacts_dir: str = "", + checkpoint_name: str = "calibration_weights.checkpoint.pt", ) -> dict: return _fit_weights_impl( branch, @@ -561,6 +586,7 @@ def fit_weights_a10( skip_county=skip_county, workers=workers, artifacts_dir=artifacts_dir, + checkpoint_name=checkpoint_name, ) @@ -585,6 +611,7 @@ def fit_weights_a100_40( skip_county: bool = True, workers: int = 8, artifacts_dir: str = "", + checkpoint_name: str = "calibration_weights.checkpoint.pt", ) -> dict: return _fit_weights_impl( branch, @@ -598,6 +625,7 @@ def fit_weights_a100_40( skip_county=skip_county, workers=workers, artifacts_dir=artifacts_dir, + checkpoint_name=checkpoint_name, ) @@ -622,6 +650,7 @@ def fit_weights_a100_80( skip_county: bool = True, workers: int = 8, artifacts_dir: str = "", + checkpoint_name: str = "calibration_weights.checkpoint.pt", ) -> dict: return _fit_weights_impl( branch, @@ -635,6 +664,7 @@ def fit_weights_a100_80( skip_county=skip_county, workers=workers, artifacts_dir=artifacts_dir, + checkpoint_name=checkpoint_name, ) @@ -659,6 +689,7 @@ def fit_weights_h100( skip_county: bool = True, workers: int = 8, artifacts_dir: str = "", + checkpoint_name: str = "calibration_weights.checkpoint.pt", ) -> dict: return _fit_weights_impl( branch, @@ -672,6 +703,7 @@ def fit_weights_h100( skip_county=skip_county, workers=workers, artifacts_dir=artifacts_dir, + checkpoint_name=checkpoint_name, ) @@ -705,6 +737,7 @@ def fit_from_package_t4( learning_rate: float = None, log_freq: int = None, volume_package_path: str = None, + checkpoint_name: str = "calibration_weights.checkpoint.pt", ) -> dict: return _fit_from_package_impl( branch, @@ -716,6 +749,7 @@ def fit_from_package_t4( lambda_l2=lambda_l2, learning_rate=learning_rate, log_freq=log_freq, + checkpoint_name=checkpoint_name, ) @@ -737,6 +771,7 @@ def fit_from_package_a10( learning_rate: float = None, log_freq: int = None, volume_package_path: str = None, + checkpoint_name: str = "calibration_weights.checkpoint.pt", ) -> dict: return _fit_from_package_impl( branch, @@ -748,6 +783,7 @@ def fit_from_package_a10( lambda_l2=lambda_l2, learning_rate=learning_rate, log_freq=log_freq, + checkpoint_name=checkpoint_name, ) @@ -769,6 +805,7 @@ def fit_from_package_a100_40( learning_rate: float = None, log_freq: int = None, volume_package_path: str = None, + checkpoint_name: str = "calibration_weights.checkpoint.pt", ) -> dict: return _fit_from_package_impl( branch, @@ -780,6 +817,7 @@ def fit_from_package_a100_40( lambda_l2=lambda_l2, learning_rate=learning_rate, log_freq=log_freq, + checkpoint_name=checkpoint_name, ) @@ -801,6 +839,7 @@ def fit_from_package_a100_80( learning_rate: float = None, log_freq: int = None, volume_package_path: str = None, + checkpoint_name: str = "calibration_weights.checkpoint.pt", ) -> dict: return _fit_from_package_impl( branch, @@ -812,6 +851,7 @@ def fit_from_package_a100_80( lambda_l2=lambda_l2, learning_rate=learning_rate, log_freq=log_freq, + checkpoint_name=checkpoint_name, ) @@ -833,6 +873,7 @@ def fit_from_package_h100( learning_rate: float = None, log_freq: int = None, volume_package_path: str = None, + checkpoint_name: str = "calibration_weights.checkpoint.pt", ) -> dict: return _fit_from_package_impl( branch, @@ -844,6 +885,7 @@ def fit_from_package_h100( lambda_l2=lambda_l2, learning_rate=learning_rate, log_freq=log_freq, + checkpoint_name=checkpoint_name, ) @@ -878,6 +920,7 @@ def main( national: bool = False, ): prefix = "national_" if national else "" + checkpoint_name = f"{prefix}calibration_weights.checkpoint.pt" if national: if lambda_l0 is None: lambda_l0 = 1e-4 @@ -930,6 +973,7 @@ def main( learning_rate=learning_rate, log_freq=log_freq, volume_package_path=vol_path, + checkpoint_name=checkpoint_name, ) elif full_pipeline: print( @@ -960,6 +1004,7 @@ def main( log_freq=log_freq, skip_county=not county_level, workers=workers, + checkpoint_name=checkpoint_name, ) else: vol_path = f"{PIPELINE_MOUNT}/artifacts/calibration_package.pkl" @@ -1017,6 +1062,7 @@ def main( learning_rate=learning_rate, log_freq=log_freq, volume_package_path=vol_path, + checkpoint_name=checkpoint_name, ) with open(output, "wb") as f: diff --git a/policyengine_us_data/calibration/calibration_utils.py b/policyengine_us_data/calibration/calibration_utils.py index 8af1bab7a..4ffd4c3a9 100644 --- a/policyengine_us_data/calibration/calibration_utils.py +++ b/policyengine_us_data/calibration/calibration_utils.py @@ -6,6 +6,7 @@ import json import numpy as np import pandas as pd +from scipy import sparse from spm_calculator import SPMCalculator, spm_equivalence_scale from spm_calculator.geoadj import calculate_geoadj_from_rent @@ -263,6 +264,10 @@ def apply_op(values: np.ndarray, op: str, val: str) -> np.ndarray: else: parsed = val + values = np.asarray(values) + if values.dtype.kind == "S" and isinstance(parsed, str): + parsed = parsed.encode() + if op in ("==", "="): return values == parsed if op == ">": diff --git a/policyengine_us_data/calibration/unified_calibration.py b/policyengine_us_data/calibration/unified_calibration.py index cf9d0532c..00b495b42 100644 --- a/policyengine_us_data/calibration/unified_calibration.py +++ b/policyengine_us_data/calibration/unified_calibration.py @@ -546,10 +546,13 @@ def checkpoint_signature_mismatches(expected: dict, actual: dict) -> list: """Return human-readable checkpoint compatibility mismatches.""" mismatches = [] float_keys = {"lambda_l0", "beta", "lambda_l2", "learning_rate"} - for key, expected_value in expected.items(): - actual_value = actual.get(key) + for key, actual_value in actual.items(): + expected_value = expected.get(key) + if expected_value is None: + mismatches.append(f"{key} missing from checkpoint") + continue if key in float_keys: - if actual_value is None or not np.isclose(expected_value, actual_value): + if not np.isclose(expected_value, actual_value): mismatches.append( f"{key} expected {expected_value}, got {actual_value}" ) diff --git a/policyengine_us_data/datasets/cps/cps.py b/policyengine_us_data/datasets/cps/cps.py index f9052f7a7..195a63db3 100644 --- a/policyengine_us_data/datasets/cps/cps.py +++ b/policyengine_us_data/datasets/cps/cps.py @@ -135,13 +135,12 @@ def generate(self): undocumented_workers_target=8.3e6, undocumented_students_target=0.21 * 1.9e6, ) - logging.info("Imputing ITIN status") - has_itin_number = impute_itin_status(cps, person, ssn_card_type) + logging.info("Adding taxpayer ID variables") _store_identification_variables( cps, person, ssn_card_type, - has_itin_number, + self.time_period, ) logging.info("Adding family variables") add_spm_variables(self, cps, spm_unit) @@ -1761,85 +1760,6 @@ def get_arrival_year_midpoint(peinusyr): return ssn_card_type -def impute_itin_status( - cps: dict, - person: pd.DataFrame, - ssn_card_type: np.ndarray, - itin_returns_target: float = 4.4e6, - random_seed: int = 98765, -) -> np.ndarray: - """Impute which undocumented (code-0) persons hold ITINs. - - The 4.4M target is ITIN *returns* (tax units), not persons. - Eligible tax units are those with at least one code-0 member who has - employment or self-employment income. This is a provisional rule — - it will miss some ITIN filers with no current earnings and include - some units that wouldn't file. - """ - n_persons = len(person) - has_itin_number = np.zeros(n_persons, dtype=bool) - - person_tax_unit_ids = np.asarray(cps["person_tax_unit_id"]) - is_code_0 = ssn_card_type == 0 - has_earnings = (person.WSAL_VAL > 0) | (person.SEMP_VAL > 0) - - # Identify tax units with at least one code-0 earner - eligible_person_mask = is_code_0 & has_earnings - eligible_tu_ids = np.unique(person_tax_unit_ids[eligible_person_mask]) - - # Build tax-unit weight lookup from household weights - household_ids = np.asarray(cps["household_id"]) - household_weights = np.asarray(cps["household_weight"]) - person_household_ids = np.asarray(cps["person_household_id"]) - hh_to_weight = dict(zip(household_ids, household_weights)) - - # Map each tax unit to its weight (weight of first member's household) - tu_to_weight = {} - for i, tu_id in enumerate(person_tax_unit_ids): - if tu_id not in tu_to_weight: - tu_to_weight[tu_id] = hh_to_weight.get(person_household_ids[i], 0) - - eligible_weights = np.array( - [tu_to_weight.get(tu_id, 0) for tu_id in eligible_tu_ids] - ) - total_eligible_weighted = eligible_weights.sum() - - if total_eligible_weighted <= 0: - print("ITIN imputation: no eligible tax units found") - return has_itin_number - - # Weighted random selection of tax units to hit the 4.4M target - target_share = min(itin_returns_target / total_eligible_weighted, 1.0) - rng = np.random.default_rng(seed=random_seed) - selected_mask = rng.random(len(eligible_tu_ids)) < target_share - selected_tu_ids = set(eligible_tu_ids[selected_mask]) - - # Mark all code-0 members of selected tax units as ITIN holders - for i in range(n_persons): - if is_code_0[i] and person_tax_unit_ids[i] in selected_tu_ids: - has_itin_number[i] = True - - # Logging - person_weights = np.array( - [hh_to_weight.get(hh_id, 0) for hh_id in person_household_ids] - ) - selected_tu_weighted = sum(tu_to_weight.get(tu_id, 0) for tu_id in selected_tu_ids) - itin_filers = has_itin_number & has_earnings - itin_total = has_itin_number - print("\nITIN imputation:") - print( - f" Weighted ITIN filer tax units: {selected_tu_weighted:,.0f} " - f"(target: {itin_returns_target:,.0f})" - ) - print( - f" Weighted code-0 earners with ITIN: " - f"{np.sum(person_weights[itin_filers]):,.0f}" - ) - print(f" Total weighted ITIN persons: {np.sum(person_weights[itin_total]):,.0f}") - - return has_itin_number - - def _update_documentation_with_numbers(log_df, docs_dir): """Update the documentation file with actual population numbers from CSV""" doc_path = docs_dir / "SSN_statuses_imputation.ipynb" diff --git a/policyengine_us_data/utils/identification.py b/policyengine_us_data/utils/identification.py index e744b832a..032506603 100644 --- a/policyengine_us_data/utils/identification.py +++ b/policyengine_us_data/utils/identification.py @@ -2,6 +2,12 @@ import pandas as pd +NON_SSN_FILER_TIN_TARGET_BY_YEAR = { + # Latest available public IRS/TAS figure: about 3.8M TY 2023 returns + # included an ITIN. Use it as a recent proxy for non-SSN filer TINs. + 2024: 3.8e6, +} + SSN_CARD_TYPE_CODE_TO_STR = { 0: "NONE", 1: "CITIZEN", @@ -18,6 +24,11 @@ def _derive_has_valid_ssn_from_ssn_card_type_codes( return ssn_card_type == 1 +def _impute_has_valid_ssn(ssn_card_type: np.ndarray) -> np.ndarray: + """Impute valid SSNs without treating EAD or documented-status proxies as IDs.""" + return _derive_has_valid_ssn_from_ssn_card_type_codes(ssn_card_type) + + def _derive_taxpayer_id_type_from_identification_flags( has_valid_ssn: np.ndarray, has_tin: np.ndarray, @@ -52,6 +63,52 @@ def _high_confidence_tin_evidence(person: pd.DataFrame) -> np.ndarray: ).to_numpy(dtype=bool) +def _person_weights(cps: dict) -> np.ndarray: + """Return person weights from household IDs and weights.""" + household_to_weight = dict(zip(cps["household_id"], cps["household_weight"])) + return np.array( + [ + household_to_weight.get(household_id, 0) + for household_id in cps["person_household_id"] + ], + dtype=float, + ) + + +def _proxy_tax_unit_filers( + person_tax_unit_ids: np.ndarray, + age: np.ndarray, +) -> np.ndarray: + """Proxy tax-unit head/spouse as the two oldest adults in each tax unit.""" + person_tax_unit_ids = np.asarray(person_tax_unit_ids) + age = np.asarray(age) + adult = age >= 18 + ranks = pd.Series(np.inf, index=np.arange(len(age)), dtype=float) + if adult.any(): + adults = pd.DataFrame( + { + "tax_unit_id": person_tax_unit_ids[adult], + "age": age[adult], + }, + index=np.flatnonzero(adult), + ) + ranks.loc[adults.index] = adults.groupby("tax_unit_id")["age"].rank( + method="first", + ascending=False, + ) + return adult & (ranks.to_numpy() <= 2) + + +def _aggregate_by_tax_unit( + values: np.ndarray, + tax_unit_index: np.ndarray, + n_tax_units: int, +) -> np.ndarray: + total = np.zeros(n_tax_units, dtype=float) + np.add.at(total, tax_unit_index, values) + return total + + def _derive_has_tin_from_identification_inputs( person: pd.DataFrame, ssn_card_type: np.ndarray, @@ -66,18 +123,138 @@ def _derive_has_tin_from_identification_inputs( return has_tin +def _impute_has_tin( + cps: dict, + person: pd.DataFrame, + ssn_card_type: np.ndarray, + time_period: int, + non_ssn_filer_tin_target: float | None = None, + has_valid_ssn: np.ndarray | None = None, +) -> np.ndarray: + """Impute broad TIN possession without treating legal-status proxies as IDs.""" + ssn_card_type = np.asarray(ssn_card_type) + if has_valid_ssn is None: + has_valid_ssn = _impute_has_valid_ssn(ssn_card_type) + has_tin = has_valid_ssn.copy() + + high_confidence_tin = ~has_valid_ssn & _high_confidence_tin_evidence(person) + has_tin |= high_confidence_tin + + target = non_ssn_filer_tin_target + if target is None: + target = NON_SSN_FILER_TIN_TARGET_BY_YEAR.get(time_period) + if target is None or target <= 0: + return has_tin + + age = np.asarray(cps["age"]) + person_tax_unit_ids = np.asarray(cps["person_tax_unit_id"]) + tax_unit_ids, person_tax_unit_index = np.unique( + person_tax_unit_ids, + return_inverse=True, + ) + n_tax_units = len(tax_unit_ids) + person_weights = _person_weights(cps) + tax_unit_weights = np.zeros(n_tax_units, dtype=float) + np.maximum.at(tax_unit_weights, person_tax_unit_index, person_weights) + + proxy_filer = _proxy_tax_unit_filers(person_tax_unit_ids, age) + non_ssn_proxy_filer = proxy_filer & ~has_valid_ssn + + current_non_ssn_tin_units = np.zeros(n_tax_units, dtype=bool) + np.logical_or.at( + current_non_ssn_tin_units, + person_tax_unit_index, + non_ssn_proxy_filer & has_tin, + ) + current_weighted_units = tax_unit_weights[current_non_ssn_tin_units].sum() + additional_target = target - current_weighted_units + if additional_target <= 0: + return has_tin + + employment_income = np.asarray(cps.get("employment_income", np.zeros(len(age)))) + self_employment_income = np.asarray( + cps.get("self_employment_income", np.zeros(len(age))) + ) + prior_year_income = np.asarray( + cps.get("employment_income_last_year", np.zeros(len(age))) + ) + np.asarray(cps.get("self_employment_income_last_year", np.zeros(len(age)))) + + has_filing_income = ( + (employment_income > 0) | (self_employment_income > 0) | (prior_year_income > 0) + ) + candidate_person = ( + non_ssn_proxy_filer & ~has_tin & (ssn_card_type == 0) & has_filing_income + ) + candidate_units = np.zeros(n_tax_units, dtype=bool) + np.logical_or.at(candidate_units, person_tax_unit_index, candidate_person) + if not candidate_units.any(): + return has_tin + + unit_employment_income = _aggregate_by_tax_unit( + np.maximum(employment_income, 0), + person_tax_unit_index, + n_tax_units, + ) + unit_self_employment_income = _aggregate_by_tax_unit( + np.maximum(self_employment_income, 0), + person_tax_unit_index, + n_tax_units, + ) + unit_prior_year_income = _aggregate_by_tax_unit( + np.maximum(prior_year_income, 0), + person_tax_unit_index, + n_tax_units, + ) + unit_non_ssn_filer_count = _aggregate_by_tax_unit( + candidate_person.astype(float), + person_tax_unit_index, + n_tax_units, + ) + unit_has_minor = np.zeros(n_tax_units, dtype=bool) + np.logical_or.at(unit_has_minor, person_tax_unit_index, age < 18) + + score = ( + 4.0 * (unit_self_employment_income > 0) + + 2.0 * (unit_employment_income > 0) + + 1.0 * (unit_prior_year_income > 0) + + 1.0 * unit_has_minor + + 0.5 * (unit_non_ssn_filer_count > 1) + ) + + candidate_idx = np.flatnonzero(candidate_units) + rng = np.random.default_rng(seed=17_000 + int(time_period)) + priority = score[candidate_idx] + rng.random(len(candidate_idx)) * 0.01 + ordered_idx = candidate_idx[np.argsort(-priority)] + + selected_units = np.zeros(n_tax_units, dtype=bool) + cumulative_weight = 0.0 + for tax_unit_index in ordered_idx: + if cumulative_weight >= additional_target: + break + selected_units[tax_unit_index] = True + cumulative_weight += tax_unit_weights[tax_unit_index] + + selected_person_unit = selected_units[person_tax_unit_index] + selected_non_ssn_filers = selected_person_unit & non_ssn_proxy_filer + selected_minor_dependents = selected_person_unit & ~proxy_filer & (age < 18) + has_tin |= selected_non_ssn_filers | (selected_minor_dependents & ~has_valid_ssn) + return has_tin + + def _store_identification_variables( cps: dict, person: pd.DataFrame, ssn_card_type: np.ndarray, - has_itin_number: np.ndarray | None = None, + time_period: int, ) -> None: """Persist identification inputs used by PolicyEngine US.""" - has_valid_ssn = _derive_has_valid_ssn_from_ssn_card_type_codes(ssn_card_type) - has_tin = _derive_has_tin_from_identification_inputs( - person=person, - ssn_card_type=ssn_card_type, - has_itin_number=has_itin_number, + has_valid_ssn = _impute_has_valid_ssn(ssn_card_type) + has_tin = _impute_has_tin( + cps, + person, + ssn_card_type, + time_period, + has_valid_ssn=has_valid_ssn, ) taxpayer_id_type = _derive_taxpayer_id_type_from_identification_flags( has_valid_ssn=has_valid_ssn, diff --git a/tests/integration/test_enhanced_cps.py b/tests/integration/test_enhanced_cps.py index 5f4d897a3..8faa87502 100644 --- a/tests/integration/test_enhanced_cps.py +++ b/tests/integration/test_enhanced_cps.py @@ -8,6 +8,16 @@ def _period_array(period_values, period): return period_values.get(period, period_values[str(period)]) +def _require_identification_fields(data): + required_fields = ("has_tin", "has_itin", "has_valid_ssn", "taxpayer_id_type") + missing = [field for field in required_fields if field not in data] + if missing: + pytest.skip( + "enhanced_cps_2024.h5 fixture predates raw identification fields: " + + ", ".join(missing) + ) + + @pytest.fixture(scope="module") def ecps_sim(): from policyengine_us_data.datasets.cps import EnhancedCPS_2024 @@ -226,18 +236,26 @@ def test_undocumented_matches_ssn_none(): def test_has_tin_matches_identification_inputs(ecps_sim): data = ecps_sim.dataset.load_dataset() + _require_identification_fields(data) has_tin = _period_array(data["has_tin"], 2024) has_itin = _period_array(data["has_itin"], 2024) + has_valid_ssn = _period_array(data["has_valid_ssn"], 2024) ssn_card_type = _period_array(data["ssn_card_type"], 2024).astype(str) + taxpayer_id_type = _period_array(data["taxpayer_id_type"], 2024).astype(str) - # has_itin is still an alias for has_tin np.testing.assert_array_equal(has_itin, has_tin) - # Everyone with an SSN card has a TIN - assert has_tin[ssn_card_type != "NONE"].all() - # Some code-0 (NONE) people have TINs via ITIN - none_mask = ssn_card_type == "NONE" - assert none_mask.any(), "Expected some ssn_card_type == NONE" - assert has_tin[none_mask].any(), "Expected some ITIN holders among code-0" + np.testing.assert_array_equal(has_valid_ssn, taxpayer_id_type == "VALID_SSN") + np.testing.assert_array_equal(has_tin, taxpayer_id_type != "NONE") + assert np.all(has_tin[has_valid_ssn]) + np.testing.assert_array_equal(has_valid_ssn[ssn_card_type == "NONE"], False) + np.testing.assert_array_equal( + taxpayer_id_type, + np.where( + has_valid_ssn, + "VALID_SSN", + np.where(has_tin, "OTHER_TIN", "NONE"), + ), + ) def test_aca_calibration(): diff --git a/tests/integration/test_sparse_enhanced_cps.py b/tests/integration/test_sparse_enhanced_cps.py index f5474fd56..488dda666 100644 --- a/tests/integration/test_sparse_enhanced_cps.py +++ b/tests/integration/test_sparse_enhanced_cps.py @@ -21,6 +21,16 @@ def _period_array(period_values, period): return period_values.get(period, period_values[str(period)]) +def _require_identification_fields(data): + required_fields = ("has_tin", "has_itin", "has_valid_ssn", "taxpayer_id_type") + missing = [field for field in required_fields if field not in data] + if missing: + pytest.skip( + "enhanced_cps_2024.h5 fixture predates raw identification fields: " + + ", ".join(missing) + ) + + @pytest.fixture(scope="session") def data(): return Dataset.from_file(STORAGE_FOLDER / "enhanced_cps_2024.h5") @@ -210,18 +220,26 @@ def test_sparse_ssn_card_type_none_target(sim): def test_sparse_has_tin_matches_identification_inputs(sim): data = sim.dataset.load_dataset() + _require_identification_fields(data) has_tin = _period_array(data["has_tin"], 2024) has_itin = _period_array(data["has_itin"], 2024) + has_valid_ssn = _period_array(data["has_valid_ssn"], 2024) ssn_card_type = _period_array(data["ssn_card_type"], 2024).astype(str) + taxpayer_id_type = _period_array(data["taxpayer_id_type"], 2024).astype(str) - # has_itin is still an alias for has_tin np.testing.assert_array_equal(has_itin, has_tin) - # Everyone with an SSN card has a TIN - assert has_tin[ssn_card_type != "NONE"].all() - # Some code-0 (NONE) people have TINs via ITIN - none_mask = ssn_card_type == "NONE" - assert none_mask.any(), "Expected some ssn_card_type == NONE" - assert has_tin[none_mask].any(), "Expected some ITIN holders among code-0" + np.testing.assert_array_equal(has_valid_ssn, taxpayer_id_type == "VALID_SSN") + np.testing.assert_array_equal(has_tin, taxpayer_id_type != "NONE") + assert np.all(has_tin[has_valid_ssn]) + np.testing.assert_array_equal(has_valid_ssn[ssn_card_type == "NONE"], False) + np.testing.assert_array_equal( + taxpayer_id_type, + np.where( + has_valid_ssn, + "VALID_SSN", + np.where(has_tin, "OTHER_TIN", "NONE"), + ), + ) def test_sparse_aca_calibration(sim): diff --git a/tests/unit/calibration/test_unified_calibration.py b/tests/unit/calibration/test_unified_calibration.py index 61462dd36..d3fe05bc0 100644 --- a/tests/unit/calibration/test_unified_calibration.py +++ b/tests/unit/calibration/test_unified_calibration.py @@ -621,6 +621,37 @@ def test_resume_checkpoint_rejects_changed_matrix_with_same_shape(self, tmp_path } ) + def test_resume_checkpoint_rejects_missing_matrix_fingerprint(self, tmp_path): + import torch + from policyengine_us_data.calibration.unified_calibration import ( + default_checkpoint_path, + fit_l0_weights, + ) + + weights_path = tmp_path / "weights.npy" + checkpoint_path = default_checkpoint_path(str(weights_path)) + kwargs = self._fit_kwargs(tmp_path) + kwargs["checkpoint_path"] = str(checkpoint_path) + + with patch( + "l0.calibration.SparseCalibrationWeights", + FakeSparseCalibrationWeights, + ): + first_weights = fit_l0_weights(**kwargs) + np.save(weights_path, first_weights) + + checkpoint = torch.load(checkpoint_path, map_location="cpu") + checkpoint["signature"].pop("x_sparse_sha256") + torch.save(checkpoint, checkpoint_path) + + with pytest.raises(ValueError, match="x_sparse_sha256"): + fit_l0_weights( + **{ + **kwargs, + "resume_from": str(checkpoint_path), + } + ) + class TestGeographyAssignmentCountyFips: """Verify county_fips field on GeographyAssignment.""" diff --git a/tests/unit/calibration/test_unified_matrix_builder_merge.py b/tests/unit/calibration/test_unified_matrix_builder_merge.py index d7bb1bb25..216fc5da6 100644 --- a/tests/unit/calibration/test_unified_matrix_builder_merge.py +++ b/tests/unit/calibration/test_unified_matrix_builder_merge.py @@ -1,5 +1,6 @@ import numpy as np +from policyengine_us_data.calibration.calibration_utils import apply_op from policyengine_us_data.calibration.unified_matrix_builder import ( UnifiedMatrixBuilder, _assemble_clone_values_standalone, @@ -51,6 +52,15 @@ def test_assemble_clone_values_standalone_preserves_string_constraints(): ] +def test_apply_op_matches_fixed_width_byte_string_constraints(): + values = np.array([b"NONE", b"CITIZEN", b"NONE"], dtype="S9") + + np.testing.assert_array_equal( + apply_op(values, "==", "NONE"), + np.array([True, False, True]), + ) + + def test_builder_assemble_clone_values_preserves_string_constraints(): builder = UnifiedMatrixBuilder.__new__(UnifiedMatrixBuilder) diff --git a/tests/unit/datasets/test_cps_identification.py b/tests/unit/datasets/test_cps_identification.py index a741ae462..6064d56f4 100644 --- a/tests/unit/datasets/test_cps_identification.py +++ b/tests/unit/datasets/test_cps_identification.py @@ -6,6 +6,9 @@ _derive_has_valid_ssn_from_ssn_card_type_codes, _derive_taxpayer_id_type_from_identification_flags, _high_confidence_tin_evidence, + _impute_has_tin, + _impute_has_valid_ssn, + _proxy_tax_unit_filers, _store_identification_variables, ) @@ -25,11 +28,53 @@ def _person_fixture(**overrides): "PEAFEVER": np.zeros(n, dtype=int), "CHAMPVA": np.zeros(n, dtype=int), "SSI_YN": np.zeros(n, dtype=int), + "WSAL_VAL": np.zeros(n, dtype=int), + "SEMP_VAL": np.zeros(n, dtype=int), } defaults.update(overrides) return pd.DataFrame(defaults) +def _cps_fixture( + *, + age, + tax_unit_ids, + weights=None, + employment_income=None, + self_employment_income=None, + prior_employment_income=None, + prior_self_employment_income=None, +): + n = len(age) + weights = np.ones(n) if weights is None else np.asarray(weights) + household_ids = np.arange(n) + return { + "age": np.asarray(age), + "person_tax_unit_id": np.asarray(tax_unit_ids), + "person_household_id": household_ids, + "household_id": household_ids, + "household_weight": weights, + "employment_income": ( + np.zeros(n) if employment_income is None else np.asarray(employment_income) + ), + "self_employment_income": ( + np.zeros(n) + if self_employment_income is None + else np.asarray(self_employment_income) + ), + "employment_income_last_year": ( + np.zeros(n) + if prior_employment_income is None + else np.asarray(prior_employment_income) + ), + "self_employment_income_last_year": ( + np.zeros(n) + if prior_self_employment_income is None + else np.asarray(prior_self_employment_income) + ), + } + + def test_derive_has_valid_ssn_from_ssn_card_type_codes(): result = _derive_has_valid_ssn_from_ssn_card_type_codes( np.array([0, 1, 2, 3]), @@ -41,6 +86,14 @@ def test_derive_has_valid_ssn_from_ssn_card_type_codes(): ) +def test_impute_has_valid_ssn_does_not_treat_ead_proxy_as_direct_evidence(): + result = _impute_has_valid_ssn( + ssn_card_type=np.array([0, 1, 2, 3]), + ) + + np.testing.assert_array_equal(result, np.array([False, True, False, False])) + + def test_derive_taxpayer_id_type_from_identification_flags(): result = _derive_taxpayer_id_type_from_identification_flags( has_valid_ssn=np.array([False, True, False]), @@ -58,6 +111,15 @@ def test_high_confidence_admin_signal_gets_tin(): np.testing.assert_array_equal(result, np.array([True, True])) +def test_medicaid_only_is_not_high_confidence_tin_evidence(): + person = _person_fixture() + person["CAID"] = np.array([1, 0, 0, 0]) + + result = _high_confidence_tin_evidence(person) + + np.testing.assert_array_equal(result, np.zeros(4, dtype=bool)) + + def test_derive_has_tin_from_identification_inputs_is_conservative(): person = _person_fixture(SS_YN=np.zeros(5, dtype=int)) result = _derive_has_tin_from_identification_inputs( @@ -82,16 +144,84 @@ def test_other_non_citizen_with_admin_signal_gets_tin(): np.testing.assert_array_equal(result, np.array([True])) +def test_other_non_citizen_without_evidence_does_not_get_tin(): + person = _person_fixture() + cps = _cps_fixture(age=[40], tax_unit_ids=[1]) + + result = _impute_has_tin( + cps, + person.iloc[:1], + ssn_card_type=np.array([3]), + time_period=2024, + non_ssn_filer_tin_target=0, + ) + + np.testing.assert_array_equal(result, np.array([False])) + + +def test_tin_target_does_not_select_other_non_citizen_without_evidence(): + person = _person_fixture() + cps = _cps_fixture( + age=[40], + tax_unit_ids=[1], + self_employment_income=[5_000], + ) + + result = _impute_has_tin( + cps, + person.iloc[:1], + ssn_card_type=np.array([3]), + time_period=2024, + non_ssn_filer_tin_target=1, + ) + + np.testing.assert_array_equal(result, np.array([False])) + + +def test_proxy_tax_unit_filers_selects_two_oldest_adults(): + result = _proxy_tax_unit_filers( + person_tax_unit_ids=np.array([1, 1, 1, 2, 2]), + age=np.array([16, 40, 38, 12, 50]), + ) + + np.testing.assert_array_equal(result, np.array([False, True, True, False, True])) + + +def test_impute_has_tin_targets_likely_itin_filer_unit_and_minor_children(): + person = _person_fixture( + SS_YN=np.zeros(4, dtype=int), + MCARE=np.zeros(4, dtype=int), + ) + cps = _cps_fixture( + age=[40, 8, 40, 8], + tax_unit_ids=[1, 1, 2, 2], + self_employment_income=[5_000, 0, 0, 0], + ) + + result = _impute_has_tin( + cps, + person, + ssn_card_type=np.array([0, 0, 0, 0]), + time_period=2024, + non_ssn_filer_tin_target=1, + ) + + np.testing.assert_array_equal(result, np.array([True, True, False, False])) + + def test_store_identification_variables_writes_id_primitives(): cps = {} - person = _person_fixture(SS_YN=np.zeros(5, dtype=int)) - has_itin = np.array([False, False, False, False, True]) + person = _person_fixture(SS_YN=np.zeros(4, dtype=int)) + cps = _cps_fixture( + age=[40, 40, 40, 40], + tax_unit_ids=[1, 2, 3, 4], + ) _store_identification_variables( cps, person, - np.array([0, 1, 2, 3, 0]), - has_itin, + np.array([0, 1, 2, 3]), + time_period=2023, ) assert cps["ssn_card_type"].tolist() == [ @@ -99,21 +229,19 @@ def test_store_identification_variables_writes_id_primitives(): b"CITIZEN", b"NON_CITIZEN_VALID_EAD", b"OTHER_NON_CITIZEN", - b"NONE", ] assert cps["taxpayer_id_type"].tolist() == [ b"NONE", b"VALID_SSN", b"NONE", b"NONE", - b"OTHER_TIN", ] np.testing.assert_array_equal( cps["has_tin"], - np.array([False, True, False, False, True], dtype=bool), + np.array([False, True, False, False], dtype=bool), ) np.testing.assert_array_equal( cps["has_valid_ssn"], - np.array([False, True, False, False, False], dtype=bool), + np.array([False, True, False, False], dtype=bool), ) np.testing.assert_array_equal(cps["has_itin"], cps["has_tin"]) diff --git a/tests/unit/test_remote_calibration_runner.py b/tests/unit/test_remote_calibration_runner.py index 2d0196263..0c0fb9f91 100644 --- a/tests/unit/test_remote_calibration_runner.py +++ b/tests/unit/test_remote_calibration_runner.py @@ -1,6 +1,7 @@ import importlib import sys from types import ModuleType, SimpleNamespace +from unittest.mock import Mock def _load_remote_calibration_runner_module(): @@ -74,3 +75,42 @@ def test_collect_outputs_reads_checkpoint_bytes(tmp_path): "config": b"config", "checkpoint": b"checkpoint", } + + +def test_fit_weights_impl_saves_and_resumes_checkpoint_on_volume( + monkeypatch, + tmp_path, +): + remote_runner = _load_remote_calibration_runner_module() + (tmp_path / "policy_data.db").write_bytes(b"db") + (tmp_path / "source_imputed_stratified_extended_cps.h5").write_bytes(b"h5") + checkpoint = tmp_path / "test.checkpoint.pt" + checkpoint.write_bytes(b"old-checkpoint") + weights = tmp_path / "weights.npy" + + volume = SimpleNamespace(reload=Mock(), commit=Mock()) + monkeypatch.setattr(remote_runner, "pipeline_vol", volume) + monkeypatch.setattr(remote_runner, "_setup_repo", lambda: None) + + def fake_run_streaming(cmd, env=None, label=""): + assert "--resume-from" in cmd + assert cmd[cmd.index("--resume-from") + 1] == str(checkpoint) + assert "--checkpoint-output" in cmd + assert cmd[cmd.index("--checkpoint-output") + 1] == str(checkpoint) + weights.write_bytes(b"weights") + checkpoint.write_bytes(b"new-checkpoint") + return 0, [f"OUTPUT_PATH:{weights}", f"CHECKPOINT_PATH:{checkpoint}"] + + monkeypatch.setattr(remote_runner, "_run_streaming", fake_run_streaming) + + result = remote_runner._fit_weights_impl( + branch="main", + epochs=1, + artifacts_dir=str(tmp_path), + checkpoint_name="test.checkpoint.pt", + ) + + assert result["weights"] == b"weights" + assert result["checkpoint"] == b"new-checkpoint" + volume.reload.assert_called_once() + volume.commit.assert_called_once() From b4b85c91e55342e257b5b015c9e5de10a5291f37 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 10 Apr 2026 09:25:11 -0400 Subject: [PATCH 10/12] Fix PUF subsample logging format --- .../calibration/puf_impute.py | 20 ++++++++++++++++--- .../test_calibration_puf_impute.py | 13 ++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/policyengine_us_data/calibration/puf_impute.py b/policyengine_us_data/calibration/puf_impute.py index 92cdc1b28..f6bd0eeda 100644 --- a/policyengine_us_data/calibration/puf_impute.py +++ b/policyengine_us_data/calibration/puf_impute.py @@ -810,9 +810,7 @@ def _run_qrf_imputation( del puf_sim sub_idx = _stratified_subsample_index(puf_agi) - logger.info( - "Stratified PUF subsample: %d -> %d records " - "(top %.1f%% preserved, AGI threshold $%,.0f)", + _log_stratified_subsample( len(puf_agi), len(sub_idx), 100 - PUF_TOP_PERCENTILE, @@ -881,6 +879,22 @@ def _stratified_subsample_index( return selected +def _log_stratified_subsample( + original_n: int, + selected_n: int, + top_percent_preserved: float, + agi_threshold: float, +) -> None: + logger.info( + "Stratified PUF subsample: %d -> %d records " + "(top %.1f%% preserved, AGI threshold $%s)", + original_n, + selected_n, + top_percent_preserved, + f"{agi_threshold:,.0f}", + ) + + def _sequential_qrf( X_train: pd.DataFrame, X_test: pd.DataFrame, diff --git a/tests/unit/calibration/test_calibration_puf_impute.py b/tests/unit/calibration/test_calibration_puf_impute.py index 2596ab61f..8c55f731e 100644 --- a/tests/unit/calibration/test_calibration_puf_impute.py +++ b/tests/unit/calibration/test_calibration_puf_impute.py @@ -10,6 +10,7 @@ DEMOGRAPHIC_PREDICTORS, IMPUTED_VARIABLES, OVERRIDDEN_IMPUTED_VARIABLES, + _log_stratified_subsample, _stratified_subsample_index, puf_clone_dataset, ) @@ -191,3 +192,15 @@ def test_indices_sorted(self): income = np.random.default_rng(0).normal(50000, 20000, size=50_000) idx = _stratified_subsample_index(income, target_n=10_000) assert np.all(idx[1:] >= idx[:-1]) + + def test_log_handles_grouped_currency_threshold(self, caplog): + threshold = np.float32(8.934329e7) + caplog.set_level( + "INFO", + logger="policyengine_us_data.calibration.puf_impute", + ) + + _log_stratified_subsample(484_015, 20_000, 0.5, threshold) + + assert "Stratified PUF subsample: 484015 -> 20000 records" in caplog.text + assert f"${threshold:,.0f}" in caplog.text From c2cab9325b16a818ffdc29e8c0773bbd3f85bb5e Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Fri, 10 Apr 2026 13:43:14 -0400 Subject: [PATCH 11/12] fixes --- .../calibration/test_unified_calibration.py | 112 ++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/tests/unit/calibration/test_unified_calibration.py b/tests/unit/calibration/test_unified_calibration.py index d3fe05bc0..aaa838cf8 100644 --- a/tests/unit/calibration/test_unified_calibration.py +++ b/tests/unit/calibration/test_unified_calibration.py @@ -907,6 +907,118 @@ def test_non_county_var_uses_state_values(self): np.testing.assert_array_equal(hh_vars["snap"], expected) +class TestAssembleCloneValuesStringConstraint: + """Verify string constraint vars (e.g. ssn_card_type) are + assembled without crashing on float32 conversion.""" + + def _make_state_values(self): + n = 4 + return { + 1: { + "hh": {"snap": np.array([50] * n, dtype=np.float32)}, + "person": { + "ssn_card_type": np.array( + ["CITIZEN", "CITIZEN", "UNDOCUMENTED", "CITIZEN"], + dtype=object, + ), + }, + "entity": {}, + }, + 2: { + "hh": {"snap": np.array([60] * n, dtype=np.float32)}, + "person": { + "ssn_card_type": np.array( + ["UNDOCUMENTED", "CITIZEN", "CITIZEN", "UNDOCUMENTED"], + dtype=object, + ), + }, + "entity": {}, + }, + } + + def test_string_constraint_var_assembled(self): + from policyengine_us_data.calibration.unified_matrix_builder import ( + UnifiedMatrixBuilder, + ) + + state_values = self._make_state_values() + clone_states = np.array([1, 1, 2, 2]) + person_hh_idx = np.array([0, 1, 2, 3]) + + builder = UnifiedMatrixBuilder.__new__(UnifiedMatrixBuilder) + _, person_vars, _ = builder._assemble_clone_values( + state_values, + clone_states, + person_hh_idx, + {"snap"}, + {"ssn_card_type"}, + ) + assert "ssn_card_type" in person_vars + arr = person_vars["ssn_card_type"] + assert arr.dtype == object + expected = np.array( + ["CITIZEN", "CITIZEN", "CITIZEN", "UNDOCUMENTED"], dtype=object + ) + np.testing.assert_array_equal(arr, expected) + + def test_string_constraint_var_standalone(self): + from policyengine_us_data.calibration.unified_matrix_builder import ( + _assemble_clone_values_standalone, + ) + + state_values = self._make_state_values() + clone_states = np.array([1, 1, 2, 2]) + person_hh_idx = np.array([0, 1, 2, 3]) + + _, person_vars, _ = _assemble_clone_values_standalone( + state_values, + clone_states, + person_hh_idx, + {"snap"}, + {"ssn_card_type"}, + ) + assert "ssn_card_type" in person_vars + arr = person_vars["ssn_card_type"] + assert arr.dtype == object + expected = np.array( + ["CITIZEN", "CITIZEN", "CITIZEN", "UNDOCUMENTED"], dtype=object + ) + np.testing.assert_array_equal(arr, expected) + + def test_string_constraint_with_equality_op(self): + import pandas as pd + + from policyengine_us_data.calibration.unified_matrix_builder import ( + _assemble_clone_values_standalone, + _evaluate_constraints_standalone, + ) + + state_values = self._make_state_values() + clone_states = np.array([1, 1, 2, 2]) + person_hh_idx = np.array([0, 1, 2, 3]) + + _, person_vars, _ = _assemble_clone_values_standalone( + state_values, + clone_states, + person_hh_idx, + {"snap"}, + {"ssn_card_type"}, + ) + + household_ids = np.array([0, 1, 2, 3]) + entity_rel = pd.DataFrame( + {"household_id": person_hh_idx, "person_id": np.arange(4)} + ) + constraints = [ + {"variable": "ssn_card_type", "operation": "==", "value": "CITIZEN"} + ] + mask = _evaluate_constraints_standalone( + constraints, person_vars, entity_rel, household_ids, 4 + ) + expected = np.array([True, True, True, False]) + np.testing.assert_array_equal(mask, expected) + + class TestTakeupDrawConsistency: """Verify the matrix builder's inline takeup loop and compute_block_takeup_for_entities produce identical draws From 807551d2c7fa3f68d97ee7c33628147da5b36838 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 10 Apr 2026 14:46:43 -0400 Subject: [PATCH 12/12] Expand PR 708 changelog summary --- changelog.d/708.added | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/changelog.d/708.added b/changelog.d/708.added index 4d24af1c6..22364e1e7 100644 --- a/changelog.d/708.added +++ b/changelog.d/708.added @@ -1 +1 @@ -Save calibration geography as a pipeline artifact and add ``--resume-from`` checkpoint support for long-running calibration fits. +Save calibration geography as a pipeline artifact, add ``--resume-from`` and checkpoint support for long-running calibration fits, and fix resume/artifact handling in the remote calibration pipeline. This also adds conservative CPS taxpayer-ID outputs (``has_tin``, ``has_valid_ssn``, and a temporary ``has_itin`` compatibility alias), plus string-valued constraint handling needed for ID-target calibration.