diff --git a/docs/internals/README.md b/docs/internals/README.md index 65d4a195b..f2fdecbf8 100644 --- a/docs/internals/README.md +++ b/docs/internals/README.md @@ -31,6 +31,14 @@ expansion. ______________________________________________________________________ +## Refactor notes + +| Note | Purpose | +| ---- | ------- | +| [`local_h5_refactor_status.md`](local_h5_refactor_status.md) | Records the landed `local_h5` architecture for area publishing, the thin adapter layers that still remain, and the work that was explicitly deferred in the H5 refactor PR. | + +______________________________________________________________________ + ## Pipeline orchestration reference The pipeline runs on [Modal](https://modal.com) via `modal_app/pipeline.py`. It chains five steps diff --git a/docs/internals/local_h5_refactor_status.md b/docs/internals/local_h5_refactor_status.md new file mode 100644 index 000000000..9a33aa79f --- /dev/null +++ b/docs/internals/local_h5_refactor_status.md @@ -0,0 +1,173 @@ +# Local H5 Refactor Status + +Date: 2026-04-09 + +This note records what actually landed in the `fix/target-architecture-h5` +refactor for the US local and national H5 publishing path. + +It is intentionally narrower than the broader architecture planning docs. The goal here is to +describe the code that now exists, the remaining thin spots, and the work that was explicitly +deferred. + +## What Landed + +The H5 path now has explicit internal contracts and a request-driven architecture: + +- `policyengine_us_data.calibration.local_h5.contracts` + - request, filter, validation, and worker result contracts + - `to_dict()` / `from_dict()` support for adapter boundaries +- `policyengine_us_data.calibration.local_h5.partitioning` + - tested weighted work partitioning +- `policyengine_us_data.calibration.local_h5.package_geography` + - exact calibration-package geography loading +- `policyengine_us_data.calibration.local_h5.fingerprinting` + - typed publish fingerprint inputs and records +- `policyengine_us_data.calibration.local_h5.selection` + - clone-weight layout and area selection +- `policyengine_us_data.calibration.local_h5.source_dataset` + - worker-scoped source snapshot with lazy variable access +- `policyengine_us_data.calibration.local_h5.reindexing` + - pure entity reindexing +- `policyengine_us_data.calibration.local_h5.variables` + - variable cloning and export policy +- `policyengine_us_data.calibration.local_h5.us_augmentations` + - US-only payload augmentation +- `policyengine_us_data.calibration.local_h5.builder` + - `LocalAreaDatasetBuilder` as the one-area orchestration root +- `policyengine_us_data.calibration.local_h5.writer` + - `H5Writer` as the H5 persistence boundary +- `policyengine_us_data.calibration.local_h5.worker_service` + - `WorkerSession` + - `LocalH5WorkerService` + - validation context loading + - request/result adaptation helpers +- `policyengine_us_data.calibration.local_h5.area_catalog` + - concrete `USAreaCatalog` + +The public entrypoints still exist, but they are now adapters over the internal components: + +- `policyengine_us_data.calibration.publish_local_area.build_h5(...)` +- `modal_app.worker_script` +- `modal_app.local_area.coordinate_publish(...)` +- `modal_app.local_area.coordinate_national_publish(...)` + +## Current Shape + +The current H5 publishing path is: + +1. coordinator derives publish inputs and fingerprint +2. coordinator builds concrete US requests from `USAreaCatalog` +3. coordinator partitions weighted requests across workers +4. worker script loads one `WorkerSession` +5. worker service iterates requests in the chunk +6. builder creates one in-memory payload per request +7. writer persists the H5 +8. validation runs per output when enabled +9. coordinator aggregates structured worker results + +In other words: + +- one-area build logic now lives in `LocalAreaDatasetBuilder` +- one-worker-chunk logic now lives in `LocalH5WorkerService` +- coordinator logic is thinner and request-driven + +## What Stayed Concrete And US-Specific + +This refactor deliberately did **not** try to create a fake shared cross-country core. + +Still US-specific by design: + +- `CloneWeightMatrix` +- `USAreaCatalog` +- `USAugmentationService` +- the current local-H5 coordinator/orchestration adapters + +That is intentional. The code was only generalized where there was already a real stable seam. + +## Test Status + +The refactor added a cheap unit-first suite around the new seams. At the end of +the coordinator refactor, the targeted local-H5 suite was passing: + +```text +81 passed +``` + +Coverage now exists for: + +- contracts +- partitioning +- validation helpers and worker validation contract +- package geography loading +- fingerprinting +- selection +- source snapshot loading +- reindexing +- variable cloning +- US augmentations +- builder and writer seams +- worker service behavior +- US area catalog behavior +- coordinator contract behavior +- calibration package serialized geography round-trip + +The deliberate gap is heavy runtime integration. The branch does **not** add a broad slow parity +suite. + +This was intentional. The PR was designed so most correctness lives in unit-testable +components, with only thin compatibility or seam coverage on top. + +## Deferred Follow-Ups + +These items were explicitly left out or only partially handled: + +1. Heavy compatibility and invariant testing + - broader `build_h5` runtime parity + - deeper `X @ w` / area-aggregate invariants + - full Modal-like integration coverage + +2. Validator unification + - per-area target validation is now structurally correct + - national validation is still partly separate + - only `ValidationPolicy.enabled` is enforced today; the finer-grained + validation policy fields are present but not fully wired through + +3. Fingerprint schema simplification + - clone count is now canonicalized from weights + - long-term package-backed fingerprinting should stop treating `n_clones` and `seed` as + semantic equality inputs + +4. Possible later shared-core extraction + - nothing in this branch proves that the US abstractions are yet the right shared abstractions + for UK or another country + +5. Coordinator cleanup beyond the H5 scope + - Modal upload/promotion/manifest logic remains adapter-heavy + - that is outside the intended scope of this refactor + +## What This Documentation Does Not Claim + +This branch does **not** establish a reusable cross-country core library. + +It does establish a cleaner set of seams that another country pipeline could +learn from: + +- request/result contracts +- builder and worker-service boundaries +- package-backed geography loading +- lazy source snapshot handling + +Whether any of those should later move into a real shared abstraction should be +decided only after a second concrete implementation proves the shape. + +## Reading Order + +If you need to understand the landed architecture quickly, read in this order: + +1. `policyengine_us_data/calibration/local_h5/contracts.py` +2. `policyengine_us_data/calibration/local_h5/builder.py` +3. `policyengine_us_data/calibration/local_h5/worker_service.py` +4. `policyengine_us_data/calibration/local_h5/area_catalog.py` +5. `policyengine_us_data/calibration/publish_local_area.py` +6. `modal_app/worker_script.py` +7. `modal_app/local_area.py` diff --git a/modal_app/local_area.py b/modal_app/local_area.py index 036f069a8..a637a6fb3 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -1,14 +1,17 @@ """ -Modal app for publishing local area H5 files with parallel workers. +Modal coordinator for local and national H5 publishing. -Architecture: -1. Coordinator partitions work across N workers -2. Workers build H5 files in parallel, writing to shared Volume -3. Validation generates manifest with checksums -4. Atomic upload to versioned paths, updates latest.json last +This module is now an adapter over the internal `local_h5` components: -Usage: - modal run modal_app/local_area.py --branch=main --num-workers=8 +1. Resolve concrete US publish requests from `USAreaCatalog` +2. Reconcile the staging run directory against the publish fingerprint +3. Partition request work across Modal workers +4. Invoke the worker script with serialized request payloads +5. Aggregate structured worker results, validation rows, and errors +6. Stage manifests and uploads + +The one-area build logic no longer lives here. That now sits under +`policyengine_us_data.calibration.local_h5`. """ import heapq @@ -28,8 +31,28 @@ if _p not in sys.path: sys.path.insert(0, _p) -from modal_app.images import cpu_image as image + from modal_app.images import cpu_image as image from modal_app.resilience import reconcile_run_dir_fingerprint +from policyengine_us_data.calibration.local_h5.area_catalog import ( + USAreaCatalog, +) +from policyengine_us_data.calibration.local_h5.contracts import ( + AreaBuildRequest, + AreaBuildResult, + ValidationIssue, + WorkerResult, +) +from policyengine_us_data.calibration.local_h5.fingerprinting import ( + FingerprintService, +) +from policyengine_us_data.calibration.local_h5.package_geography import ( + CalibrationPackageGeographyLoader, + require_calibration_package_path, +) +from policyengine_us_data.calibration.local_h5.partitioning import ( + partition_weighted_work_items, + work_item_key, +) app = modal.App("policyengine-us-data-local-area") @@ -140,32 +163,6 @@ def get_version() -> str: return pyproject["project"]["version"] -def partition_work( - work_items: List[Dict], - num_workers: int, - completed: set, -) -> List[List[Dict]]: - """Partition work items across N workers using LPT scheduling.""" - remaining = [ - item for item in work_items if f"{item['type']}:{item['id']}" not in completed - ] - remaining.sort(key=lambda x: -x["weight"]) - - n_workers = min(num_workers, len(remaining)) - if n_workers == 0: - return [] - - heap = [(0, i) for i in range(n_workers)] - chunks = [[] for _ in range(n_workers)] - - for item in remaining: - load, idx = heapq.heappop(heap) - chunks[idx].append(item) - heapq.heappush(heap, (load + item["weight"], idx)) - - return [c for c in chunks if c] - - def get_completed_from_volume(run_dir: Path) -> set: """Scan volume to find already-built files.""" completed = set() @@ -188,9 +185,161 @@ def get_completed_from_volume(run_dir: Path) -> set: return completed +def _derive_canonical_n_clones( + *, + weights_path: Path, + package_path: Path, + requested_n_clones: int, +) -> int: + """Use weights length as the canonical clone-count source for publishing.""" + + import numpy as np + + from policyengine_us_data.calibration.local_h5.package_geography import ( + CalibrationPackageGeographyLoader, + ) + from policyengine_us_data.calibration.local_h5.weights import ( + infer_clone_count_from_weight_length, + ) + + weights = np.load(weights_path, mmap_mode="r") + loader = CalibrationPackageGeographyLoader() + loaded = loader.load(package_path) + if loaded is None: + raise RuntimeError( + f"Calibration package at {package_path} does not contain usable geography" + ) + + canonical_n_clones = infer_clone_count_from_weight_length( + weights.shape[0], + loaded.geography.n_records, + ) + if requested_n_clones != canonical_n_clones: + print( + f"WARNING: requested n_clones={requested_n_clones} but " + f"weights imply {canonical_n_clones}; using weights-derived value" + ) + return canonical_n_clones + + +def _request_key(request: AreaBuildRequest) -> str: + return f"{request.area_type}:{request.area_id}" + + +def _phase_errors_from_worker_result(worker_result: WorkerResult) -> list[dict]: + phase_errors: list[dict] = [] + + for result in worker_result.failed: + phase_errors.append( + { + "type": "build_failure", + "item": _request_key(result.request), + "error": result.build_error, + } + ) + + for issue in worker_result.worker_issues: + phase_errors.append( + { + "type": "worker_issue", + "item": "worker", + "error": issue.message, + "code": issue.code, + "details": dict(issue.details), + } + ) + + return phase_errors + + +def _validation_rows_from_worker_result(worker_result: WorkerResult) -> list[dict]: + rows: list[dict] = [] + for result in worker_result.completed: + if result.validation.status in ("passed", "failed"): + rows.extend(dict(row) for row in result.validation.rows) + return rows + + +def _validation_errors_from_worker_result(worker_result: WorkerResult) -> list[dict]: + errors: list[dict] = [] + for result in worker_result.completed: + if result.validation.status != "error": + continue + item_key = _request_key(result.request) + for issue in result.validation.issues: + errors.append( + { + "item": item_key, + "error": issue.message, + "code": issue.code, + "details": dict(issue.details), + } + ) + return errors + + +def _worker_failure_result( + requests: List[Dict], + *, + error: str, + code: str, +) -> Dict: + failed_results = [] + for payload in requests: + request = AreaBuildRequest.from_dict(payload) + failed_results.append( + AreaBuildResult( + request=request, + build_status="failed", + build_error=error, + ) + ) + + result = WorkerResult( + completed=(), + failed=tuple(failed_results), + worker_issues=( + ValidationIssue( + code=code, + message=error, + severity="error", + ), + ), + ) + return result.to_dict() + + +def _load_catalog_geography( + package_path: Path, + *, + weights_path: Path, + n_clones: int, + seed: int, +): + import numpy as np + + loader = CalibrationPackageGeographyLoader() + weights = np.load(weights_path, mmap_mode="r") + weights_length = int(np.asarray(weights).size) + if weights_length % n_clones != 0: + raise RuntimeError( + "Weights are incompatible with the requested clone count: " + f"length={weights_length}, n_clones={n_clones}" + ) + loaded = loader.resolve_for_weights( + package_path=package_path, + weights_length=weights_length, + n_records=weights_length // n_clones, + n_clones=n_clones, + seed=seed, + allow_seed_fallback=False, + ) + return loaded.geography + + def run_phase( phase_name: str, - work_items: List[Dict], + entries: List, num_workers: int, completed: set, branch: str, @@ -202,12 +351,18 @@ def run_phase( """Run a single build phase, spawning workers and collecting results. Returns: - A tuple of (volume_completed, phase_errors, validation_rows) + A tuple of (volume_completed, phase_errors, validation_rows, + validation_errors) where phase_errors is a list of error dicts from workers - and crashes, and validation_rows is a list of per-target - validation result dicts. + and crashes, validation_rows is a list of per-target + validation result dicts, and validation_errors is a list + of structured validation execution failures. """ - work_chunks = partition_work(work_items, num_workers, completed) + work_items = [entry.to_partition_item() for entry in entries] + requests_by_key = { + entry.key: entry.request.to_dict() for entry in entries + } + work_chunks = partition_weighted_work_items(work_items, num_workers, completed) total_remaining = sum(len(c) for c in work_chunks) print(f"\n--- Phase: {phase_name} ---") @@ -215,16 +370,17 @@ def run_phase( if total_remaining == 0: print(f"All {phase_name} items already built!") - return completed, [], [] + return completed, [], [], [] handles = [] for i, chunk in enumerate(work_chunks): total_weight = sum(item["weight"] for item in chunk) print(f" Worker {i}: {len(chunk)} items, weight {total_weight}") + requests = [requests_by_key[work_item_key(item)] for item in chunk] handle = build_areas_worker.spawn( branch=branch, run_id=run_id, - work_items=chunk, + requests=requests, calibration_inputs=calibration_inputs, validate=validate, ) @@ -235,34 +391,45 @@ def run_phase( all_results = [] all_errors = [] all_validation_rows = [] + all_validation_errors = [] for i, handle in enumerate(handles): try: - result = handle.get() - if result is None: + payload = handle.get() + if payload is None: all_errors.append({"worker": i, "error": "Worker returned None"}) print(f" Worker {i}: returned None (no results)") continue + result = WorkerResult.from_dict(payload) all_results.append(result) print( - f" Worker {i}: {len(result['completed'])} completed, " - f"{len(result['failed'])} failed" + f" Worker {i}: {len(result.completed)} completed, " + f"{len(result.failed)} failed" ) - if result["errors"]: - all_errors.extend(result["errors"]) - # Collect validation rows - v_rows = result.get("validation_rows", []) + worker_errors = _phase_errors_from_worker_result(result) + if worker_errors: + all_errors.extend(worker_errors) + v_rows = _validation_rows_from_worker_result(result) if v_rows: all_validation_rows.extend(v_rows) print(f" Worker {i}: {len(v_rows)} validation rows") + v_errors = _validation_errors_from_worker_result(result) + if v_errors: + all_validation_errors.extend(v_errors) + print(f" Worker {i}: {len(v_errors)} validation errors") except Exception as e: all_errors.append( - {"worker": i, "error": str(e), "traceback": traceback.format_exc()} + { + "type": "transport_error", + "worker": i, + "error": str(e), + "traceback": traceback.format_exc(), + } ) print(f" Worker {i}: CRASHED - {e}") - total_completed = sum(len(r["completed"]) for r in all_results) - total_failed = sum(len(r["failed"]) for r in all_results) + total_completed = sum(len(r.completed) for r in all_results) + total_failed = sum(len(r.failed) for r in all_results) staging_volume.reload() volume_completed = get_completed_from_volume(run_dir) @@ -283,7 +450,7 @@ def run_phase( if len(all_errors) > 5: print(f" ... and {len(all_errors) - 5} more") - return volume_completed, all_errors, all_validation_rows + return volume_completed, all_errors, all_validation_rows, all_validation_errors @app.function( @@ -302,7 +469,7 @@ def run_phase( def build_areas_worker( branch: str, run_id: str, - work_items: List[Dict], + requests: List[Dict], calibration_inputs: Dict[str, str], validate: bool = True, ) -> Dict: @@ -316,15 +483,15 @@ def build_areas_worker( output_dir = Path(VOLUME_MOUNT) / run_id output_dir.mkdir(parents=True, exist_ok=True) - work_items_json = json.dumps(work_items) + requests_json = json.dumps(requests) worker_cmd = [ "uv", "run", "python", "modal_app/worker_script.py", - "--work-items", - work_items_json, + "--requests-json", + requests_json, "--weights-path", calibration_inputs["weights"], "--dataset-path", @@ -334,6 +501,10 @@ def build_areas_worker( "--output-dir", str(output_dir), ] + if "package" in calibration_inputs: + worker_cmd.extend( + ["--calibration-package-path", calibration_inputs["package"]] + ) if "n_clones" in calibration_inputs: worker_cmd.extend(["--n-clones", str(calibration_inputs["n_clones"])]) if "seed" in calibration_inputs: @@ -364,20 +535,22 @@ def build_areas_worker( if result.returncode != 0: print(f"Worker stderr:\n{result.stderr}", file=__import__("sys").stderr) - return { - "completed": [], - "failed": [f"{item['type']}:{item['id']}" for item in work_items], - "errors": [{"error": (result.stderr or "No stderr")[:2000]}], - } + return _worker_failure_result( + requests, + error=(result.stderr or "No stderr")[:2000], + code="worker_subprocess_failed", + ) try: results = json.loads(result.stdout) except json.JSONDecodeError: - results = { - "completed": [], - "failed": [], - "errors": [{"error": f"Failed to parse output: {result.stdout}"}], - } + results = _worker_failure_result( + requests, + error=f"Failed to parse output: {result.stdout}", + code="worker_output_parse_failed", + ) + + results = WorkerResult.from_dict(results).to_dict() staging_volume.commit() return results @@ -661,6 +834,9 @@ def coordinate_publish( weights_path = artifacts / "calibration_weights.npy" db_path = artifacts / "policy_data.db" dataset_path = artifacts / "source_imputed_stratified_extended_cps.h5" + package_path = require_calibration_package_path( + artifacts / "calibration_package.pkl" + ) config_json_path = artifacts / "unified_run_config.json" required = { @@ -676,13 +852,21 @@ def coordinate_publish( ) print("All required pipeline artifacts found on volume.") + canonical_n_clones = _derive_canonical_n_clones( + weights_path=weights_path, + package_path=package_path, + requested_n_clones=n_clones, + ) + calibration_inputs = { "weights": str(weights_path), "dataset": str(dataset_path), "database": str(db_path), - "n_clones": n_clones, + "n_clones": canonical_n_clones, "seed": 42, + "package": str(package_path), } + print(f"Using calibration package geography from {package_path}") validate_artifacts(config_json_path, artifacts) if validate: @@ -703,84 +887,52 @@ def coordinate_publish( # Fingerprint-based cache invalidation if expected_fingerprint: - fingerprint = expected_fingerprint - print(f"Using pinned fingerprint from pipeline: {fingerprint}") - else: - fp_result = subprocess.run( - [ - "uv", - "run", - "python", - "-c", - f""" -from policyengine_us_data.calibration.publish_local_area import ( - compute_input_fingerprint, -) -print(compute_input_fingerprint("{weights_path}", "{dataset_path}", {n_clones}, seed=42)) -""", - ], - capture_output=True, - text=True, - env=os.environ.copy(), + print(f"Using pinned fingerprint from pipeline: {expected_fingerprint}") + + fingerprint_service = FingerprintService() + fingerprint_record = fingerprint_service.create_publish_fingerprint( + weights_path=weights_path, + dataset_path=dataset_path, + calibration_package_path=package_path, + n_clones=canonical_n_clones, + seed=42, + ) + fingerprint = fingerprint_record.digest + + if expected_fingerprint and expected_fingerprint != fingerprint: + raise RuntimeError( + "Pinned fingerprint does not match current publish inputs.\n" + f" Expected: {expected_fingerprint}\n" + f" Current: {fingerprint}\n" + "Start a fresh run instead of resuming." ) - if fp_result.returncode != 0: - raise RuntimeError(f"Failed to compute fingerprint: {fp_result.stderr}") - fingerprint = fp_result.stdout.strip() - reconcile_action = reconcile_run_dir_fingerprint(run_dir, fingerprint) + + reconcile_action = reconcile_run_dir_fingerprint( + run_dir, + fingerprint_record, + scope="regional", + ) if reconcile_action == "resume": print(f"Inputs unchanged ({fingerprint}), resuming...") else: print(f"Prepared staging directory for fingerprint {fingerprint}") staging_volume.commit() - result = subprocess.run( - [ - "uv", - "run", - "python", - "-c", - f""" -import json -from policyengine_us_data.calibration.calibration_utils import ( - get_all_cds_from_database, - STATE_CODES, -) -from policyengine_us_data.calibration.publish_local_area import ( - get_district_friendly_name, -) - -db_uri = "sqlite:///{db_path}" -cds = get_all_cds_from_database(db_uri) -states = list(STATE_CODES.values()) -districts = [get_district_friendly_name(cd) for cd in cds] -print(json.dumps({{"states": states, "districts": districts, "cities": ["NYC"], "cds": cds}})) -""", - ], - capture_output=True, - text=True, - env=os.environ.copy(), + catalog = USAreaCatalog() + catalog_geography = _load_catalog_geography( + package_path, + weights_path=weights_path, + n_clones=canonical_n_clones, + seed=42, ) - - if result.returncode != 0: - raise RuntimeError(f"Failed to get work items: {result.stderr}") - - work_info = json.loads(result.stdout) - states = work_info["states"] - districts = work_info["districts"] - cities = work_info["cities"] - - from collections import Counter - - cds_per_state = Counter(d.split("-")[0] for d in districts) - - CITY_WEIGHTS = {"NYC": 11} - - work_items = [] - for s in states: - work_items.append({"type": "state", "id": s, "weight": cds_per_state.get(s, 1)}) - for d in districts: - work_items.append({"type": "district", "id": d, "weight": 1}) - for c in cities: - work_items.append({"type": "city", "id": c, "weight": CITY_WEIGHTS.get(c, 3)}) + entries = list( + catalog.resolved_regional_entries( + f"sqlite:///{db_path}", + geography=catalog_geography, + ) + ) + states = [e for e in entries if e.request.area_type == "state"] + districts = [e for e in entries if e.request.area_type == "district"] + cities = [e for e in entries if e.request.area_type == "city"] staging_volume.reload() completed = get_completed_from_volume(run_dir) @@ -797,22 +949,50 @@ def coordinate_publish( accumulated_errors = [] accumulated_validation_rows = [] + accumulated_validation_errors = [] - completed, phase_errors, v_rows = run_phase( + completed, phase_errors, v_rows, v_errors = run_phase( "All areas", - work_items=work_items, + entries=entries, completed=completed, **phase_args, ) accumulated_errors.extend(phase_errors) accumulated_validation_rows.extend(v_rows) + accumulated_validation_errors.extend(v_errors) - expected_total = len(states) + len(districts) + len(cities) + expected_total = len(entries) - # If workers crashed but all files landed on the volume, - # treat as transient infrastructure errors (e.g. gRPC stream resets). if accumulated_errors: - crash_errors = [e for e in accumulated_errors if "worker" in e] + build_failures = [ + error + for error in accumulated_errors + if error.get("type") == "build_failure" + ] + if build_failures: + raise RuntimeError( + f"Build failed: {len(build_failures)} build failure(s) reported. " + f"Errors: {build_failures[:3]}" + ) + + worker_issues = [ + error + for error in accumulated_errors + if error.get("type") == "worker_issue" + ] + if worker_issues: + raise RuntimeError( + f"Build failed: {len(worker_issues)} worker issue(s) reported. " + f"Errors: {worker_issues[:3]}" + ) + + # If workers crashed in transit but all files landed on the volume, + # treat that as transient infrastructure noise (e.g. gRPC stream resets). + crash_errors = [ + error + for error in accumulated_errors + if error.get("type") == "transport_error" + ] if crash_errors and len(completed) >= expected_total: print( f"WARNING: {len(crash_errors)} worker error(s) occurred " @@ -840,13 +1020,14 @@ def coordinate_publish( return { "message": (f"Build complete for version {version}. Upload skipped."), "validation_rows": accumulated_validation_rows, + "validation_errors": accumulated_validation_errors, "fingerprint": fingerprint, } print("\nValidating staging...") manifest = validate_staging.remote(branch=branch, run_id=run_id, version=version) - expected_total = len(states) + len(districts) + len(cities) + expected_total = len(entries) actual_total = ( manifest["totals"]["states"] + manifest["totals"]["districts"] @@ -876,6 +1057,7 @@ def coordinate_publish( "message": result, "run_id": run_id, "validation_rows": accumulated_validation_rows, + "validation_errors": accumulated_validation_errors, "fingerprint": fingerprint, } @@ -918,6 +1100,7 @@ def coordinate_national_publish( n_clones: int = 430, validate: bool = True, run_id: str = "", + expected_fingerprint: str = "", ) -> Dict: """Build and upload a national US.h5 from national weights.""" setup_gcp_credentials() @@ -945,6 +1128,9 @@ def coordinate_national_publish( weights_path = artifacts / "national_calibration_weights.npy" db_path = artifacts / "policy_data.db" dataset_path = artifacts / "source_imputed_stratified_extended_cps.h5" + package_path = require_calibration_package_path( + artifacts / "calibration_package.pkl" + ) config_json_path = artifacts / "national_unified_run_config.json" required = { @@ -960,13 +1146,21 @@ def coordinate_national_publish( ) print("All required national pipeline artifacts found.") + canonical_n_clones = _derive_canonical_n_clones( + weights_path=weights_path, + package_path=package_path, + requested_n_clones=n_clones, + ) + calibration_inputs = { "weights": str(weights_path), "dataset": str(dataset_path), "database": str(db_path), - "n_clones": n_clones, + "n_clones": canonical_n_clones, "seed": 42, + "package": str(package_path), } + print(f"Using calibration package geography from {package_path}") validate_artifacts( config_json_path, artifacts, @@ -975,31 +1169,75 @@ def coordinate_national_publish( }, ) run_dir = staging_dir / run_id - run_dir.mkdir(parents=True, exist_ok=True) - work_items = [{"type": "national", "id": "US"}] - print("Spawning worker for national H5 build...") - worker_result = build_areas_worker.remote( - branch=branch, - run_id=run_id, - work_items=work_items, - calibration_inputs=calibration_inputs, - validate=validate, + if expected_fingerprint: + print(f"Using pinned fingerprint from pipeline: {expected_fingerprint}") + + fingerprint_service = FingerprintService() + fingerprint_record = fingerprint_service.create_publish_fingerprint( + weights_path=weights_path, + dataset_path=dataset_path, + calibration_package_path=package_path, + n_clones=canonical_n_clones, + seed=42, ) + fingerprint = fingerprint_record.digest - print( - f"Worker result: " - f"{len(worker_result['completed'])} completed, " - f"{len(worker_result['failed'])} failed" - ) + if expected_fingerprint and expected_fingerprint != fingerprint: + raise RuntimeError( + "Pinned fingerprint does not match current publish inputs.\n" + f" Expected: {expected_fingerprint}\n" + f" Current: {fingerprint}\n" + "Start a fresh run instead of resuming." + ) - if worker_result["failed"]: - raise RuntimeError(f"National build failed: {worker_result['errors']}") + reconcile_action = reconcile_run_dir_fingerprint( + run_dir, + fingerprint_record, + scope="national", + ) + if reconcile_action == "resume": + print(f"Inputs unchanged ({fingerprint}), resuming...") + else: + print(f"Prepared staging directory for fingerprint {fingerprint}") + staging_volume.commit() + catalog = USAreaCatalog() + national_entry = catalog.resolved_national_entry() staging_volume.reload() national_h5 = run_dir / "national" / "US.h5" - if not national_h5.exists(): - raise RuntimeError(f"Expected {national_h5} not found after build") + if reconcile_action == "resume" and national_h5.exists(): + print("National H5 already present for matching fingerprint; skipping rebuild.") + worker_result = WorkerResult( + completed=(), + failed=(), + ) + else: + print("Spawning worker for national H5 build...") + worker_payload = build_areas_worker.remote( + branch=branch, + run_id=run_id, + requests=[national_entry.request.to_dict()], + calibration_inputs=calibration_inputs, + validate=validate, + ) + worker_result = WorkerResult.from_dict(worker_payload) + + print( + f"Worker result: " + f"{len(worker_result.completed)} completed, " + f"{len(worker_result.failed)} failed" + ) + + phase_errors = _phase_errors_from_worker_result(worker_result) + if worker_result.failed or phase_errors: + raise RuntimeError(f"National build failed: {phase_errors}") + + staging_volume.reload() + if not national_h5.exists(): + raise RuntimeError(f"Expected {national_h5} not found after build") + + validation_errors = _validation_errors_from_worker_result(worker_result) # Compute SHA256 checksum before upload for integrity verification import hashlib @@ -1082,6 +1320,8 @@ def coordinate_national_publish( ), "run_id": run_id, "national_validation": national_validation_output, + "validation_errors": validation_errors, + "fingerprint": fingerprint, } diff --git a/modal_app/pipeline.py b/modal_app/pipeline.py index 413a12d18..6d539a63a 100644 --- a/modal_app/pipeline.py +++ b/modal_app/pipeline.py @@ -104,13 +104,23 @@ class RunMetadata: error: Optional[str] = None resume_history: list = field(default_factory=list) fingerprint: Optional[str] = None + regional_fingerprint: Optional[str] = None + national_fingerprint: Optional[str] = None def to_dict(self) -> dict: - return asdict(self) + payload = asdict(self) + if payload.get("fingerprint") is None and payload.get("regional_fingerprint"): + payload["fingerprint"] = payload["regional_fingerprint"] + return payload @classmethod def from_dict(cls, data: dict) -> "RunMetadata": - return cls(**data) + payload = dict(data) + if payload.get("regional_fingerprint") is None and payload.get("fingerprint"): + payload["regional_fingerprint"] = payload["fingerprint"] + if payload.get("fingerprint") is None and payload.get("regional_fingerprint"): + payload["fingerprint"] = payload["regional_fingerprint"] + return cls(**payload) def generate_run_id(version: str, sha: str) -> str: @@ -457,11 +467,17 @@ def _write_validation_diagnostics( Extracts validation_rows from coordinate_publish and national_validation from coordinate_national_publish, writes them to runs/{run_id}/diagnostics/validation_results.csv, + writes validation execution failures to + runs/{run_id}/diagnostics/validation_errors.json, and records a summary in meta.json. """ import csv + from policyengine_us_data.calibration.local_h5.validation import ( + tag_validation_errors, + ) validation_rows = [] + validation_errors = [] # Extract regional validation rows if isinstance(regional_result, dict): @@ -469,6 +485,12 @@ def _write_validation_diagnostics( if v_rows: validation_rows.extend(v_rows) print(f" Collected {len(v_rows)} regional validation rows") + v_errors = regional_result.get("validation_errors", []) + if v_errors: + validation_errors.extend( + tag_validation_errors(v_errors, source="regional") + ) + print(f" Collected {len(v_errors)} regional validation errors") # Extract national validation output national_output = "" @@ -476,8 +498,14 @@ def _write_validation_diagnostics( national_output = national_result.get("national_validation", "") if national_output: print(" National validation output captured") + v_errors = national_result.get("validation_errors", []) + if v_errors: + validation_errors.extend( + tag_validation_errors(v_errors, source="national") + ) + print(f" Collected {len(v_errors)} national validation errors") - if not validation_rows and not national_output: + if not validation_rows and not national_output and not validation_errors: print(" No validation data to write") return @@ -547,6 +575,7 @@ def _write_validation_diagnostics( "total_targets": len(validation_rows), "sanity_failures": n_sanity_fail, "mean_rel_abs_error": round(mean_rae, 4), + "validation_errors": len(validation_errors), "worst_areas": [ { "area": k, @@ -571,9 +600,7 @@ def _write_validation_diagnostics( f"mean RAE={mean_rae:.4f}" ) - # Record in meta.json meta.step_timings["validation"] = validation_summary - write_run_meta(meta, vol) # Write national validation output if national_output: @@ -582,6 +609,17 @@ def _write_validation_diagnostics( f.write(national_output) print(f" Wrote national validation to {nat_path}") + if validation_errors: + errors_path = diag_dir / "validation_errors.json" + with open(errors_path, "w") as f: + json.dump(validation_errors, f, indent=2) + print(f" Wrote {len(validation_errors)} validation errors to {errors_path}") + meta.step_timings.setdefault("validation", {}) + meta.step_timings["validation"]["validation_errors"] = len(validation_errors) + + if validation_rows or validation_errors: + write_run_meta(meta, vol) + vol.commit() @@ -902,7 +940,9 @@ def run_pipeline( n_clones=n_clones, validate=True, run_id=run_id, - expected_fingerprint=meta.fingerprint or "", + expected_fingerprint=( + meta.regional_fingerprint or meta.fingerprint or "" + ), ) print(f" → coordinate_publish fc: {regional_h5_handle.object_id}") @@ -914,6 +954,7 @@ def run_pipeline( n_clones=n_clones, validate=True, run_id=run_id, + expected_fingerprint=meta.national_fingerprint or "", ) print( f" → coordinate_national_publish fc: {national_h5_handle.object_id}" @@ -938,6 +979,7 @@ def run_pipeline( if isinstance(regional_h5_result, dict) and regional_h5_result.get( "fingerprint" ): + meta.regional_fingerprint = regional_h5_result["fingerprint"] meta.fingerprint = regional_h5_result["fingerprint"] write_run_meta(meta, pipeline_volume) @@ -951,6 +993,11 @@ def run_pipeline( else national_h5_result ) print(f" National H5: {national_msg}") + if isinstance(national_h5_result, dict) and national_h5_result.get( + "fingerprint" + ): + meta.national_fingerprint = national_h5_result["fingerprint"] + write_run_meta(meta, pipeline_volume) # ── Aggregate validation results ── _write_validation_diagnostics( diff --git a/modal_app/resilience.py b/modal_app/resilience.py index beed25317..0a5702572 100644 --- a/modal_app/resilience.py +++ b/modal_app/resilience.py @@ -1,6 +1,6 @@ """Helpers for retry and resume safety in Modal workflows.""" -import json +from dataclasses import dataclass import shutil import subprocess import time @@ -8,6 +8,28 @@ from typing import Optional +@dataclass(frozen=True) +class PublishScope: + name: str + owned_dirs: tuple[str, ...] + + +_PUBLISH_SCOPES = { + "all": PublishScope( + name="all", + owned_dirs=("states", "districts", "cities", "national"), + ), + "regional": PublishScope( + name="regional", + owned_dirs=("states", "districts", "cities"), + ), + "national": PublishScope( + name="national", + owned_dirs=("national",), + ), +} + + def run_with_retry( cmd: list[str], max_retries: int = 3, @@ -91,7 +113,9 @@ def ensure_resume_sha_compatible( def reconcile_run_dir_fingerprint( run_dir: Path, - fingerprint: str, + fingerprint, + *, + scope: str = "all", ) -> str: """Prepare a staging run directory for a specific fingerprint. @@ -100,37 +124,113 @@ def reconcile_run_dir_fingerprint( - changed or missing fingerprint with existing H5s: stop and preserve - changed or missing fingerprint without H5s: clear stale directory """ - fingerprint_file = run_dir / "fingerprint.json" + from policyengine_us_data.calibration.local_h5.fingerprinting import ( + FingerprintRecord, + FingerprintService, + ) + + service = FingerprintService() + if isinstance(fingerprint, FingerprintRecord): + current = fingerprint + else: + current = service.legacy_record(str(fingerprint)) + publish_scope = _resolve_publish_scope(scope) + fingerprint_file = _fingerprint_file_for_scope(run_dir, publish_scope) + legacy_fingerprint_file = run_dir / "fingerprint.json" if not run_dir.exists(): run_dir.mkdir(parents=True, exist_ok=True) - fingerprint_file.write_text(json.dumps({"fingerprint": fingerprint})) + fingerprint_file.parent.mkdir(parents=True, exist_ok=True) + service.write_record(fingerprint_file, current) return "initialized" - h5_count = len(list(run_dir.rglob("*.h5"))) - if fingerprint_file.exists(): - stored = json.loads(fingerprint_file.read_text()) - stored_fingerprint = stored.get("fingerprint") - if stored_fingerprint == fingerprint: + h5_count = _count_owned_h5_files(run_dir, publish_scope) + stored_file = _stored_fingerprint_file( + fingerprint_file=fingerprint_file, + legacy_fingerprint_file=legacy_fingerprint_file, + publish_scope=publish_scope, + ) + if stored_file is not None: + stored = service.read_record(stored_file) + stored_fingerprint = stored.digest + if service.matches(stored, current): + if stored_file != fingerprint_file: + fingerprint_file.parent.mkdir(parents=True, exist_ok=True) + service.write_record(fingerprint_file, current) return "resume" if h5_count > 0: raise RuntimeError( - "Fingerprint mismatch with existing staged H5 files.\n" + f"Fingerprint mismatch with existing staged {publish_scope.name} H5 files.\n" f" Stored: {stored_fingerprint}\n" - f" Current: {fingerprint}\n" + f" Current: {current.digest}\n" f" H5 files preserved: {h5_count}\n" "Start a fresh version or clear the stale outputs explicitly." ) - shutil.rmtree(run_dir) + _clear_scope_outputs(run_dir, publish_scope) else: if h5_count > 0: raise RuntimeError( - "Missing fingerprint metadata with existing staged H5 files.\n" + f"Missing fingerprint metadata with existing staged {publish_scope.name} H5 files.\n" f" H5 files preserved: {h5_count}\n" "Start a fresh version or clear the stale outputs explicitly." ) - shutil.rmtree(run_dir) + _clear_scope_outputs(run_dir, publish_scope) run_dir.mkdir(parents=True, exist_ok=True) - fingerprint_file.write_text(json.dumps({"fingerprint": fingerprint})) + fingerprint_file.parent.mkdir(parents=True, exist_ok=True) + service.write_record(fingerprint_file, current) return "initialized" + + +def _resolve_publish_scope(scope: str) -> PublishScope: + try: + return _PUBLISH_SCOPES[scope] + except KeyError as error: + raise ValueError(f"Unknown publish scope: {scope!r}") from error + + +def _fingerprint_file_for_scope(run_dir: Path, publish_scope: PublishScope) -> Path: + if publish_scope.name == "all": + return run_dir / "fingerprint.json" + return run_dir / ".publish_scopes" / publish_scope.name / "fingerprint.json" + + +def _stored_fingerprint_file( + *, + fingerprint_file: Path, + legacy_fingerprint_file: Path, + publish_scope: PublishScope, +) -> Path | None: + if fingerprint_file.exists(): + return fingerprint_file + if publish_scope.name != "all" and legacy_fingerprint_file.exists(): + return legacy_fingerprint_file + return None + + +def _count_owned_h5_files(run_dir: Path, publish_scope: PublishScope) -> int: + return sum( + len(list((run_dir / owned_dir).rglob("*.h5"))) + for owned_dir in publish_scope.owned_dirs + if (run_dir / owned_dir).exists() + ) + + +def _clear_scope_outputs(run_dir: Path, publish_scope: PublishScope) -> None: + if publish_scope.name == "all": + if run_dir.exists(): + shutil.rmtree(run_dir) + return + + for owned_dir in publish_scope.owned_dirs: + target = run_dir / owned_dir + if target.exists(): + shutil.rmtree(target) + + scope_meta_dir = run_dir / ".publish_scopes" / publish_scope.name + if scope_meta_dir.exists(): + shutil.rmtree(scope_meta_dir) + + publish_scopes_root = run_dir / ".publish_scopes" + if publish_scopes_root.exists() and not any(publish_scopes_root.iterdir()): + publish_scopes_root.rmdir() diff --git a/modal_app/worker_script.py b/modal_app/worker_script.py index 27dbb8c2a..f4023bd97 100644 --- a/modal_app/worker_script.py +++ b/modal_app/worker_script.py @@ -1,160 +1,37 @@ #!/usr/bin/env python """ -Worker script for building local area H5 files. +Thin CLI adapter for the local H5 worker service. -Called by Modal workers via subprocess to avoid import conflicts. +Modal launches this script in a subprocess so the worker runtime stays isolated +from the coordinator process. The actual chunk-level build behavior now lives +in `policyengine_us_data.calibration.local_h5.worker_service`. """ import argparse import json import sys -import traceback -import numpy as np from pathlib import Path - -def _validate_in_subprocess( - h5_path, - area_type, - area_id, - display_id, - area_targets, - area_training, - constraints_map, - db_path, - period, -): - """Run validation for one area inside a subprocess. - - All Microsimulation memory is reclaimed when the - subprocess exits. - """ - import logging - - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(levelname)s %(message)s", - ) - from policyengine_us import Microsimulation - from sqlalchemy import create_engine as _ce - from policyengine_us_data.calibration.validate_staging import ( - validate_area, - _build_variable_entity_map, - ) - - engine = _ce(f"sqlite:///{db_path}") - sim = Microsimulation(dataset=h5_path) - variable_entity_map = _build_variable_entity_map(sim) - - results = validate_area( - sim=sim, - targets_df=area_targets, - engine=engine, - area_type=area_type, - area_id=area_id, - display_id=display_id, - period=period, - training_mask=area_training, - variable_entity_map=variable_entity_map, - constraints_map=constraints_map, - ) - return results - - -def _validate_h5_subprocess( - h5_path, - item_type, - item_id, - state_fips, - candidate, - cd_subset, - validation_targets, - training_mask_full, - constraints_map, - db_path, - period, -): - """Spawn a subprocess to validate one H5 file. - - Uses multiprocessing spawn to isolate memory. - """ - import multiprocessing as _mp - - # Determine geo_level and geographic_id for filtering targets - if item_type == "state": - geo_level = "state" - geographic_id = str(state_fips) - area_type = "states" - display_id = item_id - elif item_type == "district": - geo_level = "district" - geographic_id = str(candidate) - area_type = "districts" - display_id = item_id - elif item_type == "city": - # NYC: aggregate targets for NYC CDs - geo_level = "district" - area_type = "cities" - display_id = item_id - elif item_type == "national": - geo_level = "national" - geographic_id = "US" - area_type = "national" - display_id = "US" - else: - return [] - - # Filter targets to matching area - if item_type == "city": - # Match targets for any NYC CD - nyc_cd_set = set(str(cd) for cd in cd_subset) - mask = (validation_targets["geo_level"] == geo_level) & validation_targets[ - "geographic_id" - ].astype(str).isin(nyc_cd_set) - elif item_type == "national": - mask = validation_targets["geo_level"] == geo_level - else: - mask = (validation_targets["geo_level"] == geo_level) & ( - validation_targets["geographic_id"].astype(str) == geographic_id - ) - - area_targets = validation_targets[mask].reset_index(drop=True) - area_training = training_mask_full[mask.values] - - if len(area_targets) == 0: - return [] - - # Filter constraints_map to relevant strata - area_strata = area_targets["stratum_id"].unique().tolist() - area_constraints = {int(s): constraints_map.get(int(s), []) for s in area_strata} - - ctx = _mp.get_context("spawn") - with ctx.Pool(1) as pool: - results = pool.apply( - _validate_in_subprocess, - ( - h5_path, - area_type, - item_id, - display_id, - area_targets, - area_training, - area_constraints, - db_path, - period, - ), - ) - - return results - - def main(): parser = argparse.ArgumentParser() - parser.add_argument("--work-items", required=True, help="JSON work items") + # Kept for backward compatibility with older worker launchers. + # New callers should pass fully resolved AreaBuildRequests via + # --requests-json instead. + parser.add_argument("--work-items", default=None, help="JSON work items") + parser.add_argument( + "--requests-json", + default=None, + help="JSON serialized AreaBuildRequest list", + ) parser.add_argument("--weights-path", required=True) parser.add_argument("--dataset-path", required=True) parser.add_argument("--db-path", required=True) parser.add_argument("--output-dir", required=True) + parser.add_argument( + "--calibration-package-path", + default=None, + help="Optional calibration package path for exact geography reuse", + ) parser.add_argument( "--n-clones", type=int, @@ -191,316 +68,123 @@ def main(): ) args = parser.parse_args() - work_items = json.loads(args.work_items) + if not args.requests_json and not args.work_items: + raise ValueError("Either --requests-json or --work-items is required") + + work_items = json.loads(args.work_items) if args.work_items else None + request_payloads = json.loads(args.requests_json) if args.requests_json else None weights_path = Path(args.weights_path) dataset_path = Path(args.dataset_path) db_path = Path(args.db_path) output_dir = Path(args.output_dir) - - from policyengine_us_data.utils.takeup import ( - SIMPLE_TAKEUP_VARS, + calibration_package_path = ( + Path(args.calibration_package_path) + if args.calibration_package_path + else None ) - takeup_filter = [spec["variable"] for spec in SIMPLE_TAKEUP_VARS] - - original_stdout = sys.stdout - sys.stdout = sys.stderr - - from policyengine_us_data.calibration.publish_local_area import ( - build_h5, - NYC_COUNTY_FIPS, - AT_LARGE_DISTRICTS, + from policyengine_us_data.utils.takeup import SIMPLE_TAKEUP_VARS + from policyengine_us_data.calibration.local_h5.contracts import ( + AreaBuildRequest, + ValidationPolicy, ) - from policyengine_us_data.calibration.calibration_utils import ( - STATE_CODES, + from policyengine_us_data.calibration.local_h5.package_geography import ( + require_calibration_package_path, ) - from policyengine_us_data.calibration.clone_and_assign import ( - assign_random_geography, + from policyengine_us_data.calibration.local_h5.worker_service import ( + LocalH5WorkerService, + WorkerSession, + build_requests_from_work_items, + load_validation_context, ) - - weights = np.load(weights_path) - - from policyengine_us import Microsimulation - - _sim = Microsimulation(dataset=str(dataset_path)) - n_records = len(_sim.calculate("household_id", map_to="household").values) - del _sim - - geography = assign_random_geography( - n_records=n_records, - n_clones=args.n_clones, - seed=args.seed, + from policyengine_us_data.calibration.publish_local_area import ( + AT_LARGE_DISTRICTS, + NYC_COUNTY_FIPS, + SUB_ENTITIES, ) - cds_to_calibrate = sorted(set(geography.cd_geoid.astype(str))) - geo_labels = cds_to_calibrate - print( - f"Generated geography: " - f"{geography.n_clones} clones x " - f"{geography.n_records} records", - file=sys.stderr, + from policyengine_us_data.calibration.calibration_utils import STATE_CODES + from policyengine_us_data.calibration.local_h5.source_dataset import ( + PolicyEngineDatasetReader, ) - # ── Validation setup (once per worker) ── - validation_targets = None - training_mask_full = None - constraints_map = None - if not args.no_validate: - from sqlalchemy import create_engine - from policyengine_us_data.calibration.validate_staging import ( - _query_all_active_targets, - _batch_stratum_constraints, - ) - from policyengine_us_data.calibration.unified_calibration import ( - load_target_config, - _match_rules, + takeup_filter = tuple(spec["variable"] for spec in SIMPLE_TAKEUP_VARS) + + original_stdout = sys.stdout + sys.stdout = sys.stderr + if calibration_package_path is not None: + calibration_package_path = require_calibration_package_path( + calibration_package_path ) - engine = create_engine(f"sqlite:///{db_path}") - validation_targets = _query_all_active_targets(engine, args.period) + validation_policy = ValidationPolicy(enabled=not args.no_validate) + validation_context = load_validation_context( + db_path=db_path, + period=args.period, + target_config_path=args.target_config, + validation_config_path=args.validation_config, + policy=validation_policy, + ) + if validation_context is not None: print( - f"Loaded {len(validation_targets)} validation targets", + f"Validation ready: {len(validation_context.validation_targets)} targets, " + f"{len(validation_context.constraints_map)} strata", file=sys.stderr, ) - # Apply exclude/include from validation config - if args.validation_config: - val_cfg = load_target_config(args.validation_config) - exc_rules = val_cfg.get("exclude", []) - if exc_rules: - exc_mask = _match_rules(validation_targets, exc_rules) - validation_targets = validation_targets[~exc_mask].reset_index( - drop=True - ) - inc_rules = val_cfg.get("include", []) - if inc_rules: - inc_mask = _match_rules(validation_targets, inc_rules) - validation_targets = validation_targets[inc_mask].reset_index(drop=True) - - # Compute training mask from training config - if args.target_config: - tr_cfg = load_target_config(args.target_config) - tr_inc = tr_cfg.get("include", []) - if tr_inc: - training_mask_full = np.asarray( - _match_rules(validation_targets, tr_inc), - dtype=bool, - ) - else: - training_mask_full = np.ones(len(validation_targets), dtype=bool) - else: - training_mask_full = np.ones(len(validation_targets), dtype=bool) - - # Batch-load constraints - stratum_ids = validation_targets["stratum_id"].unique().tolist() - constraints_map = _batch_stratum_constraints(engine, stratum_ids) + session = WorkerSession.load( + weights_path=weights_path, + dataset_path=dataset_path, + output_dir=output_dir, + calibration_package_path=calibration_package_path, + requested_n_clones=args.n_clones, + seed=args.seed, + takeup_filter=takeup_filter, + validation_policy=validation_policy, + validation_context=validation_context, + source_reader=PolicyEngineDatasetReader(tuple(SUB_ENTITIES)), + allow_seed_fallback=False, + ) + if session.requested_n_clones is not None and session.requested_n_clones != session.n_clones: print( - f"Validation ready: {len(validation_targets)} targets, " - f"{len(stratum_ids)} strata", + f"WARNING: weights imply {session.n_clones} clones " + f"but --n-clones={session.requested_n_clones}; using weights-derived value", file=sys.stderr, ) + print( + f"Loaded geography from {session.geography_source}: " + f"{session.geography.n_clones} clones x {session.geography.n_records} records", + file=sys.stderr, + ) + print( + f"Loaded source snapshot once for worker: " + f"{session.source_snapshot.n_households} households", + file=sys.stderr, + ) + for warning in session.geography_warnings: + print(f"WARNING: {warning}", file=sys.stderr) - results = { - "completed": [], - "failed": [], - "errors": [], - "validation_rows": [], - "validation_summary": {}, - } - - for item in work_items: - item_type = item["type"] - item_id = item["id"] - state_fips = None - candidate = None - cd_subset = None - - try: - if item_type == "state": - state_fips = None - for fips, code in STATE_CODES.items(): - if code == item_id: - state_fips = fips - break - if state_fips is None: - raise ValueError(f"Unknown state code: {item_id}") - cd_subset = [ - cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips - ] - if not cd_subset: - print( - f"No CDs for {item_id}, skipping", - file=sys.stderr, - ) - continue - states_dir = output_dir / "states" - states_dir.mkdir(parents=True, exist_ok=True) - path = build_h5( - weights=weights, - geography=geography, - dataset_path=dataset_path, - output_path=states_dir / f"{item_id}.h5", - cd_subset=cd_subset, - takeup_filter=takeup_filter, - ) - - elif item_type == "district": - state_code, dist_num = item_id.split("-") - state_fips = None - for fips, code in STATE_CODES.items(): - if code == state_code: - state_fips = fips - break - if state_fips is None: - raise ValueError(f"Unknown state in district: {item_id}") - - candidate = f"{state_fips}{int(dist_num):02d}" - if candidate in geo_labels: - geoid = candidate - else: - state_cds = [ - cd for cd in geo_labels if int(cd) // 100 == state_fips - ] - if len(state_cds) == 1: - geoid = state_cds[0] - else: - raise ValueError( - f"CD {candidate} not found and " - f"state {state_code} has " - f"{len(state_cds)} CDs" - ) - - cd_int = int(geoid) - district_num = cd_int % 100 - if district_num in AT_LARGE_DISTRICTS: - district_num = 1 - friendly_name = f"{state_code}-{district_num:02d}" - - districts_dir = output_dir / "districts" - districts_dir.mkdir(parents=True, exist_ok=True) - path = build_h5( - weights=weights, - geography=geography, - dataset_path=dataset_path, - output_path=districts_dir / f"{friendly_name}.h5", - cd_subset=[geoid], - takeup_filter=takeup_filter, - ) - - elif item_type == "city": - cities_dir = output_dir / "cities" - cities_dir.mkdir(parents=True, exist_ok=True) - path = build_h5( - weights=weights, - geography=geography, - dataset_path=dataset_path, - output_path=cities_dir / "NYC.h5", - county_fips_filter=NYC_COUNTY_FIPS, - takeup_filter=takeup_filter, - ) - - elif item_type == "national": - national_dir = output_dir / "national" - national_dir.mkdir(parents=True, exist_ok=True) - n_clones_from_weights = weights.shape[0] // n_records - if n_clones_from_weights != geography.n_clones: - print( - f"National weights have {n_clones_from_weights} clones " - f"but geography has {geography.n_clones}; " - f"regenerating geography", - file=sys.stderr, - ) - national_geo = assign_random_geography( - n_records=n_records, - n_clones=n_clones_from_weights, - seed=args.seed, - ) - else: - national_geo = geography - path = build_h5( - weights=weights, - geography=national_geo, - dataset_path=dataset_path, - output_path=national_dir / "US.h5", - ) - else: - raise ValueError(f"Unknown item type: {item_type}") - - if path: - results["completed"].append(f"{item_type}:{item_id}") - print( - f"Completed {item_type}:{item_id}", - file=sys.stderr, - ) - - # ── Per-item validation ── - if not args.no_validate and validation_targets is not None: - try: - v_rows = _validate_h5_subprocess( - h5_path=str(path), - item_type=item_type, - item_id=item_id, - state_fips=( - state_fips - if item_type in ("state", "district") - else None - ), - candidate=(candidate if item_type == "district" else None), - cd_subset=(cd_subset if item_type == "city" else None), - validation_targets=validation_targets, - training_mask_full=training_mask_full, - constraints_map=constraints_map, - db_path=str(db_path), - period=args.period, - ) - results["validation_rows"].extend(v_rows) - key = f"{item_type}:{item_id}" - n_fail = sum( - 1 for r in v_rows if r.get("sanity_check") == "FAIL" - ) - rae_vals = [ - r["rel_abs_error"] - for r in v_rows - if isinstance( - r.get("rel_abs_error"), - (int, float), - ) - and r["rel_abs_error"] != float("inf") - ] - mean_rae = sum(rae_vals) / len(rae_vals) if rae_vals else 0.0 - results["validation_summary"][key] = { - "n_targets": len(v_rows), - "n_sanity_fail": n_fail, - "mean_rel_abs_error": round(mean_rae, 4), - } - print( - f" Validated {key}: " - f"{len(v_rows)} targets, " - f"{n_fail} sanity fails, " - f"mean RAE={mean_rae:.4f}", - file=sys.stderr, - ) - except Exception as ve: - print( - f" Validation failed for {item_type}:{item_id}: {ve}", - file=sys.stderr, - ) - - except Exception as e: - results["failed"].append(f"{item_type}:{item_id}") - results["errors"].append( - { - "item": f"{item_type}:{item_id}", - "error": str(e), - "traceback": traceback.format_exc(), - } - ) - print( - f"FAILED {item_type}:{item_id}: {e}", - file=sys.stderr, - ) + if request_payloads is not None: + requests = tuple( + AreaBuildRequest.from_dict(payload) for payload in request_payloads + ) + initial_failures = () + else: + requests, initial_failures = build_requests_from_work_items( + work_items, + geography=session.geography, + state_codes=STATE_CODES, + at_large_districts=AT_LARGE_DISTRICTS, + nyc_county_fips=NYC_COUNTY_FIPS, + ) + service = LocalH5WorkerService() + worker_result = service.run( + session, + requests, + initial_failures=initial_failures, + ) sys.stdout = original_stdout - print(json.dumps(results)) + print(json.dumps(worker_result.to_dict())) if __name__ == "__main__": diff --git a/policyengine_us_data/calibration/local_h5/__init__.py b/policyengine_us_data/calibration/local_h5/__init__.py new file mode 100644 index 000000000..e285db655 --- /dev/null +++ b/policyengine_us_data/calibration/local_h5/__init__.py @@ -0,0 +1,114 @@ +"""Internal contracts and components for the local H5 refactor.""" + +from .contracts import ( + AreaBuildRequest, + AreaBuildResult, + AreaFilter, + BuildStatus, + FilterOp, + PublishingInputBundle, + ValidationIssue, + ValidationPolicy, + ValidationResult, + ValidationStatus, + WorkerResult, +) +from .area_catalog import USAreaCatalog, USCatalogEntry +from .builder import LocalAreaBuildArtifacts, LocalAreaDatasetBuilder +from .entity_graph import EntityGraph, EntityGraphExtractor +from .fingerprinting import ( + FingerprintComponents, + FingerprintInputs, + FingerprintRecord, + FingerprintService, +) +from .reindexing import EntityReindexer, ReindexedEntities +from .selection import AreaSelector, CloneSelection +from .partitioning import partition_weighted_work_items, work_item_key +from .package_geography import ( + CalibrationPackageGeographyLoader, + LoadedPackageGeography, + require_calibration_package_path, +) +from .source_dataset import ( + PolicyEngineDatasetReader, + PolicyEngineVariableArrayProvider, + SourceDatasetSnapshot, +) +from .us_augmentations import USAugmentationService +from .validation import ( + make_validation_error, + summarize_validation_rows, + tag_validation_errors, + validation_geo_level_for_area_type, +) +from .variables import H5Payload, VariableCloner, VariableExportPolicy +from .weights import CloneWeightMatrix, infer_clone_count_from_weight_length +from .worker_service import ( + LocalH5WorkerService, + ValidationContext, + WorkerSession, + build_request_from_work_item, + build_requests_from_work_items, + load_validation_context, + validate_in_subprocess, + validate_output_subprocess, + worker_result_to_legacy_dict, +) +from .writer import H5Writer + +__all__ = [ + "AreaBuildRequest", + "AreaBuildResult", + "AreaFilter", + "AreaSelector", + "BuildStatus", + "CalibrationPackageGeographyLoader", + "CloneSelection", + "CloneWeightMatrix", + "EntityGraph", + "EntityGraphExtractor", + "EntityReindexer", + "FingerprintComponents", + "FingerprintInputs", + "FingerprintRecord", + "FingerprintService", + "FilterOp", + "LoadedPackageGeography", + "H5Payload", + "H5Writer", + "LocalAreaBuildArtifacts", + "LocalAreaDatasetBuilder", + "PolicyEngineDatasetReader", + "PolicyEngineVariableArrayProvider", + "PublishingInputBundle", + "ReindexedEntities", + "SourceDatasetSnapshot", + "USAugmentationService", + "ValidationIssue", + "ValidationContext", + "ValidationPolicy", + "ValidationResult", + "ValidationStatus", + "WorkerResult", + "WorkerSession", + "USAreaCatalog", + "USCatalogEntry", + "LocalH5WorkerService", + "build_request_from_work_item", + "build_requests_from_work_items", + "infer_clone_count_from_weight_length", + "load_validation_context", + "make_validation_error", + "partition_weighted_work_items", + "require_calibration_package_path", + "summarize_validation_rows", + "tag_validation_errors", + "validate_in_subprocess", + "validate_output_subprocess", + "validation_geo_level_for_area_type", + "VariableCloner", + "VariableExportPolicy", + "worker_result_to_legacy_dict", + "work_item_key", +] diff --git a/policyengine_us_data/calibration/local_h5/area_catalog.py b/policyengine_us_data/calibration/local_h5/area_catalog.py new file mode 100644 index 000000000..9caf6471a --- /dev/null +++ b/policyengine_us_data/calibration/local_h5/area_catalog.py @@ -0,0 +1,255 @@ +"""Concrete US area request catalog for local H5 publishing.""" + +from __future__ import annotations + +from collections import Counter +from dataclasses import dataclass +from typing import Mapping, Sequence + +from policyengine_us_data.calibration.calibration_utils import ( + STATE_CODES, + get_all_cds_from_database, +) + +from .contracts import AreaBuildRequest, AreaFilter + + +AT_LARGE_DISTRICTS = {0, 98} +NYC_COUNTY_FIPS = {"36005", "36047", "36061", "36081", "36085"} +CITY_WEIGHTS = {"NYC": 11} + + +@dataclass(frozen=True) +class USCatalogEntry: + request: AreaBuildRequest + weight: int + + @property + def key(self) -> str: + return f"{self.request.area_type}:{self.request.area_id}" + + def to_partition_item(self) -> dict[str, object]: + return { + "type": self.request.area_type, + "id": self.request.area_id, + "weight": self.weight, + } + + +class USAreaCatalog: + """Build concrete US local-area requests from the calibration artifacts.""" + + def __init__( + self, + *, + state_codes: Mapping[int, str] = STATE_CODES, + at_large_districts: set[int] | None = None, + nyc_county_fips: set[str] | None = None, + city_weights: Mapping[str, int] | None = None, + ) -> None: + self.state_codes = dict(state_codes) + self.at_large_districts = set(at_large_districts or AT_LARGE_DISTRICTS) + self.nyc_county_fips = set(nyc_county_fips or NYC_COUNTY_FIPS) + self.city_weights = dict(city_weights or CITY_WEIGHTS) + + def load_regional_entries(self, db_uri: str) -> tuple[USCatalogEntry, ...]: + cds = tuple(str(cd) for cd in get_all_cds_from_database(db_uri)) + return self.regional_entries_from_cds(cds) + + def regional_entries_from_cds( + self, + cds: Sequence[str], + ) -> tuple[USCatalogEntry, ...]: + cds = tuple(str(cd) for cd in cds) + districts = tuple(self._district_friendly_name(cd) for cd in cds) + cds_per_state = Counter(district.split("-")[0] for district in districts) + states_with_cds = [ + state_code + for state_code in self.state_codes.values() + if cds_per_state.get(state_code, 0) > 0 + ] + + entries: list[USCatalogEntry] = [] + for state_code in states_with_cds: + entries.append( + USCatalogEntry( + request=AreaBuildRequest( + area_type="state", + area_id=state_code, + display_name=state_code, + output_relative_path=f"states/{state_code}.h5", + ), + weight=cds_per_state.get(state_code, 1), + ) + ) + + for cd, friendly_name in zip(cds, districts): + entries.append( + USCatalogEntry( + request=AreaBuildRequest( + area_type="district", + area_id=friendly_name, + display_name=friendly_name, + output_relative_path=f"districts/{friendly_name}.h5", + ), + weight=1, + ) + ) + + entries.append( + USCatalogEntry( + request=AreaBuildRequest( + area_type="city", + area_id="NYC", + display_name="NYC", + output_relative_path="cities/NYC.h5", + ), + weight=self.city_weights.get("NYC", 3), + ) + ) + return tuple(entries) + + def national_entry(self) -> USCatalogEntry: + return USCatalogEntry( + request=AreaBuildRequest.national(), + weight=1, + ) + + def _district_friendly_name(self, cd_geoid: str) -> str: + cd_int = int(cd_geoid) + state_fips = cd_int // 100 + district_num = cd_int % 100 + if district_num in self.at_large_districts: + district_num = 1 + state_code = self.state_codes.get(state_fips, str(state_fips)) + return f"{state_code}-{district_num:02d}" + + def filters_for_request( + self, + request: AreaBuildRequest, + *, + cds: Sequence[str], + ) -> tuple[AreaFilter, ...]: + if request.area_type == "state": + state_code = request.area_id + state_fips = self._state_fips_for_code(state_code) + cd_subset = tuple(cd for cd in cds if int(cd) // 100 == state_fips) + return ( + AreaFilter( + geography_field="cd_geoid", + op="in", + value=cd_subset, + ), + ) + + if request.area_type == "district": + matching_cd = next( + cd for cd in cds if self._district_friendly_name(cd) == request.area_id + ) + return ( + AreaFilter( + geography_field="cd_geoid", + op="in", + value=(matching_cd,), + ), + ) + + if request.area_type == "city": + return ( + AreaFilter( + geography_field="county_fips", + op="in", + value=tuple(sorted(self.nyc_county_fips)), + ), + ) + + return () + + def with_filters( + self, + entry: USCatalogEntry, + *, + cds: Sequence[str], + geography=None, + ) -> USCatalogEntry: + request = entry.request + filters = self.filters_for_request(request, cds=cds) + validation_geo_level = None + validation_geographic_ids: tuple[str, ...] = () + + if request.area_type == "state": + validation_geo_level = "state" + validation_geographic_ids = (str(self._state_fips_for_code(request.area_id)),) + elif request.area_type == "district": + validation_geo_level = "district" + validation_geographic_ids = tuple(item.value[0] for item in filters) + elif request.area_type == "city": + validation_geo_level = "district" + validation_geographic_ids = self._city_validation_cd_geoids( + cds=cds, + geography=geography, + ) + elif request.area_type == "national": + validation_geo_level = "national" + validation_geographic_ids = ("US",) + + return USCatalogEntry( + request=AreaBuildRequest( + area_type=request.area_type, + area_id=request.area_id, + display_name=request.display_name, + output_relative_path=request.output_relative_path, + filters=filters, + validation_geo_level=validation_geo_level, + validation_geographic_ids=validation_geographic_ids, + metadata=dict(request.metadata), + ), + weight=entry.weight, + ) + + def resolved_regional_entries( + self, + db_uri: str, + *, + geography=None, + ) -> tuple[USCatalogEntry, ...]: + cds = tuple(str(cd) for cd in get_all_cds_from_database(db_uri)) + return tuple( + self.with_filters(entry, cds=cds, geography=geography) + for entry in self.regional_entries_from_cds(cds) + ) + + def resolved_national_entry(self) -> USCatalogEntry: + return self.with_filters(self.national_entry(), cds=(), geography=None) + + def _state_fips_for_code(self, state_code: str) -> int: + for fips, code in self.state_codes.items(): + if code == state_code: + return int(fips) + raise ValueError(f"Unknown state code: {state_code}") + + def _city_validation_cd_geoids( + self, + *, + cds: Sequence[str], + geography, + ) -> tuple[str, ...]: + if geography is None: + return () + + county_fips = getattr(geography, "county_fips", None) + cd_geoids = getattr(geography, "cd_geoid", None) + if county_fips is None or cd_geoids is None: + return () + + available_cds = set(str(cd) for cd in cds) + return tuple( + sorted( + { + str(cd) + for cd, county in zip(cd_geoids, county_fips) + if str(county) in self.nyc_county_fips + and str(cd) in available_cds + } + ) + ) diff --git a/policyengine_us_data/calibration/local_h5/builder.py b/policyengine_us_data/calibration/local_h5/builder.py new file mode 100644 index 000000000..f199264a1 --- /dev/null +++ b/policyengine_us_data/calibration/local_h5/builder.py @@ -0,0 +1,157 @@ +"""One-area orchestration for local H5 publishing.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Sequence + +import numpy as np + +from .contracts import AreaFilter +from .reindexing import EntityReindexer, ReindexedEntities +from .selection import AreaSelector, CloneSelection +from .source_dataset import SourceDatasetSnapshot +from .us_augmentations import USAugmentationService +from .variables import H5Payload, VariableCloner, VariableExportPolicy +from .weights import CloneWeightMatrix + + +@dataclass(frozen=True) +class LocalAreaBuildArtifacts: + payload: H5Payload + selection: CloneSelection + reindexed: ReindexedEntities + time_period: int | str + + +class LocalAreaDatasetBuilder: + """Compose the pure local-H5 build steps for one output area.""" + + def __init__( + self, + *, + selector: AreaSelector | None = None, + reindexer: EntityReindexer | None = None, + variable_cloner: VariableCloner | None = None, + us_augmentations: USAugmentationService | None = None, + export_policy: VariableExportPolicy | None = None, + ) -> None: + self.selector = selector or AreaSelector() + self.reindexer = reindexer or EntityReindexer() + self.variable_cloner = variable_cloner or VariableCloner() + self.us_augmentations = us_augmentations or USAugmentationService() + self.export_policy = export_policy or VariableExportPolicy( + include_input_variables=True + ) + + def build( + self, + *, + weights: np.ndarray, + geography, + source: SourceDatasetSnapshot, + filters: tuple[AreaFilter, ...] = (), + takeup_filter: Sequence[str] | None = None, + ) -> LocalAreaBuildArtifacts: + weight_matrix = CloneWeightMatrix.from_vector(weights, source.n_households) + selection = self.selector.select( + weight_matrix, + geography, + filters=filters, + ) + self._validate_selection(selection=selection, filters=filters) + + reindexed = self.reindexer.reindex(source, selection) + time_period = source.time_period + cloned = self.variable_cloner.clone( + source, + reindexed, + self.export_policy, + ) + + data = { + variable: dict(periods) + for variable, periods in cloned.variables.items() + } + self._inject_entity_ids( + data=data, + time_period=time_period, + reindexed=reindexed, + ) + self._inject_household_weights( + data=data, + time_period=time_period, + active_weights=selection.active_weights, + ) + + self.us_augmentations.apply_all( + data, + time_period=time_period, + selection=selection, + source=source, + reindexed=reindexed, + takeup_filter=takeup_filter, + ) + + return LocalAreaBuildArtifacts( + payload=H5Payload(variables=data), + selection=selection, + reindexed=reindexed, + time_period=time_period, + ) + + def build_payload(self, **kwargs) -> H5Payload: + return self.build(**kwargs).payload + + def _inject_entity_ids( + self, + *, + data: dict[str, dict[int | str, np.ndarray]], + time_period: int | str, + reindexed: ReindexedEntities, + ) -> None: + data["household_id"] = {time_period: reindexed.new_household_ids} + data["person_id"] = {time_period: reindexed.new_person_ids} + data["person_household_id"] = { + time_period: reindexed.new_person_household_ids, + } + for entity_key, entity_ids in reindexed.new_entity_ids.items(): + data[f"{entity_key}_id"] = {time_period: entity_ids} + data[f"person_{entity_key}_id"] = { + time_period: reindexed.new_person_entity_ids[entity_key], + } + + def _inject_household_weights( + self, + *, + data: dict[str, dict[int | str, np.ndarray]], + time_period: int | str, + active_weights: np.ndarray, + ) -> None: + data["household_weight"] = { + time_period: active_weights.astype(np.float32), + } + + def _validate_selection( + self, + *, + selection: CloneSelection, + filters: tuple[AreaFilter, ...], + ) -> None: + if selection.is_empty: + raise ValueError( + "No active clones after filtering. " + f"filters={self._format_filters(filters)}" + ) + + empty_count = int(np.sum(selection.active_block_geoids == "")) + if empty_count > 0: + raise ValueError(f"{empty_count} active clones have empty block GEOIDs") + + def _format_filters(self, filters: tuple[AreaFilter, ...]) -> str: + if not filters: + return "()" + return ", ".join( + f"{area_filter.geography_field} {area_filter.op} {area_filter.value}" + for area_filter in filters + ) diff --git a/policyengine_us_data/calibration/local_h5/contracts.py b/policyengine_us_data/calibration/local_h5/contracts.py new file mode 100644 index 000000000..26637b6e8 --- /dev/null +++ b/policyengine_us_data/calibration/local_h5/contracts.py @@ -0,0 +1,345 @@ +"""Core value contracts for the local H5 refactor. + +These contracts intentionally avoid any PolicyEngine, Modal, or H5 IO. +They define the shapes that later services will exchange. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Literal, Mapping + +AreaType = Literal["national", "state", "district", "city", "custom"] +BuildStatus = Literal["completed", "failed"] +ValidationStatus = Literal["not_run", "passed", "failed", "error"] +FilterOp = Literal["eq", "in"] + + +def _jsonable(value: Any) -> Any: + """Convert common contract values into JSON-serializable primitives.""" + + if isinstance(value, Path): + return str(value) + if isinstance(value, tuple): + return [_jsonable(item) for item in value] + if isinstance(value, list): + return [_jsonable(item) for item in value] + if isinstance(value, Mapping): + return {str(key): _jsonable(item) for key, item in value.items()} + if hasattr(value, "to_dict") and callable(value.to_dict): + return value.to_dict() + return value + + +@dataclass(frozen=True) +class AreaFilter: + geography_field: str + op: FilterOp + value: str | int | tuple[str | int, ...] + + def __post_init__(self) -> None: + if not self.geography_field: + raise ValueError("geography_field must be non-empty") + if self.op == "in" and not isinstance(self.value, tuple): + raise ValueError("AreaFilter value must be a tuple when op='in'") + if self.op == "eq" and isinstance(self.value, tuple): + raise ValueError("AreaFilter value must not be a tuple when op='eq'") + + def to_dict(self) -> dict[str, Any]: + return { + "geography_field": self.geography_field, + "op": self.op, + "value": _jsonable(self.value), + } + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "AreaFilter": + value = data["value"] + if data["op"] == "in": + value = tuple(value) + return cls( + geography_field=str(data["geography_field"]), + op=data["op"], + value=value, + ) + + +@dataclass(frozen=True) +class AreaBuildRequest: + area_type: AreaType + area_id: str + display_name: str + output_relative_path: str + filters: tuple[AreaFilter, ...] = () + validation_geo_level: str | None = None + validation_geographic_ids: tuple[str, ...] = () + metadata: Mapping[str, str] = field(default_factory=dict) + + def __post_init__(self) -> None: + if not self.area_id: + raise ValueError("area_id must be non-empty") + if not self.display_name: + raise ValueError("display_name must be non-empty") + if not self.output_relative_path: + raise ValueError("output_relative_path must be non-empty") + if self.validation_geographic_ids and self.validation_geo_level is None: + raise ValueError( + "validation_geo_level must be set when validation_geographic_ids " + "are provided" + ) + + @classmethod + def national(cls, area_id: str = "US") -> "AreaBuildRequest": + return cls( + area_type="national", + area_id=area_id, + display_name=area_id, + output_relative_path="national/US.h5", + validation_geo_level="national", + validation_geographic_ids=(area_id,), + ) + + def to_dict(self) -> dict[str, Any]: + return { + "area_type": self.area_type, + "area_id": self.area_id, + "display_name": self.display_name, + "output_relative_path": self.output_relative_path, + "filters": [_jsonable(item) for item in self.filters], + "validation_geo_level": self.validation_geo_level, + "validation_geographic_ids": list(self.validation_geographic_ids), + "metadata": dict(self.metadata), + } + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "AreaBuildRequest": + return cls( + area_type=data["area_type"], + area_id=str(data["area_id"]), + display_name=str(data["display_name"]), + output_relative_path=str(data["output_relative_path"]), + filters=tuple( + AreaFilter.from_dict(item) + for item in data.get("filters", ()) + ), + validation_geo_level=data.get("validation_geo_level"), + validation_geographic_ids=tuple( + str(item) for item in data.get("validation_geographic_ids", ()) + ), + metadata=dict(data.get("metadata", {})), + ) + + +@dataclass(frozen=True) +class PublishingInputBundle: + weights_path: Path + source_dataset_path: Path + target_db_path: Path | None + calibration_package_path: Path | None + run_config_path: Path | None + run_id: str + version: str + n_clones: int | None + seed: int + + def __post_init__(self) -> None: + if not self.run_id: + raise ValueError("run_id must be non-empty") + if not self.version: + raise ValueError("version must be non-empty") + if self.n_clones is not None and self.n_clones <= 0: + raise ValueError("n_clones must be positive when provided") + + def required_paths(self) -> tuple[Path, ...]: + required = [self.weights_path, self.source_dataset_path] + if self.target_db_path is not None: + required.append(self.target_db_path) + if self.calibration_package_path is not None: + required.append(self.calibration_package_path) + return tuple(required) + + def to_dict(self) -> dict[str, Any]: + return { + "weights_path": str(self.weights_path), + "source_dataset_path": str(self.source_dataset_path), + "target_db_path": _jsonable(self.target_db_path), + "calibration_package_path": _jsonable(self.calibration_package_path), + "run_config_path": _jsonable(self.run_config_path), + "run_id": self.run_id, + "version": self.version, + "n_clones": self.n_clones, + "seed": self.seed, + } + + +@dataclass(frozen=True) +class ValidationPolicy: + """Validation controls for H5 worker execution. + + Only `enabled` is enforced today. The finer-grained failure and + sub-check flags are intentionally present as forward-compatible + contract fields, but they are not yet fully wired through the + validator implementations. + """ + + enabled: bool = True + fail_on_exception: bool = False + fail_on_validation_failure: bool = False + run_sanity_checks: bool = True + run_target_validation: bool = True + run_national_validation: bool = True + + def to_dict(self) -> dict[str, Any]: + return { + "enabled": self.enabled, + "fail_on_exception": self.fail_on_exception, + "fail_on_validation_failure": self.fail_on_validation_failure, + "run_sanity_checks": self.run_sanity_checks, + "run_target_validation": self.run_target_validation, + "run_national_validation": self.run_national_validation, + } + + +@dataclass(frozen=True) +class ValidationIssue: + code: str + message: str + severity: Literal["warning", "error"] + details: Mapping[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + if not self.code: + raise ValueError("code must be non-empty") + if not self.message: + raise ValueError("message must be non-empty") + + def to_dict(self) -> dict[str, Any]: + return { + "code": self.code, + "message": self.message, + "severity": self.severity, + "details": _jsonable(self.details), + } + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "ValidationIssue": + return cls( + code=str(data["code"]), + message=str(data["message"]), + severity=data["severity"], + details=dict(data.get("details", {})), + ) + + +@dataclass(frozen=True) +class ValidationResult: + status: ValidationStatus + rows: tuple[Mapping[str, Any], ...] = () + issues: tuple[ValidationIssue, ...] = () + summary: Mapping[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return { + "status": self.status, + "rows": [_jsonable(item) for item in self.rows], + "issues": [_jsonable(item) for item in self.issues], + "summary": _jsonable(self.summary), + } + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "ValidationResult": + return cls( + status=data["status"], + rows=tuple(dict(item) for item in data.get("rows", ())), + issues=tuple( + ValidationIssue.from_dict(item) + for item in data.get("issues", ()) + ), + summary=dict(data.get("summary", {})), + ) + + +@dataclass(frozen=True) +class AreaBuildResult: + request: AreaBuildRequest + build_status: BuildStatus + output_path: Path | None = None + build_error: str | None = None + validation: ValidationResult = field( + default_factory=lambda: ValidationResult(status="not_run") + ) + + def __post_init__(self) -> None: + if self.build_status == "completed": + if self.output_path is None: + raise ValueError("completed AreaBuildResult requires output_path") + if self.build_error is not None: + raise ValueError( + "completed AreaBuildResult must not include build_error" + ) + else: + if not self.build_error: + raise ValueError("failed AreaBuildResult requires build_error") + + def to_dict(self) -> dict[str, Any]: + return { + "request": self.request.to_dict(), + "build_status": self.build_status, + "output_path": _jsonable(self.output_path), + "build_error": self.build_error, + "validation": self.validation.to_dict(), + } + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "AreaBuildResult": + output_path = data.get("output_path") + return cls( + request=AreaBuildRequest.from_dict(data["request"]), + build_status=data["build_status"], + output_path=Path(output_path) if output_path is not None else None, + build_error=data.get("build_error"), + validation=ValidationResult.from_dict( + data.get("validation", {"status": "not_run"}) + ), + ) + + +@dataclass(frozen=True) +class WorkerResult: + completed: tuple[AreaBuildResult, ...] + failed: tuple[AreaBuildResult, ...] + worker_issues: tuple[ValidationIssue, ...] = () + + def __post_init__(self) -> None: + if any(item.build_status != "completed" for item in self.completed): + raise ValueError("all results in completed must have build_status='completed'") + if any(item.build_status != "failed" for item in self.failed): + raise ValueError("all results in failed must have build_status='failed'") + + def all_results(self) -> tuple[AreaBuildResult, ...]: + return self.completed + self.failed + + def to_dict(self) -> dict[str, Any]: + return { + "completed": [_jsonable(item) for item in self.completed], + "failed": [_jsonable(item) for item in self.failed], + "worker_issues": [_jsonable(item) for item in self.worker_issues], + } + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "WorkerResult": + return cls( + completed=tuple( + AreaBuildResult.from_dict(item) + for item in data.get("completed", ()) + ), + failed=tuple( + AreaBuildResult.from_dict(item) + for item in data.get("failed", ()) + ), + worker_issues=tuple( + ValidationIssue.from_dict(item) + for item in data.get("worker_issues", ()) + ), + ) diff --git a/policyengine_us_data/calibration/local_h5/entity_graph.py b/policyengine_us_data/calibration/local_h5/entity_graph.py new file mode 100644 index 000000000..2fbe5fc84 --- /dev/null +++ b/policyengine_us_data/calibration/local_h5/entity_graph.py @@ -0,0 +1,118 @@ +"""Source entity-relationship extraction for local H5 publishing.""" + +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Mapping, Sequence + +import numpy as np + + +@dataclass(frozen=True) +class EntityGraph: + """Static source entity relationships derived from the base dataset.""" + + household_ids: np.ndarray + person_household_ids: np.ndarray + hh_id_to_index: Mapping[int, int] + hh_to_persons: Mapping[int, tuple[int, ...]] + entity_id_arrays: Mapping[str, np.ndarray] + person_entity_id_arrays: Mapping[str, np.ndarray] + hh_to_entity: Mapping[str, Mapping[int, tuple[int, ...]]] + + +class EntityGraphExtractor: + """Build source entity-relationship maps from source arrays.""" + + def __init__(self, sub_entities: Sequence[str]): + self.sub_entities = tuple(sub_entities) + + def extract(self, simulation: Any, household_ids: np.ndarray) -> EntityGraph: + person_household_ids = np.asarray( + simulation.calculate("household_id", map_to="person").values + ) + entity_id_arrays = { + entity_key: np.asarray( + simulation.calculate(f"{entity_key}_id", map_to=entity_key).values + ) + for entity_key in self.sub_entities + } + person_entity_id_arrays = { + entity_key: np.asarray( + simulation.calculate( + f"person_{entity_key}_id", + map_to="person", + ).values + ) + for entity_key in self.sub_entities + } + return self.extract_from_arrays( + household_ids=np.asarray(household_ids), + person_household_ids=person_household_ids, + entity_id_arrays=entity_id_arrays, + person_entity_id_arrays=person_entity_id_arrays, + ) + + def extract_from_arrays( + self, + *, + household_ids: np.ndarray, + person_household_ids: np.ndarray, + entity_id_arrays: Mapping[str, np.ndarray], + person_entity_id_arrays: Mapping[str, np.ndarray], + ) -> EntityGraph: + household_ids = np.asarray(household_ids) + person_household_ids = np.asarray(person_household_ids) + + hh_id_to_index = {int(hid): idx for idx, hid in enumerate(household_ids)} + + hh_to_persons_lists: dict[int, list[int]] = defaultdict(list) + for person_idx, household_id in enumerate(person_household_ids): + hh_to_persons_lists[hh_id_to_index[int(household_id)]].append(person_idx) + hh_to_persons = { + hh_idx: tuple(person_indices) + for hh_idx, person_indices in hh_to_persons_lists.items() + } + + hh_to_entity: dict[str, dict[int, tuple[int, ...]]] = {} + normalized_entity_id_arrays = { + entity_key: np.asarray(entity_values) + for entity_key, entity_values in entity_id_arrays.items() + } + normalized_person_entity_id_arrays = { + entity_key: np.asarray(entity_values) + for entity_key, entity_values in person_entity_id_arrays.items() + } + + for entity_key in self.sub_entities: + entity_ids = normalized_entity_id_arrays[entity_key] + person_entity_ids = normalized_person_entity_id_arrays[entity_key] + entity_id_to_index = { + int(entity_id): entity_idx + for entity_idx, entity_id in enumerate(entity_ids) + } + + mapping_lists: dict[int, list[int]] = defaultdict(list) + seen: dict[int, set[int]] = defaultdict(set) + for person_idx, household_id in enumerate(person_household_ids): + hh_idx = hh_id_to_index[int(household_id)] + entity_idx = entity_id_to_index[int(person_entity_ids[person_idx])] + if entity_idx not in seen[hh_idx]: + seen[hh_idx].add(entity_idx) + mapping_lists[hh_idx].append(entity_idx) + + hh_to_entity[entity_key] = { + hh_idx: tuple(sorted(entity_indices)) + for hh_idx, entity_indices in mapping_lists.items() + } + + return EntityGraph( + household_ids=household_ids, + person_household_ids=person_household_ids, + hh_id_to_index=hh_id_to_index, + hh_to_persons=hh_to_persons, + entity_id_arrays=normalized_entity_id_arrays, + person_entity_id_arrays=normalized_person_entity_id_arrays, + hh_to_entity=hh_to_entity, + ) diff --git a/policyengine_us_data/calibration/local_h5/fingerprinting.py b/policyengine_us_data/calibration/local_h5/fingerprinting.py new file mode 100644 index 000000000..531c60977 --- /dev/null +++ b/policyengine_us_data/calibration/local_h5/fingerprinting.py @@ -0,0 +1,298 @@ +"""Semantic fingerprinting for local H5 publish inputs.""" + +from __future__ import annotations + +import hashlib +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Mapping + +import numpy as np + +from policyengine_us_data.calibration.local_h5.package_geography import ( + CalibrationPackageGeographyLoader, + require_calibration_package_path, +) + + +def _require_file(path: str | Path, *, label: str) -> Path: + path = Path(path) + if not path.is_file(): + raise FileNotFoundError(f"Required {label} file not found at {path}") + return path + + +@dataclass(frozen=True) +class FingerprintInputs: + weights_path: Path + dataset_path: Path + calibration_package_path: Path + n_clones: int + seed: int + weights_length: int + n_records: int + + def to_dict(self) -> dict[str, str | int]: + return { + "weights_path": str(self.weights_path), + "dataset_path": str(self.dataset_path), + "calibration_package_path": str(self.calibration_package_path), + "n_clones": self.n_clones, + "seed": self.seed, + "weights_length": self.weights_length, + "n_records": self.n_records, + } + + +@dataclass(frozen=True) +class FingerprintComponents: + weights_sha256: str + dataset_sha256: str + geography_sha256: str + n_clones: int + seed: int + + def to_dict(self) -> dict[str, str | int]: + return { + "weights_sha256": self.weights_sha256, + "dataset_sha256": self.dataset_sha256, + "geography_sha256": self.geography_sha256, + "n_clones": self.n_clones, + "seed": self.seed, + } + + @classmethod + def from_dict(cls, payload: Mapping[str, Any]) -> "FingerprintComponents": + return cls( + weights_sha256=str(payload["weights_sha256"]), + dataset_sha256=str(payload["dataset_sha256"]), + geography_sha256=str(payload["geography_sha256"]), + n_clones=int(payload["n_clones"]), + seed=int(payload["seed"]), + ) + + +@dataclass(frozen=True) +class FingerprintRecord: + schema_version: str + algorithm: str + digest: str + components: FingerprintComponents | None = None + inputs: Mapping[str, str | int] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "fingerprint": self.digest, + "digest": self.digest, + "schema_version": self.schema_version, + "algorithm": self.algorithm, + } + if self.components is not None: + payload["components"] = self.components.to_dict() + if self.inputs: + payload["inputs"] = dict(self.inputs) + return payload + + +class FingerprintService: + """Single authoritative definition of local H5 publish identity.""" + + SCHEMA_VERSION = "local_h5_publish_v1" + ALGORITHM = "sha256-truncated-16" + DIGEST_HEX_CHARS = 16 + _GEOGRAPHY_FIELDS = ( + "block_geoid", + "cd_geoid", + "county_fips", + "state_fips", + "n_records", + "n_clones", + ) + + def __init__(self, loader: CalibrationPackageGeographyLoader | None = None): + self.loader = loader or CalibrationPackageGeographyLoader() + + def build_inputs( + self, + *, + weights_path: str | Path, + dataset_path: str | Path, + calibration_package_path: str | Path, + n_clones: int, + seed: int, + ) -> FingerprintInputs: + if n_clones <= 0: + raise ValueError("n_clones must be positive") + + weights_path = _require_file(weights_path, label="weights") + weights_length = self._weight_length(weights_path) + if weights_length % n_clones != 0: + raise ValueError( + "Weight vector length " + f"{weights_length} is not divisible by n_clones={n_clones}" + ) + n_records = weights_length // n_clones + + return FingerprintInputs( + weights_path=weights_path, + dataset_path=_require_file(dataset_path, label="dataset"), + calibration_package_path=require_calibration_package_path( + calibration_package_path + ), + n_clones=int(n_clones), + seed=int(seed), + weights_length=int(weights_length), + n_records=int(n_records), + ) + + def create_publish_fingerprint( + self, + *, + weights_path: str | Path, + dataset_path: str | Path, + calibration_package_path: str | Path, + n_clones: int, + seed: int, + ) -> FingerprintRecord: + inputs = self.build_inputs( + weights_path=weights_path, + dataset_path=dataset_path, + calibration_package_path=calibration_package_path, + n_clones=n_clones, + seed=seed, + ) + return self.create_from_inputs(inputs) + + def create_from_inputs(self, inputs: FingerprintInputs) -> FingerprintRecord: + resolved = self.loader.resolve_for_weights( + package_path=inputs.calibration_package_path, + weights_length=inputs.weights_length, + n_records=inputs.n_records, + n_clones=inputs.n_clones, + seed=inputs.seed, + allow_seed_fallback=False, + ) + # Compatibility note: v1 still includes n_clones and seed in the + # digest so existing staged run directories keep their current + # resume semantics. Long-term, package-backed publishing should not + # need either field in the fingerprint because geography_sha256 + # already encodes the exact clone layout and the upstream random + # outcome that actually drives H5 output. + components = FingerprintComponents( + weights_sha256=self._sha256_file(inputs.weights_path), + dataset_sha256=self._sha256_file(inputs.dataset_path), + geography_sha256=self._sha256_geography_payload( + self.loader.serialize_geography(resolved.geography) + ), + n_clones=inputs.n_clones, + seed=inputs.seed, + ) + digest = self._compute_digest(components) + return FingerprintRecord( + schema_version=self.SCHEMA_VERSION, + algorithm=self.ALGORITHM, + digest=digest, + components=components, + inputs=inputs.to_dict(), + ) + + def serialize(self, record: FingerprintRecord) -> dict[str, Any]: + return record.to_dict() + + def deserialize(self, payload: Mapping[str, Any]) -> FingerprintRecord: + digest = payload.get("digest") or payload.get("fingerprint") + if not digest: + raise ValueError("Fingerprint payload is missing digest/fingerprint") + + components_payload = payload.get("components") + components = None + if isinstance(components_payload, Mapping): + components = FingerprintComponents.from_dict(components_payload) + + inputs_payload = payload.get("inputs") + inputs: Mapping[str, str | int] + if isinstance(inputs_payload, Mapping): + inputs = { + str(key): value for key, value in inputs_payload.items() + } + else: + inputs = {} + + return FingerprintRecord( + schema_version=str(payload.get("schema_version", "legacy")), + algorithm=str(payload.get("algorithm", self.ALGORITHM)), + digest=str(digest), + components=components, + inputs=inputs, + ) + + def write_record(self, path: str | Path, record: FingerprintRecord) -> None: + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + json.dump(self.serialize(record), f, indent=2, sort_keys=True) + + def read_record(self, path: str | Path) -> FingerprintRecord: + with open(path) as f: + payload = json.load(f) + return self.deserialize(payload) + + def matches( + self, + stored: FingerprintRecord, + current: FingerprintRecord, + ) -> bool: + return stored.digest == current.digest + + def legacy_record(self, digest: str) -> FingerprintRecord: + return FingerprintRecord( + schema_version="legacy", + algorithm=self.ALGORITHM, + digest=str(digest), + ) + + def _compute_digest(self, components: FingerprintComponents) -> str: + payload = { + "schema_version": self.SCHEMA_VERSION, + **components.to_dict(), + } + raw = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode( + "utf-8" + ) + return hashlib.sha256(raw).hexdigest()[: self.DIGEST_HEX_CHARS] + + def _sha256_file(self, path: Path) -> str: + h = hashlib.sha256() + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(1 << 20), b""): + h.update(chunk) + return h.hexdigest() + + def _weight_length(self, path: Path) -> int: + values = np.load(path, mmap_mode="r") + return int(np.asarray(values).size) + + def _sha256_geography_payload(self, payload: Mapping[str, Any]) -> str: + h = hashlib.sha256() + for field in self._GEOGRAPHY_FIELDS: + h.update(field.encode("utf-8")) + h.update(b"\0") + value = payload[field] + if field in ("n_records", "n_clones"): + h.update(str(int(value)).encode("utf-8")) + h.update(b"\0") + continue + + arr = np.asarray(value) + h.update(str(arr.shape).encode("utf-8")) + h.update(b"\0") + if arr.dtype.kind in "iufb": + h.update(str(arr.dtype).encode("utf-8")) + h.update(b"\0") + h.update(np.ascontiguousarray(arr).tobytes()) + else: + for item in arr.reshape(-1): + h.update(str(item).encode("utf-8")) + h.update(b"\0") + return h.hexdigest() diff --git a/policyengine_us_data/calibration/local_h5/package_geography.py b/policyengine_us_data/calibration/local_h5/package_geography.py new file mode 100644 index 000000000..c8c32c287 --- /dev/null +++ b/policyengine_us_data/calibration/local_h5/package_geography.py @@ -0,0 +1,345 @@ +"""Load and serialize calibration-package geography for H5 publishing.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Mapping + +import numpy as np + + +def require_calibration_package_path(package_path: str | Path) -> Path: + """Require that a calibration package file exists.""" + + package_path = Path(package_path) + if not package_path.is_file(): + raise FileNotFoundError( + "Required calibration package not found at " + f"{package_path}. H5 publishing now requires the exact " + "calibration_package.pkl so geography is not regenerated from seed." + ) + return package_path + + +@dataclass(frozen=True) +class LoadedPackageGeography: + """Resolved geography plus provenance for publisher logging/tests.""" + + geography: Any + source: str + warnings: tuple[str, ...] = () + + +class CalibrationPackageGeographyLoader: + """Read exact geography assignments from calibration packages. + + This loader prefers the newer serialized ``geography`` payload, + but can still reconstruct a ``GeographyAssignment`` from the older + top-level ``block_geoid``/``cd_geoid`` arrays. + """ + + def serialize_geography(self, geography: Any) -> dict[str, Any]: + return { + "block_geoid": np.asarray(geography.block_geoid, dtype=str), + "cd_geoid": np.asarray(geography.cd_geoid, dtype=str), + "county_fips": np.asarray(geography.county_fips, dtype=str), + "state_fips": np.asarray(geography.state_fips, dtype=np.int64), + "n_records": int(geography.n_records), + "n_clones": int(geography.n_clones), + } + + def load( + self, + package_path: str | Path, + *, + fallback_n_records: int | None = None, + fallback_n_clones: int | None = None, + ) -> LoadedPackageGeography | None: + import pickle + + with open(package_path, "rb") as f: + package = pickle.load(f) + return self.load_from_package_dict( + package, + fallback_n_records=fallback_n_records, + fallback_n_clones=fallback_n_clones, + ) + + def load_from_package_dict( + self, + package: Mapping[str, Any], + *, + fallback_n_records: int | None = None, + fallback_n_clones: int | None = None, + ) -> LoadedPackageGeography | None: + payload = package.get("geography") + if isinstance(payload, Mapping): + geography = self._build_from_serialized_payload( + payload, + fallback_n_records=fallback_n_records, + fallback_n_clones=fallback_n_clones, + ) + return LoadedPackageGeography( + geography=geography, + source="serialized_package", + ) + + if package.get("block_geoid") is not None and package.get("cd_geoid") is not None: + geography = self._build_from_legacy_package( + package, + fallback_n_records=fallback_n_records, + fallback_n_clones=fallback_n_clones, + ) + return LoadedPackageGeography( + geography=geography, + source="legacy_package", + warnings=( + "Calibration package does not include serialized geography; " + "reconstructed from legacy arrays.", + ), + ) + + return None + + def resolve_for_weights( + self, + *, + package_path: str | Path | None, + weights_length: int, + n_records: int, + n_clones: int, + seed: int, + allow_seed_fallback: bool = True, + ) -> LoadedPackageGeography: + warnings: list[str] = [] + + if package_path: + package_path = Path(package_path) + if package_path.exists(): + load_error = None + try: + loaded = self.load( + package_path, + fallback_n_records=n_records, + fallback_n_clones=n_clones, + ) + except ValueError as error: + loaded = None + load_error = str(error) + if loaded is not None: + actual_len = len(np.asarray(loaded.geography.block_geoid)) + actual_records = int(getattr(loaded.geography, "n_records", -1)) + actual_clones = int(getattr(loaded.geography, "n_clones", -1)) + if ( + actual_len == weights_length + and actual_records == n_records + and actual_clones == n_clones + ): + if loaded.warnings: + warnings.extend(loaded.warnings) + return LoadedPackageGeography( + geography=loaded.geography, + source=loaded.source, + warnings=tuple(warnings), + ) + dimension_error = ( + "Calibration package geography is incompatible with " + "the requested publish shape: " + f"length={actual_len} records={actual_records} " + f"clones={actual_clones}, expected length={weights_length} " + f"records={n_records} clones={n_clones}." + ) + if not allow_seed_fallback: + raise ValueError(dimension_error) + warnings.append(f"{dimension_error} Regenerating from seed.") + else: + if load_error is not None: + if not allow_seed_fallback: + raise ValueError( + "Calibration package geography could not be loaded " + f"({load_error})." + ) + warnings.append( + "Calibration package geography could not be loaded " + f"({load_error}); regenerating from seed." + ) + else: + if not allow_seed_fallback: + raise ValueError( + "Calibration package does not include usable geography." + ) + warnings.append( + "Calibration package does not include usable geography; " + "regenerating from seed." + ) + else: + if not allow_seed_fallback: + raise FileNotFoundError( + f"Calibration package not found at {package_path}." + ) + warnings.append( + f"Calibration package not found at {package_path}; regenerating from seed." + ) + + elif not allow_seed_fallback: + raise ValueError( + "Calibration package path is required for strict geography resolution." + ) + + return LoadedPackageGeography( + geography=self._generate_geography( + n_records=n_records, + n_clones=n_clones, + seed=seed, + ), + source="generated", + warnings=tuple(warnings), + ) + + def _build_from_serialized_payload( + self, + payload: Mapping[str, Any], + *, + fallback_n_records: int | None, + fallback_n_clones: int | None, + ) -> Any: + blocks = self._string_array(payload["block_geoid"]) + cds = self._string_array(payload["cd_geoid"]) + n_records, n_clones = self._infer_dimensions( + total_length=len(blocks), + n_records=payload.get("n_records"), + n_clones=payload.get("n_clones"), + fallback_n_records=fallback_n_records, + fallback_n_clones=fallback_n_clones, + ) + county_fips = payload.get("county_fips") + if county_fips is None: + county_fips = self._derive_county_fips(blocks) + else: + county_fips = self._string_array(county_fips) + state_fips = payload.get("state_fips") + if state_fips is None: + state_fips = self._derive_state_fips(blocks) + else: + state_fips = np.asarray(state_fips, dtype=np.int64) + return self._build_assignment( + block_geoid=blocks, + cd_geoid=cds, + county_fips=county_fips, + state_fips=state_fips, + n_records=n_records, + n_clones=n_clones, + ) + + def _build_from_legacy_package( + self, + package: Mapping[str, Any], + *, + fallback_n_records: int | None, + fallback_n_clones: int | None, + ) -> Any: + blocks = self._string_array(package["block_geoid"]) + cds = self._string_array(package["cd_geoid"]) + metadata = package.get("metadata") or {} + n_records, n_clones = self._infer_dimensions( + total_length=len(blocks), + n_records=metadata.get("base_n_records"), + n_clones=metadata.get("n_clones"), + fallback_n_records=fallback_n_records, + fallback_n_clones=fallback_n_clones, + ) + return self._build_assignment( + block_geoid=blocks, + cd_geoid=cds, + county_fips=self._derive_county_fips(blocks), + state_fips=self._derive_state_fips(blocks), + n_records=n_records, + n_clones=n_clones, + ) + + def _infer_dimensions( + self, + *, + total_length: int, + n_records: int | None, + n_clones: int | None, + fallback_n_records: int | None, + fallback_n_clones: int | None, + ) -> tuple[int, int]: + resolved_records = self._as_int(n_records) or self._as_int(fallback_n_records) + resolved_clones = self._as_int(n_clones) or self._as_int(fallback_n_clones) + + if resolved_records is None and resolved_clones is not None: + if total_length % resolved_clones != 0: + raise ValueError( + "Cannot infer base record count from package geometry length " + f"{total_length} and n_clones={resolved_clones}" + ) + resolved_records = total_length // resolved_clones + if resolved_clones is None and resolved_records is not None: + if total_length % resolved_records != 0: + raise ValueError( + "Cannot infer clone count from package geometry length " + f"{total_length} and n_records={resolved_records}" + ) + resolved_clones = total_length // resolved_records + + if resolved_records is None or resolved_clones is None: + raise ValueError( + "Calibration package geography is missing n_records/n_clones metadata" + ) + if resolved_records * resolved_clones != total_length: + raise ValueError( + "Calibration package geography dimensions do not match array length: " + f"{resolved_records} x {resolved_clones} != {total_length}" + ) + return resolved_records, resolved_clones + + def _generate_geography(self, *, n_records: int, n_clones: int, seed: int) -> Any: + from policyengine_us_data.calibration.clone_and_assign import ( + assign_random_geography, + ) + + return assign_random_geography( + n_records=n_records, + n_clones=n_clones, + seed=seed, + ) + + def _build_assignment( + self, + *, + block_geoid: np.ndarray, + cd_geoid: np.ndarray, + county_fips: np.ndarray, + state_fips: np.ndarray, + n_records: int, + n_clones: int, + ) -> Any: + from policyengine_us_data.calibration.clone_and_assign import ( + GeographyAssignment, + ) + + return GeographyAssignment( + block_geoid=block_geoid, + cd_geoid=cd_geoid, + county_fips=county_fips, + state_fips=state_fips, + n_records=n_records, + n_clones=n_clones, + ) + + def _derive_county_fips(self, blocks: np.ndarray) -> np.ndarray: + return np.asarray([str(block)[:5] for block in blocks], dtype=str) + + def _derive_state_fips(self, blocks: np.ndarray) -> np.ndarray: + return np.asarray([int(str(block)[:2]) for block in blocks], dtype=np.int64) + + def _string_array(self, values: Any) -> np.ndarray: + return np.asarray(values, dtype=str) + + def _as_int(self, value: Any) -> int | None: + if value is None: + return None + return int(value) diff --git a/policyengine_us_data/calibration/local_h5/partitioning.py b/policyengine_us_data/calibration/local_h5/partitioning.py new file mode 100644 index 000000000..230f464fb --- /dev/null +++ b/policyengine_us_data/calibration/local_h5/partitioning.py @@ -0,0 +1,54 @@ +"""Pure helpers for assigning weighted work items to worker chunks.""" + +from __future__ import annotations + +import heapq +from collections.abc import Mapping, Sequence +from typing import Any + + +def work_item_key(item: Mapping[str, Any]) -> str: + """Return the stable completion key used by the current H5 workers.""" + + return f"{item['type']}:{item['id']}" + + +def partition_weighted_work_items( + work_items: Sequence[Mapping[str, Any]], + num_workers: int, + completed: set[str] | None = None, +) -> list[list[Mapping[str, Any]]]: + """Partition work items across workers using longest-processing-time first. + + The current H5 pipeline represents work items as mappings with: + - `type` + - `id` + - `weight` + + This helper stays compatible with that existing shape so the orchestration + layer can adopt a tested partitioning seam before the request model is + migrated in later commits. + """ + + if num_workers <= 0: + return [] + + completed = completed or set() + remaining = [ + item for item in work_items if work_item_key(item) not in completed + ] + remaining.sort(key=lambda item: -item["weight"]) + + n_workers = min(num_workers, len(remaining)) + if n_workers == 0: + return [] + + heap: list[tuple[int | float, int]] = [(0, idx) for idx in range(n_workers)] + chunks: list[list[Mapping[str, Any]]] = [[] for _ in range(n_workers)] + + for item in remaining: + load, idx = heapq.heappop(heap) + chunks[idx].append(item) + heapq.heappush(heap, (load + item["weight"], idx)) + + return [chunk for chunk in chunks if chunk] diff --git a/policyengine_us_data/calibration/local_h5/reindexing.py b/policyengine_us_data/calibration/local_h5/reindexing.py new file mode 100644 index 000000000..decb397f8 --- /dev/null +++ b/policyengine_us_data/calibration/local_h5/reindexing.py @@ -0,0 +1,153 @@ +"""Pure entity reindexing for local H5 publishing.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Mapping + +import numpy as np + +from .selection import CloneSelection +from .source_dataset import SourceDatasetSnapshot + + +@dataclass(frozen=True) +class ReindexedEntities: + household_source_indices: np.ndarray + person_source_indices: np.ndarray + entity_source_indices: Mapping[str, np.ndarray] + persons_per_clone: np.ndarray + entities_per_clone: Mapping[str, np.ndarray] + new_household_ids: np.ndarray + new_person_ids: np.ndarray + new_person_household_ids: np.ndarray + new_entity_ids: Mapping[str, np.ndarray] + new_person_entity_ids: Mapping[str, np.ndarray] + + +class EntityReindexer: + """Build output IDs and cross-references from a clone selection.""" + + def reindex( + self, + source: SourceDatasetSnapshot, + selection: CloneSelection, + ) -> ReindexedEntities: + entity_graph = source.entity_graph + household_source_indices = np.asarray( + selection.active_household_indices, + dtype=np.int64, + ) + n_household_clones = len(household_source_indices) + + persons_per_clone = np.asarray( + [ + len(entity_graph.hh_to_persons.get(int(household_idx), ())) + for household_idx in household_source_indices + ], + dtype=np.int64, + ) + person_parts = [ + np.asarray( + entity_graph.hh_to_persons.get(int(household_idx), ()), + dtype=np.int64, + ) + for household_idx in household_source_indices + ] + person_source_indices = ( + np.concatenate(person_parts) + if person_parts + else np.asarray([], dtype=np.int64) + ) + + entity_source_indices: dict[str, np.ndarray] = {} + entities_per_clone: dict[str, np.ndarray] = {} + for entity_key in entity_graph.entity_id_arrays: + per_clone_counts = np.asarray( + [ + len(entity_graph.hh_to_entity[entity_key].get(int(household_idx), ())) + for household_idx in household_source_indices + ], + dtype=np.int64, + ) + entities_per_clone[entity_key] = per_clone_counts + entity_parts = [ + np.asarray( + entity_graph.hh_to_entity[entity_key].get(int(household_idx), ()), + dtype=np.int64, + ) + for household_idx in household_source_indices + ] + entity_source_indices[entity_key] = ( + np.concatenate(entity_parts) + if entity_parts + else np.asarray([], dtype=np.int64) + ) + + n_persons = len(person_source_indices) + new_household_ids = np.arange(n_household_clones, dtype=np.int32) + new_person_ids = np.arange(n_persons, dtype=np.int32) + new_person_household_ids = np.repeat(new_household_ids, persons_per_clone) + clone_ids_for_persons = np.repeat( + np.arange(n_household_clones, dtype=np.int64), + persons_per_clone, + ) + + new_entity_ids: dict[str, np.ndarray] = {} + new_person_entity_ids: dict[str, np.ndarray] = {} + + for entity_key, source_indices in entity_source_indices.items(): + entity_count = len(source_indices) + new_entity_ids[entity_key] = np.arange(entity_count, dtype=np.int32) + + if entity_count == 0: + if n_persons != 0: + raise ValueError( + f"No source {entity_key} entities for selected persons" + ) + new_person_entity_ids[entity_key] = np.asarray([], dtype=np.int32) + continue + + old_entity_ids = entity_graph.entity_id_arrays[entity_key][ + source_indices + ].astype(np.int64) + clone_ids_for_entities = np.repeat( + np.arange(n_household_clones, dtype=np.int64), + entities_per_clone[entity_key], + ) + + offset = int(old_entity_ids.max()) + 1 if old_entity_ids.size else 1 + entity_keys = clone_ids_for_entities * offset + old_entity_ids + + sorted_order = np.argsort(entity_keys) + sorted_keys = entity_keys[sorted_order] + sorted_new_ids = new_entity_ids[entity_key][sorted_order] + + old_person_entity_ids = entity_graph.person_entity_id_arrays[entity_key][ + person_source_indices + ].astype(np.int64) + person_keys = clone_ids_for_persons * offset + old_person_entity_ids + + positions = np.searchsorted(sorted_keys, person_keys) + if np.any(positions >= len(sorted_keys)): + raise ValueError( + f"Could not map selected persons to new {entity_key} IDs" + ) + if np.any(sorted_keys[positions] != person_keys): + raise ValueError( + f"Inconsistent selected {entity_key} mappings for persons" + ) + new_person_entity_ids[entity_key] = sorted_new_ids[positions] + + return ReindexedEntities( + household_source_indices=household_source_indices, + person_source_indices=person_source_indices, + entity_source_indices=entity_source_indices, + persons_per_clone=persons_per_clone, + entities_per_clone=entities_per_clone, + new_household_ids=new_household_ids, + new_person_ids=new_person_ids, + new_person_household_ids=new_person_household_ids, + new_entity_ids=new_entity_ids, + new_person_entity_ids=new_person_entity_ids, + ) diff --git a/policyengine_us_data/calibration/local_h5/selection.py b/policyengine_us_data/calibration/local_h5/selection.py new file mode 100644 index 000000000..ea78b4d00 --- /dev/null +++ b/policyengine_us_data/calibration/local_h5/selection.py @@ -0,0 +1,140 @@ +"""Pure area-selection helpers for local H5 publishing.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import numpy as np + +from .contracts import AreaFilter +from .weights import CloneWeightMatrix + + +@dataclass(frozen=True) +class CloneSelection: + active_clone_indices: np.ndarray + active_household_indices: np.ndarray + active_weights: np.ndarray + active_block_geoids: np.ndarray + active_cd_geoids: np.ndarray + active_county_fips: np.ndarray + active_state_fips: np.ndarray + + @property + def n_household_clones(self) -> int: + return int(len(self.active_household_indices)) + + @property + def is_empty(self) -> bool: + return self.n_household_clones == 0 + + +class AreaSelector: + """Select active clone-household cells for an area's geography filters.""" + + _SUPPORTED_FIELDS = ( + "block_geoid", + "cd_geoid", + "county_fips", + "state_fips", + ) + + def select( + self, + weights: CloneWeightMatrix, + geography: Any, + *, + filters: tuple[AreaFilter, ...] = (), + ) -> CloneSelection: + self._validate_geography_shape(weights, geography) + + weight_matrix = weights.as_matrix() + shape = weight_matrix.shape + + block_matrix = self._field_matrix(geography, "block_geoid", shape) + cd_matrix = self._field_matrix(geography, "cd_geoid", shape) + county_matrix = self._field_matrix(geography, "county_fips", shape) + state_matrix = self._field_matrix(geography, "state_fips", shape) + + active_mask = weight_matrix > 0 + for area_filter in filters: + active_mask &= self._apply_filter( + values=self._field_matrix( + geography, + area_filter.geography_field, + shape, + ), + area_filter=area_filter, + ) + + active_clone_indices, active_household_indices = np.where(active_mask) + + return CloneSelection( + active_clone_indices=active_clone_indices.astype(np.int64), + active_household_indices=active_household_indices.astype(np.int64), + active_weights=weight_matrix[ + active_clone_indices, active_household_indices + ], + active_block_geoids=block_matrix[ + active_clone_indices, active_household_indices + ], + active_cd_geoids=cd_matrix[active_clone_indices, active_household_indices], + active_county_fips=county_matrix[ + active_clone_indices, active_household_indices + ], + active_state_fips=state_matrix[ + active_clone_indices, active_household_indices + ], + ) + + def _validate_geography_shape( + self, + weights: CloneWeightMatrix, + geography: Any, + ) -> None: + if getattr(geography, "n_records", weights.n_records) != weights.n_records: + raise ValueError( + "Geography n_records does not match weight matrix " + f"({getattr(geography, 'n_records', None)} != {weights.n_records})" + ) + if getattr(geography, "n_clones", weights.n_clones) != weights.n_clones: + raise ValueError( + "Geography n_clones does not match weight matrix " + f"({getattr(geography, 'n_clones', None)} != {weights.n_clones})" + ) + + def _field_matrix( + self, + geography: Any, + field_name: str, + shape: tuple[int, int], + ) -> np.ndarray: + if field_name not in self._SUPPORTED_FIELDS: + raise ValueError( + f"Unsupported geography field {field_name!r}; " + f"supported fields: {', '.join(self._SUPPORTED_FIELDS)}" + ) + if not hasattr(geography, field_name): + raise ValueError(f"Geography is missing field {field_name!r}") + + values = np.asarray(getattr(geography, field_name)) + expected_size = shape[0] * shape[1] + if values.size != expected_size: + raise ValueError( + f"Geography field {field_name!r} has length {values.size}; " + f"expected {expected_size}" + ) + return values.reshape(shape) + + def _apply_filter( + self, + *, + values: np.ndarray, + area_filter: AreaFilter, + ) -> np.ndarray: + if area_filter.op == "eq": + return values == area_filter.value + if area_filter.op == "in": + return np.isin(values, list(area_filter.value)) + raise ValueError(f"Unsupported filter op {area_filter.op!r}") diff --git a/policyengine_us_data/calibration/local_h5/source_dataset.py b/policyengine_us_data/calibration/local_h5/source_dataset.py new file mode 100644 index 000000000..b45c65739 --- /dev/null +++ b/policyengine_us_data/calibration/local_h5/source_dataset.py @@ -0,0 +1,97 @@ +"""Worker-scoped source dataset loading with lazy variable access.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Protocol + +import numpy as np + +from .entity_graph import EntityGraph, EntityGraphExtractor + + +class VariableArrayProvider(Protocol): + def list_variables(self) -> tuple[str, ...]: ... + + def get_known_periods(self, variable: str) -> tuple[int | str, ...]: ... + + def get_array(self, variable: str, period: int | str) -> np.ndarray: ... + + def get_variable_definition(self, variable: str) -> Any: ... + + def calculate(self, variable: str, *, map_to: str | None = None) -> Any: ... + + +class PolicyEngineVariableArrayProvider: + """Lazy access to source arrays through a single Microsimulation.""" + + def __init__(self, simulation: Any): + self.simulation = simulation + self._holder_cache: dict[str, Any] = {} + + def list_variables(self) -> tuple[str, ...]: + return tuple(self.simulation.tax_benefit_system.variables.keys()) + + def get_known_periods(self, variable: str) -> tuple[int | str, ...]: + return tuple(self._get_holder(variable).get_known_periods()) + + def get_array(self, variable: str, period: int | str) -> np.ndarray: + return self._get_holder(variable).get_array(period) + + def get_variable_definition(self, variable: str) -> Any: + return self.simulation.tax_benefit_system.variables.get(variable) + + def calculate(self, variable: str, *, map_to: str | None = None) -> Any: + if map_to is None: + return self.simulation.calculate(variable) + return self.simulation.calculate(variable, map_to=map_to) + + def _get_holder(self, variable: str) -> Any: + holder = self._holder_cache.get(variable) + if holder is None: + holder = self.simulation.get_holder(variable) + self._holder_cache[variable] = holder + return holder + + +@dataclass(frozen=True) +class SourceDatasetSnapshot: + dataset_path: Path + time_period: int + household_ids: np.ndarray + entity_graph: EntityGraph + input_variables: frozenset[str] + variable_provider: VariableArrayProvider + + @property + def n_households(self) -> int: + return int(len(self.household_ids)) + + +class PolicyEngineDatasetReader: + """Load worker-scoped source dataset structure once.""" + + def __init__(self, sub_entities: tuple[str, ...]): + self.sub_entities = tuple(sub_entities) + self.entity_graph_extractor = EntityGraphExtractor(self.sub_entities) + + def load(self, dataset_path: str | Path) -> SourceDatasetSnapshot: + from policyengine_us import Microsimulation + + dataset_path = Path(dataset_path) + simulation = Microsimulation(dataset=str(dataset_path)) + household_ids = np.asarray( + simulation.calculate("household_id", map_to="household").values + ) + entity_graph = self.entity_graph_extractor.extract(simulation, household_ids) + variable_provider = PolicyEngineVariableArrayProvider(simulation) + + return SourceDatasetSnapshot( + dataset_path=dataset_path, + time_period=int(simulation.default_calculation_period), + household_ids=household_ids, + entity_graph=entity_graph, + input_variables=frozenset(simulation.input_variables), + variable_provider=variable_provider, + ) diff --git a/policyengine_us_data/calibration/local_h5/us_augmentations.py b/policyengine_us_data/calibration/local_h5/us_augmentations.py new file mode 100644 index 000000000..89b5508f7 --- /dev/null +++ b/policyengine_us_data/calibration/local_h5/us_augmentations.py @@ -0,0 +1,300 @@ +"""US-specific payload augmentation for local H5 publishing.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Mapping, Sequence + +import numpy as np + +from policyengine_us_data.calibration.block_assignment import ( + derive_geography_from_blocks, +) +from policyengine_us_data.calibration.calibration_utils import ( + calculate_spm_thresholds_vectorized, + load_cd_geoadj_values, +) +from policyengine_us_data.calibration.local_h5.reindexing import ( + ReindexedEntities, +) +from policyengine_us_data.calibration.local_h5.selection import CloneSelection +from policyengine_us_data.calibration.local_h5.source_dataset import ( + SourceDatasetSnapshot, +) +from policyengine_us_data.utils.takeup import ( + apply_block_takeup_to_arrays, + reported_subsidized_marketplace_by_tax_unit, +) + + +def _default_county_name_lookup(county_indices: np.ndarray) -> np.ndarray: + from policyengine_us.variables.household.demographic.geographic.county.county_enum import ( + County, + ) + + return np.asarray( + [County._member_names_[int(index)] for index in county_indices], + dtype="S", + ) + + +def build_reported_takeup_anchors( + data: dict[str, dict[int | str, np.ndarray]], + time_period: int | str, +) -> dict[str, np.ndarray]: + reported_anchors: dict[str, np.ndarray] = {} + if ( + "reported_has_subsidized_marketplace_health_coverage_at_interview" in data + and "person_tax_unit_id" in data + and "tax_unit_id" in data + and time_period + in data["reported_has_subsidized_marketplace_health_coverage_at_interview"] + and time_period in data["person_tax_unit_id"] + and time_period in data["tax_unit_id"] + ): + reported_anchors["takes_up_aca_if_eligible"] = ( + reported_subsidized_marketplace_by_tax_unit( + data["person_tax_unit_id"][time_period], + data["tax_unit_id"][time_period], + data[ + "reported_has_subsidized_marketplace_health_coverage_at_interview" + ][time_period], + ) + ) + if ( + "has_medicaid_health_coverage_at_interview" in data + and time_period in data["has_medicaid_health_coverage_at_interview"] + ): + reported_anchors["takes_up_medicaid_if_eligible"] = data[ + "has_medicaid_health_coverage_at_interview" + ][time_period].astype(bool) + return reported_anchors + + +@dataclass(frozen=True) +class USAugmentationService: + geography_lookup: Callable[[np.ndarray], Mapping[str, np.ndarray]] = ( + derive_geography_from_blocks + ) + county_name_lookup: Callable[[np.ndarray], np.ndarray] = ( + _default_county_name_lookup + ) + cd_geoadj_loader: Callable[[Sequence[str]], Mapping[str, float]] = ( + load_cd_geoadj_values + ) + threshold_calculator: Callable[..., np.ndarray] = ( + calculate_spm_thresholds_vectorized + ) + takeup_fn: Callable[..., Mapping[str, np.ndarray]] = ( + apply_block_takeup_to_arrays + ) + + def apply_geography( + self, + data: dict[str, dict[int | str, np.ndarray]], + *, + time_period: int | str, + active_blocks: np.ndarray, + active_clone_cds: np.ndarray, + ) -> Mapping[str, np.ndarray]: + unique_blocks, block_inv = np.unique(active_blocks, return_inverse=True) + unique_geo = self.geography_lookup(unique_blocks) + clone_geo = { + key: np.asarray(values)[block_inv] + for key, values in unique_geo.items() + } + + data["state_fips"] = { + time_period: clone_geo["state_fips"].astype(np.int32) + } + data["county"] = { + time_period: self.county_name_lookup(clone_geo["county_index"]) + } + data["county_fips"] = { + time_period: clone_geo["county_fips"].astype(np.int32) + } + + for variable in ( + "block_geoid", + "tract_geoid", + "cbsa_code", + "sldu", + "sldl", + "place_fips", + "vtd", + "puma", + "zcta", + ): + if variable in clone_geo: + data[variable] = { + time_period: clone_geo[variable].astype("S") + } + + data["congressional_district_geoid"] = { + time_period: np.asarray( + [int(cd) for cd in active_clone_cds], + dtype=np.int32, + ) + } + return clone_geo + + def apply_zip_code_patch( + self, + data: dict[str, dict[int | str, np.ndarray]], + *, + time_period: int | str, + county_fips: np.ndarray, + ) -> None: + la_mask = county_fips.astype(str) == "06037" + if not la_mask.any(): + return + zip_codes = np.full(len(la_mask), "UNKNOWN") + zip_codes[la_mask] = "90001" + data["zip_code"] = {time_period: zip_codes.astype("S")} + + def apply_spm_thresholds( + self, + data: dict[str, dict[int | str, np.ndarray]], + *, + time_period: int, + active_clone_cds: np.ndarray, + source: SourceDatasetSnapshot, + reindexed: ReindexedEntities, + ) -> None: + provider = source.variable_provider + unique_cds_list = sorted(set(active_clone_cds)) + cd_geoadj_values = self.cd_geoadj_loader(unique_cds_list) + + spm_entities_per_clone = reindexed.entities_per_clone["spm_unit"] + spm_clone_ids = np.repeat( + np.arange(len(spm_entities_per_clone), dtype=np.int64), + spm_entities_per_clone, + ) + spm_unit_geoadj = np.asarray( + [ + cd_geoadj_values[str(active_clone_cds[clone_id])] + for clone_id in spm_clone_ids + ], + dtype=np.float64, + ) + + person_ages = provider.calculate("age", map_to="person").values[ + reindexed.person_source_indices + ] + spm_tenure_periods = provider.get_known_periods("spm_unit_tenure_type") + if spm_tenure_periods: + raw_tenure = provider.get_array( + "spm_unit_tenure_type", + spm_tenure_periods[0], + ) + if hasattr(raw_tenure, "decode_to_str"): + raw_tenure = raw_tenure.decode_to_str().astype("S") + else: + raw_tenure = np.asarray(raw_tenure).astype("S") + spm_tenure_cloned = raw_tenure[ + reindexed.entity_source_indices["spm_unit"] + ] + else: + spm_tenure_cloned = np.full( + len(reindexed.entity_source_indices["spm_unit"]), + b"RENTER", + dtype="S30", + ) + + data["spm_unit_spm_threshold"] = { + time_period: self.threshold_calculator( + person_ages=person_ages, + person_spm_unit_ids=reindexed.new_person_entity_ids["spm_unit"], + spm_unit_tenure_types=spm_tenure_cloned, + spm_unit_geoadj=spm_unit_geoadj, + year=time_period, + ) + } + + def apply_takeup( + self, + data: dict[str, dict[int | str, np.ndarray]], + *, + time_period: int | str, + takeup_filter: Sequence[str] | None, + selection: CloneSelection, + source: SourceDatasetSnapshot, + reindexed: ReindexedEntities, + clone_geo: Mapping[str, np.ndarray], + ) -> None: + entity_hh_indices = { + "person": np.repeat( + np.arange(selection.n_household_clones, dtype=np.int64), + reindexed.persons_per_clone, + ).astype(np.int64), + "tax_unit": np.repeat( + np.arange(selection.n_household_clones, dtype=np.int64), + reindexed.entities_per_clone["tax_unit"], + ).astype(np.int64), + "spm_unit": np.repeat( + np.arange(selection.n_household_clones, dtype=np.int64), + reindexed.entities_per_clone["spm_unit"], + ).astype(np.int64), + } + entity_counts = { + "person": len(reindexed.person_source_indices), + "tax_unit": len(reindexed.entity_source_indices["tax_unit"]), + "spm_unit": len(reindexed.entity_source_indices["spm_unit"]), + } + original_hh_ids = source.household_ids[ + selection.active_household_indices + ].astype(np.int64) + reported_anchors = build_reported_takeup_anchors(data, time_period) + + takeup_results = self.takeup_fn( + hh_blocks=selection.active_block_geoids, + hh_state_fips=clone_geo["state_fips"].astype(np.int32), + hh_ids=original_hh_ids, + hh_clone_indices=selection.active_clone_indices.astype(np.int64), + entity_hh_indices=entity_hh_indices, + entity_counts=entity_counts, + time_period=time_period, + takeup_filter=takeup_filter, + reported_anchors=reported_anchors, + ) + for variable, values in takeup_results.items(): + data[variable] = {time_period: values} + + def apply_all( + self, + data: dict[str, dict[int | str, np.ndarray]], + *, + time_period: int, + selection: CloneSelection, + source: SourceDatasetSnapshot, + reindexed: ReindexedEntities, + takeup_filter: Sequence[str] | None, + ) -> dict[str, dict[int | str, np.ndarray]]: + clone_geo = self.apply_geography( + data, + time_period=time_period, + active_blocks=selection.active_block_geoids, + active_clone_cds=selection.active_cd_geoids, + ) + self.apply_zip_code_patch( + data, + time_period=time_period, + county_fips=clone_geo["county_fips"], + ) + self.apply_spm_thresholds( + data, + time_period=time_period, + active_clone_cds=selection.active_cd_geoids, + source=source, + reindexed=reindexed, + ) + self.apply_takeup( + data, + time_period=time_period, + takeup_filter=takeup_filter, + selection=selection, + source=source, + reindexed=reindexed, + clone_geo=clone_geo, + ) + return data diff --git a/policyengine_us_data/calibration/local_h5/validation.py b/policyengine_us_data/calibration/local_h5/validation.py new file mode 100644 index 000000000..d5174ae17 --- /dev/null +++ b/policyengine_us_data/calibration/local_h5/validation.py @@ -0,0 +1,65 @@ +"""Pure helpers for H5 validation semantics.""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import Any + + +def validation_geo_level_for_area_type(area_type: str) -> str: + """Map current worker/validator area types onto sanity-check geo levels.""" + + if area_type == "states": + return "state" + if area_type == "national": + return "national" + return "district" + + +def summarize_validation_rows( + validation_rows: Sequence[Mapping[str, Any]], +) -> dict[str, int | float]: + """Summarize per-target validation rows for worker reporting.""" + + n_fail = sum(1 for row in validation_rows if row.get("sanity_check") == "FAIL") + rel_abs_errors = [ + row["rel_abs_error"] + for row in validation_rows + if isinstance(row.get("rel_abs_error"), (int, float)) + and row["rel_abs_error"] != float("inf") + ] + mean_rae = sum(rel_abs_errors) / len(rel_abs_errors) if rel_abs_errors else 0.0 + return { + "n_targets": len(validation_rows), + "n_sanity_fail": n_fail, + "mean_rel_abs_error": round(mean_rae, 4), + } + + +def make_validation_error( + item_key: str, + error: Exception | str, + traceback_text: str | None = None, +) -> dict[str, str]: + """Build a structured validation error record for worker JSON output.""" + + return { + "item": item_key, + "error": str(error), + "traceback": traceback_text or "", + } + + +def tag_validation_errors( + validation_errors: Sequence[Mapping[str, Any]], + *, + source: str, +) -> list[dict[str, Any]]: + """Attach a diagnostics source label to structured validation errors.""" + + tagged = [] + for error in validation_errors: + item = dict(error) + item["source"] = source + tagged.append(item) + return tagged diff --git a/policyengine_us_data/calibration/local_h5/variables.py b/policyengine_us_data/calibration/local_h5/variables.py new file mode 100644 index 000000000..d1761650f --- /dev/null +++ b/policyengine_us_data/calibration/local_h5/variables.py @@ -0,0 +1,107 @@ +"""Generic variable export for local H5 publishing.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Mapping + +import numpy as np + +from .reindexing import ReindexedEntities +from .source_dataset import SourceDatasetSnapshot + + +@dataclass(frozen=True) +class VariableExportPolicy: + include_input_variables: bool = True + required_variables: frozenset[str] = frozenset() + excluded_variables: frozenset[str] = frozenset() + + def variable_names(self, source: SourceDatasetSnapshot) -> tuple[str, ...]: + selected = set() + if self.include_input_variables: + selected.update(source.input_variables) + selected.update(self.required_variables) + selected.difference_update(self.excluded_variables) + return tuple(sorted(selected)) + + +@dataclass(frozen=True) +class H5Payload: + variables: Mapping[str, Mapping[int | str, np.ndarray]] + + @property + def dataset_count(self) -> int: + return sum(len(periods) for periods in self.variables.values()) + + +class VariableCloner: + """Clone source arrays for the selected entities and periods.""" + + def clone( + self, + source: SourceDatasetSnapshot, + reindexed: ReindexedEntities, + policy: VariableExportPolicy, + ) -> H5Payload: + provider = source.variable_provider + clone_index_map = { + "household": reindexed.household_source_indices, + "person": reindexed.person_source_indices, + **reindexed.entity_source_indices, + } + + payload: dict[str, dict[int | str, np.ndarray]] = {} + for variable in policy.variable_names(source): + var_def = provider.get_variable_definition(variable) + if var_def is None: + continue + + entity_key = var_def.entity.key + if entity_key not in clone_index_map: + continue + + periods = provider.get_known_periods(variable) + if not periods: + continue + + clone_indices = clone_index_map[entity_key] + var_data: dict[int | str, np.ndarray] = {} + for period in periods: + values = provider.get_array(variable, period) + coerced = self._coerce_output_array( + variable=variable, + values=values, + value_type=var_def.value_type, + ) + var_data[period] = coerced[clone_indices] + + if var_data: + payload[variable] = var_data + + return H5Payload(variables=payload) + + def _coerce_output_array( + self, + *, + variable: str, + values, + value_type, + ) -> np.ndarray: + if hasattr(values, "_pa_array") or hasattr(values, "_ndarray"): + values = np.asarray(values) + + if variable == "county_fips": + return np.asarray(values).astype("int32") + + if self._is_string_like_value_type(value_type): + if hasattr(values, "decode_to_str"): + return values.decode_to_str().astype("S") + return np.asarray(values).astype("S") + + return np.asarray(values) + + def _is_string_like_value_type(self, value_type) -> bool: + if value_type is str: + return True + return getattr(value_type, "__name__", None) == "Enum" diff --git a/policyengine_us_data/calibration/local_h5/weights.py b/policyengine_us_data/calibration/local_h5/weights.py new file mode 100644 index 000000000..6b8ff6a26 --- /dev/null +++ b/policyengine_us_data/calibration/local_h5/weights.py @@ -0,0 +1,57 @@ +"""US-local clone-by-household weight layout helpers.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + + +def infer_clone_count_from_weight_length( + weight_length: int, + n_records: int, +) -> int: + """Derive canonical clone count from weight length and record count.""" + + if n_records <= 0: + raise ValueError("n_records must be positive") + if weight_length <= 0: + raise ValueError("weight_length must be positive") + if weight_length % n_records != 0: + raise ValueError( + "Weight vector length " + f"{weight_length} is not divisible by n_records={n_records}" + ) + return int(weight_length // n_records) + + +@dataclass(frozen=True) +class CloneWeightMatrix: + """A US clone-by-household weight vector with validated shape.""" + + values: np.ndarray + n_records: int + n_clones: int + + @classmethod + def from_vector( + cls, + values: np.ndarray, + n_records: int, + ) -> "CloneWeightMatrix": + arr = np.asarray(values) + if arr.ndim != 1: + raise ValueError("weight vector must be one-dimensional") + + return cls( + values=arr, + n_records=int(n_records), + n_clones=infer_clone_count_from_weight_length(arr.size, n_records), + ) + + @property + def shape(self) -> tuple[int, int]: + return (self.n_clones, self.n_records) + + def as_matrix(self) -> np.ndarray: + return self.values.reshape(self.shape) diff --git a/policyengine_us_data/calibration/local_h5/worker_service.py b/policyengine_us_data/calibration/local_h5/worker_service.py new file mode 100644 index 000000000..0ae7be1e6 --- /dev/null +++ b/policyengine_us_data/calibration/local_h5/worker_service.py @@ -0,0 +1,593 @@ +"""Worker-scoped session loading and per-request H5 execution.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Mapping, Sequence +import traceback + +import numpy as np + +from .builder import LocalAreaDatasetBuilder +from .contracts import ( + AreaBuildRequest, + AreaBuildResult, + AreaFilter, + ValidationIssue, + ValidationPolicy, + ValidationResult, + WorkerResult, +) +from .package_geography import CalibrationPackageGeographyLoader +from .source_dataset import PolicyEngineDatasetReader, SourceDatasetSnapshot +from .validation import summarize_validation_rows +from .weights import infer_clone_count_from_weight_length +from .writer import H5Writer + + +@dataclass(frozen=True) +class ValidationContext: + validation_targets: Any + training_mask_full: np.ndarray + constraints_map: Mapping[int, Sequence[Any]] + db_path: Path + period: int + + +@dataclass(frozen=True) +class WorkerSession: + source_snapshot: SourceDatasetSnapshot + weights: np.ndarray + geography: Any + output_dir: Path + takeup_filter: tuple[str, ...] = () + validation_policy: ValidationPolicy = field(default_factory=ValidationPolicy) + validation_context: ValidationContext | None = None + dataset_path: Path | None = None + weights_path: Path | None = None + db_path: Path | None = None + calibration_package_path: Path | None = None + seed: int = 42 + requested_n_clones: int | None = None + geography_source: str = "" + geography_warnings: tuple[str, ...] = () + + @property + def n_records(self) -> int: + return self.source_snapshot.n_households + + @property + def n_clones(self) -> int: + return infer_clone_count_from_weight_length(len(self.weights), self.n_records) + + @classmethod + def load( + cls, + *, + weights_path: Path, + dataset_path: Path, + output_dir: Path, + calibration_package_path: Path | None = None, + requested_n_clones: int | None = None, + seed: int = 42, + takeup_filter: Sequence[str] = (), + validation_policy: ValidationPolicy | None = None, + validation_context: ValidationContext | None = None, + source_reader: PolicyEngineDatasetReader | None = None, + geography_loader: CalibrationPackageGeographyLoader | None = None, + allow_seed_fallback: bool = True, + ) -> "WorkerSession": + weights = np.load(weights_path) + source_reader = source_reader or PolicyEngineDatasetReader(()) + source_snapshot = source_reader.load(dataset_path) + n_records = source_snapshot.n_households + n_clones = infer_clone_count_from_weight_length(len(weights), n_records) + + geography_loader = geography_loader or CalibrationPackageGeographyLoader() + geography_resolution = geography_loader.resolve_for_weights( + package_path=calibration_package_path, + weights_length=len(weights), + n_records=n_records, + n_clones=n_clones, + seed=seed, + allow_seed_fallback=allow_seed_fallback, + ) + + return cls( + source_snapshot=source_snapshot, + weights=weights, + geography=geography_resolution.geography, + output_dir=Path(output_dir), + takeup_filter=tuple(takeup_filter), + validation_policy=validation_policy or ValidationPolicy(), + validation_context=validation_context, + dataset_path=Path(dataset_path), + weights_path=Path(weights_path), + db_path=(validation_context.db_path if validation_context else None), + calibration_package_path=( + Path(calibration_package_path) + if calibration_package_path is not None + else None + ), + seed=seed, + requested_n_clones=requested_n_clones, + geography_source=geography_resolution.source, + geography_warnings=tuple(geography_resolution.warnings), + ) + + +def load_validation_context( + *, + db_path: Path, + period: int, + target_config_path: str | Path | None = None, + validation_config_path: str | Path | None = None, + policy: ValidationPolicy | None = None, +) -> ValidationContext | None: + policy = policy or ValidationPolicy() + if not policy.enabled: + return None + + from sqlalchemy import create_engine + from policyengine_us_data.calibration.validate_staging import ( + _batch_stratum_constraints, + _query_all_active_targets, + ) + from policyengine_us_data.calibration.unified_calibration import ( + _match_rules, + load_target_config, + ) + + engine = create_engine(f"sqlite:///{db_path}") + validation_targets = _query_all_active_targets(engine, period) + + if validation_config_path: + val_cfg = load_target_config(str(validation_config_path)) + exclude_rules = val_cfg.get("exclude", []) + if exclude_rules: + exclude_mask = _match_rules(validation_targets, exclude_rules) + validation_targets = validation_targets[~exclude_mask].reset_index(drop=True) + include_rules = val_cfg.get("include", []) + if include_rules: + include_mask = _match_rules(validation_targets, include_rules) + validation_targets = validation_targets[include_mask].reset_index(drop=True) + + if target_config_path: + target_cfg = load_target_config(str(target_config_path)) + include_rules = target_cfg.get("include", []) + if include_rules: + training_mask_full = np.asarray( + _match_rules(validation_targets, include_rules), + dtype=bool, + ) + else: + training_mask_full = np.ones(len(validation_targets), dtype=bool) + else: + training_mask_full = np.ones(len(validation_targets), dtype=bool) + + stratum_ids = validation_targets["stratum_id"].unique().tolist() + constraints_map = _batch_stratum_constraints(engine, stratum_ids) + + return ValidationContext( + validation_targets=validation_targets, + training_mask_full=training_mask_full, + constraints_map=constraints_map, + db_path=Path(db_path), + period=period, + ) + + +def validate_in_subprocess( + h5_path, + area_type, + area_id, + display_id, + area_targets, + area_training, + constraints_map, + db_path, + period, +): + """Run validation for one area inside a subprocess.""" + import logging + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", + ) + from policyengine_us import Microsimulation + from sqlalchemy import create_engine as _ce + from policyengine_us_data.calibration.validate_staging import ( + _build_variable_entity_map, + validate_area, + ) + + engine = _ce(f"sqlite:///{db_path}") + sim = Microsimulation(dataset=h5_path) + variable_entity_map = _build_variable_entity_map(sim) + + return validate_area( + sim=sim, + targets_df=area_targets, + engine=engine, + area_type=area_type, + area_id=area_id, + display_id=display_id, + dataset_path=h5_path, + period=period, + training_mask=area_training, + variable_entity_map=variable_entity_map, + constraints_map=constraints_map, + ) + + +def validate_output_subprocess( + output_path: Path, + request: AreaBuildRequest, + session: WorkerSession, +) -> ValidationResult: + if session.validation_context is None: + return ValidationResult(status="not_run") + + validation_targets = session.validation_context.validation_targets + geographic_ids = tuple(str(item) for item in request.validation_geographic_ids) + + if request.validation_geo_level is None: + return ValidationResult(status="passed", summary=summarize_validation_rows(())) + + mask = validation_targets["geo_level"] == request.validation_geo_level + if geographic_ids: + mask &= validation_targets["geographic_id"].astype(str).isin(geographic_ids) + + area_targets = validation_targets[mask].reset_index(drop=True) + area_training = session.validation_context.training_mask_full[mask.values] + + if len(area_targets) == 0: + summary = summarize_validation_rows(()) + return ValidationResult(status="passed", rows=(), summary=summary) + + area_strata = area_targets["stratum_id"].unique().tolist() + area_constraints = { + int(stratum_id): session.validation_context.constraints_map.get( + int(stratum_id), [] + ) + for stratum_id in area_strata + } + + import multiprocessing as _mp + + with _mp.get_context("spawn").Pool(1) as pool: + rows = pool.apply( + validate_in_subprocess, + ( + str(output_path), + _validation_area_type(request), + request.area_id, + request.display_name, + area_targets, + area_training, + area_constraints, + str(session.validation_context.db_path), + session.validation_context.period, + ), + ) + + summary = summarize_validation_rows(rows) + status = "failed" if summary["n_sanity_fail"] > 0 else "passed" + return ValidationResult( + status=status, + rows=tuple(dict(row) for row in rows), + summary=summary, + ) + + +def build_request_from_work_item( + item: Mapping[str, Any], + *, + geography, + state_codes: Mapping[int, str], + at_large_districts: set[int], + nyc_county_fips: set[str], +) -> AreaBuildRequest | None: + item_type = item["type"] + item_id = item["id"] + geo_labels = sorted(set(np.asarray(geography.cd_geoid).astype(str))) + + if item_type == "state": + state_fips = _state_fips_for_code(item_id, state_codes) + cd_subset = [cd for cd in geo_labels if int(cd) // 100 == state_fips] + if not cd_subset: + return None + return AreaBuildRequest( + area_type="state", + area_id=item_id, + display_name=item_id, + output_relative_path=f"states/{item_id}.h5", + filters=( + AreaFilter( + geography_field="cd_geoid", + op="in", + value=tuple(cd_subset), + ), + ), + validation_geo_level="state", + validation_geographic_ids=(str(state_fips),), + ) + + if item_type == "district": + state_code, dist_num = item_id.split("-") + state_fips = _state_fips_for_code(state_code, state_codes) + candidate = f"{state_fips}{int(dist_num):02d}" + if candidate in geo_labels: + geoid = candidate + else: + state_cds = [cd for cd in geo_labels if int(cd) // 100 == state_fips] + if len(state_cds) == 1: + geoid = state_cds[0] + else: + raise ValueError( + f"CD {candidate} not found and state {state_code} has " + f"{len(state_cds)} CDs" + ) + + cd_int = int(geoid) + district_num = cd_int % 100 + if district_num in at_large_districts: + district_num = 1 + friendly_name = f"{state_code}-{district_num:02d}" + return AreaBuildRequest( + area_type="district", + area_id=friendly_name, + display_name=friendly_name, + output_relative_path=f"districts/{friendly_name}.h5", + filters=( + AreaFilter( + geography_field="cd_geoid", + op="in", + value=(geoid,), + ), + ), + validation_geo_level="district", + validation_geographic_ids=(geoid,), + ) + + if item_type == "city": + county_values = np.asarray(geography.county_fips).astype(str) + city_mask = np.isin(county_values, sorted(nyc_county_fips)) + city_cds = tuple(sorted(set(np.asarray(geography.cd_geoid).astype(str)[city_mask]))) + return AreaBuildRequest( + area_type="city", + area_id=item_id, + display_name=item_id, + output_relative_path=f"cities/{item_id}.h5", + filters=( + AreaFilter( + geography_field="county_fips", + op="in", + value=tuple(sorted(nyc_county_fips)), + ), + ), + validation_geo_level="district", + validation_geographic_ids=city_cds, + ) + + if item_type == "national": + return AreaBuildRequest.national() + + raise ValueError(f"Unknown item type: {item_type}") + + +def build_requests_from_work_items( + work_items: Sequence[Mapping[str, Any]], + *, + geography, + state_codes: Mapping[int, str], + at_large_districts: set[int], + nyc_county_fips: set[str], +) -> tuple[tuple[AreaBuildRequest, ...], tuple[AreaBuildResult, ...]]: + requests: list[AreaBuildRequest] = [] + failures: list[AreaBuildResult] = [] + + for item in work_items: + try: + request = build_request_from_work_item( + item, + geography=geography, + state_codes=state_codes, + at_large_districts=at_large_districts, + nyc_county_fips=nyc_county_fips, + ) + except Exception as error: + failures.append( + AreaBuildResult( + request=_fallback_request(item), + build_status="failed", + build_error=str(error), + ) + ) + continue + + if request is not None: + requests.append(request) + + return tuple(requests), tuple(failures) + + +class LocalH5WorkerService: + """Execute one worker chunk against a shared worker session.""" + + def __init__( + self, + *, + builder: LocalAreaDatasetBuilder | None = None, + writer: H5Writer | None = None, + validator: Callable[[Path, AreaBuildRequest, WorkerSession], ValidationResult] + | None = None, + ) -> None: + self.builder = builder or LocalAreaDatasetBuilder() + self.writer = writer or H5Writer() + self.validator = validator or validate_output_subprocess + + def run( + self, + session: WorkerSession, + requests: Sequence[AreaBuildRequest], + *, + initial_failures: Sequence[AreaBuildResult] = (), + ) -> WorkerResult: + completed: list[AreaBuildResult] = [] + failed: list[AreaBuildResult] = list(initial_failures) + + for request in requests: + result = self.build_one(session, request) + if result.build_status == "completed": + completed.append(result) + else: + failed.append(result) + + return WorkerResult( + completed=tuple(completed), + failed=tuple(failed), + ) + + def build_one( + self, + session: WorkerSession, + request: AreaBuildRequest, + ) -> AreaBuildResult: + output_path = session.output_dir / request.output_relative_path + + try: + built = self.builder.build( + weights=session.weights, + geography=session.geography, + source=session.source_snapshot, + filters=request.filters, + takeup_filter=session.takeup_filter, + ) + written_path = self.writer.write_payload(built.payload, output_path) + self.writer.verify_output(written_path, time_period=built.time_period) + except Exception as error: + return AreaBuildResult( + request=request, + build_status="failed", + build_error=str(error), + ) + + validation = self._validate_output(written_path, request, session) + return AreaBuildResult( + request=request, + build_status="completed", + output_path=written_path, + validation=validation, + ) + + def _validate_output( + self, + output_path: Path, + request: AreaBuildRequest, + session: WorkerSession, + ) -> ValidationResult: + if not session.validation_policy.enabled or session.validation_context is None: + return ValidationResult(status="not_run") + + try: + return self.validator(output_path, request, session) + except Exception as error: + return ValidationResult( + status="error", + issues=( + ValidationIssue( + code="validation_exception", + message=str(error), + severity="error", + details={"traceback": traceback.format_exc()}, + ), + ), + ) + + +def worker_result_to_legacy_dict(worker_result: WorkerResult) -> dict[str, Any]: + completed = [] + failed = [] + errors = [] + validation_rows: list[dict[str, Any]] = [] + validation_errors: list[dict[str, Any]] = [] + validation_summary: dict[str, Any] = {} + + for result in worker_result.completed: + item_key = _result_item_key(result.request) + completed.append(item_key) + if result.validation.status in ("passed", "failed"): + validation_rows.extend(dict(row) for row in result.validation.rows) + if result.validation.summary: + validation_summary[item_key] = dict(result.validation.summary) + elif result.validation.status == "error": + for issue in result.validation.issues: + validation_errors.append( + { + "item": item_key, + "error": issue.message, + "code": issue.code, + "details": dict(issue.details), + } + ) + + for result in worker_result.failed: + item_key = _result_item_key(result.request) + failed.append(item_key) + errors.append({"item": item_key, "error": result.build_error}) + + for issue in worker_result.worker_issues: + errors.append( + { + "item": "worker", + "error": issue.message, + "code": issue.code, + "details": dict(issue.details), + } + ) + + return { + "completed": completed, + "failed": failed, + "errors": errors, + "validation_errors": validation_errors, + "validation_rows": validation_rows, + "validation_summary": validation_summary, + } + + +def _validation_area_type(request: AreaBuildRequest) -> str: + if request.area_type == "state": + return "states" + if request.area_type == "district": + return "districts" + if request.area_type == "city": + return "cities" + return "national" + + +def _result_item_key(request: AreaBuildRequest) -> str: + return f"{request.area_type}:{request.area_id}" + + +def _fallback_request(item: Mapping[str, Any]) -> AreaBuildRequest: + area_type = item.get("type", "custom") + if area_type not in {"national", "state", "district", "city", "custom"}: + area_type = "custom" + area_id = str(item.get("id", "unknown")) + return AreaBuildRequest( + area_type=area_type, + area_id=area_id, + display_name=area_id, + output_relative_path=f"invalid/{area_id}.h5", + ) + + +def _state_fips_for_code(state_code: str, state_codes: Mapping[int, str]) -> int: + for fips, code in state_codes.items(): + if code == state_code: + return int(fips) + raise ValueError(f"Unknown state code: {state_code}") diff --git a/policyengine_us_data/calibration/local_h5/writer.py b/policyengine_us_data/calibration/local_h5/writer.py new file mode 100644 index 000000000..a52a40ad9 --- /dev/null +++ b/policyengine_us_data/calibration/local_h5/writer.py @@ -0,0 +1,74 @@ +"""H5 file writing and basic verification for local H5 publishing.""" + +from __future__ import annotations + +from pathlib import Path + +from .variables import H5Payload + + +class H5Writer: + """Persist local-H5 payloads and provide lightweight output verification.""" + + def write_payload(self, payload: H5Payload, output_path: str | Path) -> Path: + import h5py + + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with h5py.File(str(output_path), "w") as h5_file: + for variable, periods in payload.variables.items(): + group = h5_file.create_group(variable) + for period, values in periods.items(): + group.create_dataset(str(period), data=values) + + return output_path + + def verify_output( + self, + output_path: str | Path, + *, + time_period: int | str, + ) -> dict[str, int | float]: + import h5py + + output_path = Path(output_path) + period = str(time_period) + summary: dict[str, int | float] = {} + + with h5py.File(str(output_path), "r") as h5_file: + household_dataset = self._get_dataset(h5_file, "household_id", period) + if household_dataset is not None: + summary["household_count"] = int(len(household_dataset[:])) + + person_dataset = self._get_dataset(h5_file, "person_id", period) + if person_dataset is not None: + summary["person_count"] = int(len(person_dataset[:])) + + household_weight_dataset = self._get_dataset( + h5_file, + "household_weight", + period, + ) + if household_weight_dataset is not None: + summary["household_weight_sum"] = float( + household_weight_dataset[:].sum() + ) + + person_weight_dataset = self._get_dataset( + h5_file, + "person_weight", + period, + ) + if person_weight_dataset is not None: + summary["person_weight_sum"] = float(person_weight_dataset[:].sum()) + + return summary + + def _get_dataset(self, h5_file, variable: str, period: str): + if variable not in h5_file: + return None + group = h5_file[variable] + if period not in group: + return None + return group[period] diff --git a/policyengine_us_data/calibration/publish_local_area.py b/policyengine_us_data/calibration/publish_local_area.py index 2a017668c..3648f29f9 100644 --- a/policyengine_us_data/calibration/publish_local_area.py +++ b/policyengine_us_data/calibration/publish_local_area.py @@ -8,7 +8,6 @@ python publish_local_area.py [--skip-download] [--states-only] [--upload] """ -import hashlib import json import shutil @@ -25,21 +24,27 @@ ) from policyengine_us_data.calibration.calibration_utils import ( STATE_CODES, - load_cd_geoadj_values, - calculate_spm_thresholds_vectorized, -) -from policyengine_us_data.calibration.block_assignment import ( - derive_geography_from_blocks, ) from policyengine_us_data.calibration.clone_and_assign import ( GeographyAssignment, assign_random_geography, ) -from policyengine_us_data.utils.takeup import ( - SIMPLE_TAKEUP_VARS, - apply_block_takeup_to_arrays, - reported_subsidized_marketplace_by_tax_unit, +from policyengine_us_data.calibration.local_h5.builder import ( + LocalAreaDatasetBuilder, +) +from policyengine_us_data.calibration.local_h5.contracts import AreaFilter +from policyengine_us_data.calibration.local_h5.source_dataset import ( + PolicyEngineDatasetReader, + SourceDatasetSnapshot, +) +from policyengine_us_data.calibration.local_h5.us_augmentations import ( + build_reported_takeup_anchors, ) +from policyengine_us_data.calibration.local_h5.weights import ( + infer_clone_count_from_weight_length, +) +from policyengine_us_data.calibration.local_h5.writer import H5Writer +from policyengine_us_data.utils.takeup import SIMPLE_TAKEUP_VARS CHECKPOINT_FILE = Path("completed_states.txt") CHECKPOINT_FILE_DISTRICTS = Path("completed_districts.txt") @@ -52,31 +57,88 @@ META_FILE = WORK_DIR / "checkpoint_meta.json" +def _build_selection_filters( + *, + cd_subset: List[str] | None = None, + county_fips_filter: set[str] | None = None, +) -> tuple[AreaFilter, ...]: + filters = [] + if cd_subset is not None: + filters.append( + AreaFilter( + geography_field="cd_geoid", + op="in", + value=tuple(str(cd) for cd in cd_subset), + ) + ) + if county_fips_filter is not None: + filters.append( + AreaFilter( + geography_field="county_fips", + op="in", + value=tuple(str(fips) for fips in sorted(county_fips_filter)), + ) + ) + return tuple(filters) + + def compute_input_fingerprint( - weights_path: Path, dataset_path: Path, n_clones: int, seed: int + weights_path: Path, + dataset_path: Path, + n_clones: int, + seed: int, + calibration_package_path: Path | None = None, ) -> str: - h = hashlib.sha256() - for p in [weights_path, dataset_path]: - with open(p, "rb") as f: - while chunk := f.read(8192): - h.update(chunk) - h.update(f"{n_clones}:{seed}".encode()) - return h.hexdigest()[:16] + if calibration_package_path is None: + import hashlib + + h = hashlib.sha256() + for p in [weights_path, dataset_path]: + with open(p, "rb") as f: + while chunk := f.read(8192): + h.update(chunk) + h.update(f"{n_clones}:{seed}".encode()) + return h.hexdigest()[:16] + + from policyengine_us_data.calibration.local_h5.fingerprinting import ( + FingerprintService, + ) + + service = FingerprintService() + record = service.create_publish_fingerprint( + weights_path=weights_path, + dataset_path=dataset_path, + calibration_package_path=calibration_package_path, + n_clones=n_clones, + seed=seed, + ) + return record.digest -def validate_or_clear_checkpoints(fingerprint: str): +def validate_or_clear_checkpoints(fingerprint): + from policyengine_us_data.calibration.local_h5.fingerprinting import ( + FingerprintRecord, + FingerprintService, + ) + + service = FingerprintService() + if isinstance(fingerprint, FingerprintRecord): + record = fingerprint + else: + record = service.legacy_record(str(fingerprint)) + if META_FILE.exists(): - stored = json.loads(META_FILE.read_text()) - if stored.get("fingerprint") == fingerprint: - print(f"Inputs unchanged ({fingerprint}), resuming...") + stored = service.read_record(META_FILE) + if service.matches(stored, record): + print(f"Inputs unchanged ({record.digest}), resuming...") return print( f"Inputs changed " - f"({stored.get('fingerprint')} -> {fingerprint}), " + f"({stored.digest} -> {record.digest}), " f"clearing..." ) else: - print(f"No checkpoint metadata, starting fresh ({fingerprint})") + print(f"No checkpoint metadata, starting fresh ({record.digest})") h5_count = sum( 1 for subdir in ["states", "districts", "cities"] @@ -108,7 +170,7 @@ def validate_or_clear_checkpoints(fingerprint: str): if d.exists(): shutil.rmtree(d) META_FILE.parent.mkdir(parents=True, exist_ok=True) - META_FILE.write_text(json.dumps({"fingerprint": fingerprint})) + service.write_record(META_FILE, record) SUB_ENTITIES = [ @@ -161,29 +223,7 @@ def record_completed_city(city_name: str): def _build_reported_takeup_anchors( data: dict, time_period: int ) -> dict[str, np.ndarray]: - reported_anchors = {} - if ( - "reported_has_subsidized_marketplace_health_coverage_at_interview" in data - and time_period - in data["reported_has_subsidized_marketplace_health_coverage_at_interview"] - ): - reported_anchors["takes_up_aca_if_eligible"] = ( - reported_subsidized_marketplace_by_tax_unit( - data["person_tax_unit_id"][time_period], - data["tax_unit_id"][time_period], - data[ - "reported_has_subsidized_marketplace_health_coverage_at_interview" - ][time_period], - ) - ) - if ( - "has_medicaid_health_coverage_at_interview" in data - and time_period in data["has_medicaid_health_coverage_at_interview"] - ): - reported_anchors["takes_up_medicaid_if_eligible"] = data[ - "has_medicaid_health_coverage_at_interview" - ][time_period].astype(bool) - return reported_anchors + return build_reported_takeup_anchors(data, time_period) def build_h5( @@ -194,6 +234,7 @@ def build_h5( cd_subset: List[str] = None, county_fips_filter: set = None, takeup_filter: List[str] = None, + source_snapshot: SourceDatasetSnapshot | None = None, ) -> Path: """Build an H5 file by cloning records for each nonzero weight. @@ -210,415 +251,73 @@ def build_h5( Returns: Path to the output H5 file. """ - import h5py - from collections import defaultdict - from policyengine_core.enums import Enum - from policyengine_us.variables.household.demographic.geographic.county.county_enum import ( - County, - ) - output_path = Path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - - blocks = np.asarray(geography.block_geoid) - clone_cds = np.asarray(geography.cd_geoid, dtype=str) - # === Load base simulation === - sim = Microsimulation(dataset=str(dataset_path)) - time_period = int(sim.default_calculation_period) - household_ids = sim.calculate("household_id", map_to="household").values - n_hh = len(household_ids) - - if weights.shape[0] % n_hh != 0: + if source_snapshot is None: + source_snapshot = PolicyEngineDatasetReader(tuple(SUB_ENTITIES)).load( + dataset_path + ) + elif source_snapshot.dataset_path != Path(dataset_path): raise ValueError( - f"Weight vector length {weights.shape[0]} is not divisible by n_hh={n_hh}" + "source_snapshot.dataset_path does not match dataset_path " + f"({source_snapshot.dataset_path} != {Path(dataset_path)})" ) - n_clones_total = weights.shape[0] // n_hh - - # === Reshape and filter weight matrix === - W = weights.reshape(n_clones_total, n_hh).copy() - clone_cds_matrix = clone_cds.reshape(n_clones_total, n_hh) - # CD subset filtering: zero out cells whose CD isn't in subset - if cd_subset is not None: - cd_subset_set = set(cd_subset) - cd_mask = np.vectorize(lambda cd: cd in cd_subset_set)(clone_cds_matrix) - W[~cd_mask] = 0 + n_hh = source_snapshot.n_households + n_clones = infer_clone_count_from_weight_length(len(weights), n_hh) - # County FIPS filtering: zero out clones not in target counties - if county_fips_filter is not None: - fips_array = np.asarray(geography.county_fips).reshape(n_clones_total, n_hh) - fips_mask = np.isin(fips_array, list(county_fips_filter)) - W[~fips_mask] = 0 + builder = LocalAreaDatasetBuilder() + writer = H5Writer() label = ( f"CD subset {cd_subset}" if cd_subset is not None - else f"{n_clones_total} clone rows" + else f"{n_clones} clone rows" ) print(f"\n{'=' * 60}") print(f"Building {output_path.name} ({label}, {n_hh} households)") print(f"{'=' * 60}") - # === Identify active clones === - active_geo, active_hh = np.where(W > 0) - n_clones = len(active_geo) - if n_clones == 0: - raise ValueError( - f"No active clones after filtering. " - f"cd_subset={cd_subset}, county_fips_filter={county_fips_filter}" - ) - clone_weights = W[active_geo, active_hh] - active_blocks = blocks.reshape(n_clones_total, n_hh)[active_geo, active_hh] - active_clone_cds = clone_cds.reshape(n_clones_total, n_hh)[active_geo, active_hh] - - empty_count = np.sum(active_blocks == "") - if empty_count > 0: - raise ValueError(f"{empty_count} active clones have empty block GEOIDs") - - print(f"Active clones: {n_clones:,}") - print(f"Total weight: {clone_weights.sum():,.0f}") - - # === Build entity membership maps === - hh_id_to_idx = {int(hid): i for i, hid in enumerate(household_ids)} - person_hh_ids = sim.calculate("household_id", map_to="person").values - - hh_to_persons = defaultdict(list) - for p_idx, p_hh_id in enumerate(person_hh_ids): - hh_to_persons[hh_id_to_idx[int(p_hh_id)]].append(p_idx) - - hh_to_entity = {} - entity_id_arrays = {} - person_entity_id_arrays = {} - - for ek in SUB_ENTITIES: - eids = sim.calculate(f"{ek}_id", map_to=ek).values - peids = sim.calculate(f"person_{ek}_id", map_to="person").values - entity_id_arrays[ek] = eids - person_entity_id_arrays[ek] = peids - eid_to_idx = {int(eid): i for i, eid in enumerate(eids)} - - mapping = defaultdict(list) - seen = defaultdict(set) - for p_idx in range(len(person_hh_ids)): - hh_idx = hh_id_to_idx[int(person_hh_ids[p_idx])] - e_idx = eid_to_idx[int(peids[p_idx])] - if e_idx not in seen[hh_idx]: - seen[hh_idx].add(e_idx) - mapping[hh_idx].append(e_idx) - for hh_idx in mapping: - mapping[hh_idx].sort() - hh_to_entity[ek] = mapping - - # === Build clone index arrays === - hh_clone_idx = active_hh - - persons_per_clone = np.array([len(hh_to_persons.get(h, [])) for h in active_hh]) - person_parts = [ - np.array(hh_to_persons.get(h, []), dtype=np.int64) for h in active_hh - ] - person_clone_idx = ( - np.concatenate(person_parts) if person_parts else np.array([], dtype=np.int64) + built = builder.build( + weights=weights, + geography=geography, + source=source_snapshot, + filters=_build_selection_filters( + cd_subset=cd_subset, + county_fips_filter=county_fips_filter, + ), + takeup_filter=takeup_filter, ) - - entity_clone_idx = {} - entities_per_clone = {} - for ek in SUB_ENTITIES: - epc = np.array([len(hh_to_entity[ek].get(h, [])) for h in active_hh]) - entities_per_clone[ek] = epc - parts = [ - np.array(hh_to_entity[ek].get(h, []), dtype=np.int64) for h in active_hh - ] - entity_clone_idx[ek] = ( - np.concatenate(parts) if parts else np.array([], dtype=np.int64) - ) - - n_persons = len(person_clone_idx) - print(f"Cloned persons: {n_persons:,}") - for ek in SUB_ENTITIES: - print(f"Cloned {ek}s: {len(entity_clone_idx[ek]):,}") - - # === Build new entity IDs and cross-references === - new_hh_ids = np.arange(n_clones, dtype=np.int32) - new_person_ids = np.arange(n_persons, dtype=np.int32) - new_person_hh_ids = np.repeat(new_hh_ids, persons_per_clone) - - new_entity_ids = {} - new_person_entity_ids = {} - clone_ids_for_persons = np.repeat( - np.arange(n_clones, dtype=np.int64), persons_per_clone - ) - - for ek in SUB_ENTITIES: - n_ents = len(entity_clone_idx[ek]) - new_entity_ids[ek] = np.arange(n_ents, dtype=np.int32) - - old_eids = entity_id_arrays[ek][entity_clone_idx[ek]].astype(np.int64) - clone_ids_e = np.repeat( - np.arange(n_clones, dtype=np.int64), - entities_per_clone[ek], + selection = built.selection + reindexed = built.reindexed + print(f"Active clones: {selection.n_household_clones:,}") + print(f"Total weight: {selection.active_weights.sum():,.0f}") + print(f"Cloned persons: {len(reindexed.person_source_indices):,}") + for entity_key in SUB_ENTITIES: + print( + f"Cloned {entity_key}s: " + f"{len(reindexed.entity_source_indices[entity_key]):,}" ) + print(f"Variables cloned: {built.payload.dataset_count}") - offset = int(old_eids.max()) + 1 if len(old_eids) > 0 else 1 - entity_keys = clone_ids_e * offset + old_eids - - sorted_order = np.argsort(entity_keys) - sorted_keys = entity_keys[sorted_order] - sorted_new = new_entity_ids[ek][sorted_order] - - p_old_eids = person_entity_id_arrays[ek][person_clone_idx].astype(np.int64) - person_keys = clone_ids_for_persons * offset + p_old_eids - - positions = np.searchsorted(sorted_keys, person_keys) - positions = np.clip(positions, 0, len(sorted_keys) - 1) - new_person_entity_ids[ek] = sorted_new[positions] - - # === Derive geography from blocks (dedup optimization) === - print("Deriving geography from blocks...") - unique_blocks, block_inv = np.unique(active_blocks, return_inverse=True) - print(f" {n_clones:,} blocks -> {len(unique_blocks):,} unique") - unique_geo = derive_geography_from_blocks(unique_blocks) - clone_geo = {k: v[block_inv] for k, v in unique_geo.items()} - - # === Determine variables to save === - vars_to_save = set(sim.input_variables) - vars_to_save.add("county") - vars_to_save.add("spm_unit_spm_threshold") - vars_to_save.add("congressional_district_geoid") - for gv in [ - "block_geoid", - "tract_geoid", - "cbsa_code", - "sldu", - "sldl", - "place_fips", - "vtd", - "puma", - "zcta", - ]: - vars_to_save.add(gv) - - # === Clone variable arrays === - clone_idx_map = { - "household": hh_clone_idx, - "person": person_clone_idx, - } - for ek in SUB_ENTITIES: - clone_idx_map[ek] = entity_clone_idx[ek] - - data = {} - variables_saved = 0 - - for variable in sim.tax_benefit_system.variables: - if variable not in vars_to_save: - continue + output_path = writer.write_payload(built.payload, output_path) + summary = writer.verify_output(output_path, time_period=built.time_period) - holder = sim.get_holder(variable) - periods = holder.get_known_periods() - if not periods: - continue - - var_def = sim.tax_benefit_system.variables.get(variable) - entity_key = var_def.entity.key - if entity_key not in clone_idx_map: - continue - - cidx = clone_idx_map[entity_key] - var_data = {} - - for period in periods: - values = holder.get_array(period) - - if hasattr(values, "_pa_array") or hasattr(values, "_ndarray"): - values = np.asarray(values) - - if var_def.value_type in (Enum, str) and variable != "county_fips": - if hasattr(values, "decode_to_str"): - values = values.decode_to_str().astype("S") - else: - values = np.asarray(values).astype("S") - elif variable == "county_fips": - values = np.asarray(values).astype("int32") - else: - values = np.asarray(values) - - var_data[period] = values[cidx] - variables_saved += 1 - - if var_data: - data[variable] = var_data - - print(f"Variables cloned: {variables_saved}") - - # === Override entity IDs === - data["household_id"] = {time_period: new_hh_ids} - data["person_id"] = {time_period: new_person_ids} - data["person_household_id"] = { - time_period: new_person_hh_ids, - } - for ek in SUB_ENTITIES: - data[f"{ek}_id"] = { - time_period: new_entity_ids[ek], - } - data[f"person_{ek}_id"] = { - time_period: new_person_entity_ids[ek], - } - - # === Override weights === - # Only write household_weight; sub-entity weights (tax_unit_weight, - # spm_unit_weight, person_weight, etc.) are formula variables in - # policyengine-us that derive from household_weight at runtime. - data["household_weight"] = { - time_period: clone_weights.astype(np.float32), - } - - # === Override geography === - data["state_fips"] = { - time_period: clone_geo["state_fips"].astype(np.int32), - } - county_names = np.array( - [County._member_names_[i] for i in clone_geo["county_index"]] - ).astype("S") - data["county"] = {time_period: county_names} - data["county_fips"] = { - time_period: clone_geo["county_fips"].astype(np.int32), - } - for gv in [ - "block_geoid", - "tract_geoid", - "cbsa_code", - "sldu", - "sldl", - "place_fips", - "vtd", - "puma", - "zcta", - ]: - if gv in clone_geo: - data[gv] = { - time_period: clone_geo[gv].astype("S"), - } - - # === Set zip_code for LA County clones (ACA rating area fix) === - la_mask = clone_geo["county_fips"].astype(str) == "06037" - if la_mask.any(): - zip_codes = np.full(len(la_mask), "UNKNOWN") - zip_codes[la_mask] = "90001" - data["zip_code"] = {time_period: zip_codes.astype("S")} - - # === Congressional district GEOID === - clone_cd_geoids = np.array([int(cd) for cd in active_clone_cds], dtype=np.int32) - data["congressional_district_geoid"] = { - time_period: clone_cd_geoids, - } - - # === SPM threshold recalculation === - print("Recalculating SPM thresholds...") - unique_cds_list = sorted(set(active_clone_cds)) - cd_geoadj_values = load_cd_geoadj_values(unique_cds_list) - # Build per-SPM-unit geoadj from clone's CD - spm_clone_ids = np.repeat( - np.arange(n_clones, dtype=np.int64), - entities_per_clone["spm_unit"], - ) - spm_unit_geoadj = np.array( - [cd_geoadj_values[active_clone_cds[c]] for c in spm_clone_ids], - dtype=np.float64, - ) - - # Get cloned person ages and SPM tenure types - person_ages = sim.calculate("age", map_to="person").values[person_clone_idx] - - spm_tenure_holder = sim.get_holder("spm_unit_tenure_type") - spm_tenure_periods = spm_tenure_holder.get_known_periods() - if spm_tenure_periods: - raw_tenure = spm_tenure_holder.get_array(spm_tenure_periods[0]) - if hasattr(raw_tenure, "decode_to_str"): - raw_tenure = raw_tenure.decode_to_str().astype("S") - else: - raw_tenure = np.array(raw_tenure).astype("S") - spm_tenure_cloned = raw_tenure[entity_clone_idx["spm_unit"]] - else: - spm_tenure_cloned = np.full( - len(entity_clone_idx["spm_unit"]), - b"RENTER", - dtype="S30", + print(f"\nH5 saved to {output_path}") + if "household_count" in summary: + print(f"Verified: {summary['household_count']:,} households in output") + if "person_count" in summary: + print(f"Verified: {summary['person_count']:,} persons in output") + if "household_weight_sum" in summary: + print( + "Total population (HH weights): " + f"{summary['household_weight_sum']:,.0f}" ) - - new_spm_thresholds = calculate_spm_thresholds_vectorized( - person_ages=person_ages, - person_spm_unit_ids=new_person_entity_ids["spm_unit"], - spm_unit_tenure_types=spm_tenure_cloned, - spm_unit_geoadj=spm_unit_geoadj, - year=time_period, - ) - data["spm_unit_spm_threshold"] = { - time_period: new_spm_thresholds, - } - - # === Apply calibration takeup draws === - if blocks is not None: - print("Applying calibration takeup draws...") - entity_hh_indices = { - "person": np.repeat( - np.arange(n_clones, dtype=np.int64), - persons_per_clone, - ).astype(np.int64), - "tax_unit": np.repeat( - np.arange(n_clones, dtype=np.int64), - entities_per_clone["tax_unit"], - ).astype(np.int64), - "spm_unit": np.repeat( - np.arange(n_clones, dtype=np.int64), - entities_per_clone["spm_unit"], - ).astype(np.int64), - } - entity_counts = { - "person": n_persons, - "tax_unit": len(entity_clone_idx["tax_unit"]), - "spm_unit": len(entity_clone_idx["spm_unit"]), - } - hh_state_fips = clone_geo["state_fips"].astype(np.int32) - original_hh_ids = household_ids[active_hh].astype(np.int64) - reported_anchors = _build_reported_takeup_anchors(data, time_period) - - takeup_results = apply_block_takeup_to_arrays( - hh_blocks=active_blocks, - hh_state_fips=hh_state_fips, - hh_ids=original_hh_ids, - hh_clone_indices=active_geo.astype(np.int64), - entity_hh_indices=entity_hh_indices, - entity_counts=entity_counts, - time_period=time_period, - takeup_filter=takeup_filter, - reported_anchors=reported_anchors, + if "person_weight_sum" in summary: + print( + "Total population (person weights): " + f"{summary['person_weight_sum']:,.0f}" ) - for var_name, bools in takeup_results.items(): - data[var_name] = {time_period: bools} - - # === Write H5 === - with h5py.File(str(output_path), "w") as f: - for variable, periods in data.items(): - grp = f.create_group(variable) - for period, values in periods.items(): - grp.create_dataset(str(period), data=values) - - print(f"\nH5 saved to {output_path}") - - with h5py.File(str(output_path), "r") as f: - tp = str(time_period) - if "household_id" in f and tp in f["household_id"]: - n = len(f["household_id"][tp][:]) - print(f"Verified: {n:,} households in output") - if "person_id" in f and tp in f["person_id"]: - n = len(f["person_id"][tp][:]) - print(f"Verified: {n:,} persons in output") - if "household_weight" in f and tp in f["household_weight"]: - hw = f["household_weight"][tp][:] - print(f"Total population (HH weights): {hw.sum():,.0f}") - if "person_weight" in f and tp in f["person_weight"]: - pw = f["person_weight"][tp][:] - print(f"Total population (person weights): {pw.sum():,.0f}") return output_path @@ -881,6 +580,11 @@ def main(): action="store_true", help="Upload to GCP and HuggingFace (default: build locally only)", ) + parser.add_argument( + "--calibration-package-path", + type=str, + help="Optional calibration package path for exact geography reuse", + ) args = parser.parse_args() WORK_DIR.mkdir(parents=True, exist_ok=True) @@ -911,14 +615,11 @@ def main(): print(f"Using dataset: {inputs['dataset']}") - print("Computing input fingerprint...") - fingerprint = compute_input_fingerprint( - inputs["weights"], - inputs["dataset"], - args.n_clones, - args.seed, + calibration_package_path = ( + Path(args.calibration_package_path) + if args.calibration_package_path + else None ) - validate_or_clear_checkpoints(fingerprint) print("Loading base simulation to get household count...") _sim = Microsimulation(dataset=str(inputs["dataset"])) @@ -926,8 +627,68 @@ def main(): del _sim print(f"\nBase dataset has {n_hh:,} households") - geo_cache = WORK_DIR / f"geography_{n_hh}x{args.n_clones}_s{args.seed}.npz" - if geo_cache.exists(): + weights = np.load(inputs["weights"], mmap_mode="r") + canonical_n_clones = infer_clone_count_from_weight_length( + weights.shape[0], + n_hh, + ) + if canonical_n_clones != args.n_clones: + print( + f"WARNING: requested n_clones={args.n_clones} but " + f"weights imply {canonical_n_clones}; using weights-derived value" + ) + + print("Computing input fingerprint...") + if calibration_package_path is not None: + from policyengine_us_data.calibration.local_h5.fingerprinting import ( + FingerprintService, + ) + from policyengine_us_data.calibration.local_h5.package_geography import ( + require_calibration_package_path, + ) + + calibration_package_path = require_calibration_package_path( + calibration_package_path + ) + fingerprint_service = FingerprintService() + fingerprint_record = fingerprint_service.create_publish_fingerprint( + weights_path=inputs["weights"], + dataset_path=inputs["dataset"], + calibration_package_path=calibration_package_path, + n_clones=canonical_n_clones, + seed=args.seed, + ) + fingerprint = fingerprint_record.digest + validate_or_clear_checkpoints(fingerprint_record) + else: + fingerprint = compute_input_fingerprint( + inputs["weights"], + inputs["dataset"], + canonical_n_clones, + args.seed, + calibration_package_path=calibration_package_path, + ) + validate_or_clear_checkpoints(fingerprint) + + geo_cache = WORK_DIR / f"geography_{n_hh}x{canonical_n_clones}_s{args.seed}.npz" + if calibration_package_path is not None and calibration_package_path.exists(): + from policyengine_us_data.calibration.local_h5.package_geography import ( + CalibrationPackageGeographyLoader, + ) + + loader = CalibrationPackageGeographyLoader() + resolved = loader.resolve_for_weights( + package_path=calibration_package_path, + weights_length=weights.shape[0], + n_records=n_hh, + n_clones=canonical_n_clones, + seed=args.seed, + ) + geography = resolved.geography + print(f"Loaded geography from {resolved.source}") + for warning in resolved.warnings: + print(f"WARNING: {warning}") + elif geo_cache.exists(): print(f"Loading cached geography from {geo_cache}") npz = np.load(geo_cache, allow_pickle=True) geography = GeographyAssignment( @@ -936,16 +697,16 @@ def main(): county_fips=npz["county_fips"], state_fips=npz["state_fips"], n_records=n_hh, - n_clones=args.n_clones, + n_clones=canonical_n_clones, ) else: print( f"Generating geography: {n_hh} records x " - f"{args.n_clones} clones, seed={args.seed}" + f"{canonical_n_clones} clones, seed={args.seed}" ) geography = assign_random_geography( n_records=n_hh, - n_clones=args.n_clones, + n_clones=canonical_n_clones, seed=args.seed, ) np.savez_compressed( diff --git a/policyengine_us_data/calibration/unified_calibration.py b/policyengine_us_data/calibration/unified_calibration.py index e449cea4d..211447443 100644 --- a/policyengine_us_data/calibration/unified_calibration.py +++ b/policyengine_us_data/calibration/unified_calibration.py @@ -416,6 +416,7 @@ def save_calibration_package( targets_df: "pd.DataFrame", target_names: list, metadata: dict, + geography=None, initial_weights: np.ndarray = None, cd_geoid: np.ndarray = None, block_geoid: np.ndarray = None, @@ -428,18 +429,33 @@ def save_calibration_package( targets_df: Targets DataFrame. target_names: Target name list. metadata: Run metadata dict. + geography: Optional GeographyAssignment to serialize. initial_weights: Pre-computed initial weight array. cd_geoid: CD GEOID array from geography assignment. block_geoid: Block GEOID array from geography assignment. """ import pickle + serialized_geography = None + if geography is not None: + from policyengine_us_data.calibration.local_h5.package_geography import ( + CalibrationPackageGeographyLoader, + ) + + loader = CalibrationPackageGeographyLoader() + serialized_geography = loader.serialize_geography(geography) + if cd_geoid is None: + cd_geoid = geography.cd_geoid + if block_geoid is None: + block_geoid = geography.block_geoid + package = { "X_sparse": X_sparse, "targets_df": targets_df, "target_names": target_names, "metadata": metadata, "initial_weights": initial_weights, + "geography": serialized_geography, "cd_geoid": cd_geoid, "block_geoid": block_geoid, } @@ -462,6 +478,18 @@ def load_calibration_package(path: str) -> dict: with open(path, "rb") as f: package = pickle.load(f) + if package.get("geography") is None: + from policyengine_us_data.calibration.local_h5.package_geography import ( + CalibrationPackageGeographyLoader, + ) + + loader = CalibrationPackageGeographyLoader() + try: + loaded = loader.load_from_package_dict(package) + except ValueError: + loaded = None + if loaded is not None: + package["geography"] = loader.serialize_geography(loaded.geography) logger.info( "Loaded package: %d targets, %d records", package["X_sparse"].shape[0], @@ -1076,9 +1104,8 @@ def run_calibration( targets_df, target_names, metadata, + geography=geography, initial_weights=full_initial_weights, - cd_geoid=geography.cd_geoid, - block_geoid=geography.block_geoid, ) # Step 6c: Apply target config filtering (for fit or validation) diff --git a/policyengine_us_data/calibration/validate_staging.py b/policyengine_us_data/calibration/validate_staging.py index 1862fbbdc..b108d9d12 100644 --- a/policyengine_us_data/calibration/validate_staging.py +++ b/policyengine_us_data/calibration/validate_staging.py @@ -38,6 +38,9 @@ from policyengine_us_data.calibration.calibration_utils import ( STATE_CODES, ) +from policyengine_us_data.calibration.local_h5.validation import ( + validation_geo_level_for_area_type, +) from policyengine_us_data.calibration.sanity_checks import ( run_sanity_checks, ) @@ -325,7 +328,7 @@ def validate_area( training_arr = np.asarray(training_mask, dtype=bool) - geo_level = "state" if area_type == "states" else "district" + geo_level = validation_geo_level_for_area_type(area_type) results = [] for i, (idx, row) in enumerate(targets_df.iterrows()): @@ -1059,7 +1062,7 @@ def main(argv=None): all_results = [] for area_type in area_types: - geo_level = "state" if area_type == "states" else "district" + geo_level = validation_geo_level_for_area_type(area_type) geo_mask = (all_targets["geo_level"] == geo_level).values level_targets = all_targets[geo_mask].reset_index(drop=True) level_training = training_mask[geo_mask] diff --git a/tests/integration/test_build_h5.py b/tests/integration/test_build_h5.py index 339dec4e6..e2e435387 100644 --- a/tests/integration/test_build_h5.py +++ b/tests/integration/test_build_h5.py @@ -2,6 +2,7 @@ import os import tempfile +import h5py import numpy as np import pandas as pd import pytest @@ -9,11 +10,15 @@ from pathlib import Path from policyengine_us import Microsimulation from policyengine_us_data.calibration.publish_local_area import ( + SUB_ENTITIES, build_h5, ) from policyengine_us_data.calibration.clone_and_assign import ( GeographyAssignment, ) +from policyengine_us_data.calibration.local_h5.source_dataset import ( + PolicyEngineDatasetReader, +) FIXTURE_PATH = os.path.join(os.path.dirname(__file__), "test_fixture_50hh.h5") TEST_CDS = ["3701", "200"] # NC-01 and AK at-large @@ -78,6 +83,11 @@ def test_weights(n_households): return w +@pytest.fixture(scope="module") +def source_snapshot(): + return PolicyEngineDatasetReader(tuple(SUB_ENTITIES)).load(Path(FIXTURE_PATH)) + + @pytest.fixture(scope="module") def stacked_result(test_weights, n_households): """Run stacked dataset builder and return results.""" @@ -166,6 +176,41 @@ def test_household_count_matches_weights(self, stacked_result, test_weights): expected_households = (test_weights > 0).sum() assert len(hh_df) == expected_households + def test_build_h5_accepts_preloaded_source_snapshot( + self, + test_weights, + n_households, + source_snapshot, + ): + """A worker-scoped source snapshot should produce a valid augmented H5.""" + geography = _make_geography(n_households, TEST_CDS) + with tempfile.TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir) / "test_snapshot_output.h5" + expected_households = int((test_weights > 0).sum()) + + build_h5( + weights=np.array(test_weights), + geography=geography, + dataset_path=Path(FIXTURE_PATH), + output_path=output_path, + cd_subset=TEST_CDS, + source_snapshot=source_snapshot, + ) + + with h5py.File(output_path, "r") as f: + assert "congressional_district_geoid" in f + assert "spm_unit_spm_threshold" in f + tp = str(source_snapshot.time_period) + hh_ids = f["household_id"][tp][:] + hh_weights = f["household_weight"][tp][:] + districts = f["congressional_district_geoid"][tp][:] + spm_thresholds = f["spm_unit_spm_threshold"][tp][:] + + assert len(hh_ids) == expected_households + assert len(hh_weights) == expected_households + assert len(districts) == expected_households + assert len(spm_thresholds) > 0 + @pytest.fixture(scope="module") def stacked_sim(test_weights, n_households): diff --git a/tests/integration/test_build_h5_minimal.py b/tests/integration/test_build_h5_minimal.py new file mode 100644 index 000000000..099b7cf46 --- /dev/null +++ b/tests/integration/test_build_h5_minimal.py @@ -0,0 +1,444 @@ +"""Minimal integration coverage for the build_h5 publishing seam.""" + +from __future__ import annotations + +import importlib +from pathlib import Path +import sys +import types +from types import SimpleNamespace + +import h5py +import numpy as np + + +FIXTURE_PATH = Path(__file__).with_name("test_fixture_50hh.h5") +SUB_ENTITIES = ("tax_unit", "spm_unit", "family", "marital_unit") +TEST_CDS = ("0200", "3701") +_CD_COUNTY = { + "0200": "02020", + "3701": "37183", +} + + +def _install_stub_packages(monkeypatch): + repo_root = Path(__file__).resolve().parents[2] + + for name in list(sys.modules): + if name == "policyengine_us" or name.startswith("policyengine_us."): + sys.modules.pop(name, None) + if name == "policyengine_us_data" or name.startswith("policyengine_us_data."): + sys.modules.pop(name, None) + + policyengine_us = types.ModuleType("policyengine_us") + policyengine_us.Microsimulation = object + monkeypatch.setitem(sys.modules, "policyengine_us", policyengine_us) + + variables_mod = types.ModuleType("policyengine_us.variables") + household_mod = types.ModuleType("policyengine_us.variables.household") + demographic_mod = types.ModuleType( + "policyengine_us.variables.household.demographic" + ) + geographic_mod = types.ModuleType( + "policyengine_us.variables.household.demographic.geographic" + ) + county_mod = types.ModuleType( + "policyengine_us.variables.household.demographic.geographic.county" + ) + county_enum_mod = types.ModuleType( + "policyengine_us.variables.household.demographic.geographic.county.county_enum" + ) + + class County: + _member_names_ = ["UNKNOWN", "ANCHORAGE_AK", "WAKE_NC"] + + county_enum_mod.County = County + + monkeypatch.setitem(sys.modules, "policyengine_us.variables", variables_mod) + monkeypatch.setitem( + sys.modules, + "policyengine_us.variables.household", + household_mod, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us.variables.household.demographic", + demographic_mod, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us.variables.household.demographic.geographic", + geographic_mod, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us.variables.household.demographic.geographic.county", + county_mod, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us.variables.household.demographic.geographic.county.county_enum", + county_enum_mod, + ) + + spm_calculator = types.ModuleType("spm_calculator") + spm_calculator.__path__ = [] + spm_calculator_geoadj = types.ModuleType("spm_calculator.geoadj") + + class FakeSPMCalculator: + def __init__(self, year): + self.year = year + + def get_base_thresholds(self): + return { + "owner_with_mortgage": 10000.0, + "owner_without_mortgage": 9000.0, + "renter": 8000.0, + } + + def spm_equivalence_scale(num_adults: int, num_children: int) -> float: + return 1.0 + 0.1 * num_adults + 0.05 * num_children + + spm_calculator.SPMCalculator = FakeSPMCalculator + spm_calculator.spm_equivalence_scale = spm_equivalence_scale + spm_calculator_geoadj.calculate_geoadj_from_rent = lambda rent: 1.0 + monkeypatch.setitem(sys.modules, "spm_calculator", spm_calculator) + monkeypatch.setitem( + sys.modules, + "spm_calculator.geoadj", + spm_calculator_geoadj, + ) + + package = types.ModuleType("policyengine_us_data") + package.__path__ = [str(repo_root / "policyengine_us_data")] + calibration_package = types.ModuleType("policyengine_us_data.calibration") + calibration_package.__path__ = [ + str(repo_root / "policyengine_us_data" / "calibration") + ] + local_h5_package = types.ModuleType("policyengine_us_data.calibration.local_h5") + local_h5_package.__path__ = [ + str(repo_root / "policyengine_us_data" / "calibration" / "local_h5") + ] + utils_package = types.ModuleType("policyengine_us_data.utils") + utils_package.__path__ = [str(repo_root / "policyengine_us_data" / "utils")] + + monkeypatch.setitem(sys.modules, "policyengine_us_data", package) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration", + calibration_package, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5", + local_h5_package, + ) + monkeypatch.setitem(sys.modules, "policyengine_us_data.utils", utils_package) + + hf_module = types.ModuleType("policyengine_us_data.utils.huggingface") + hf_module.download_calibration_inputs = lambda *_a, **_k: None + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.utils.huggingface", + hf_module, + ) + + upload_module = types.ModuleType("policyengine_us_data.utils.data_upload") + upload_module.upload_local_area_file = lambda *_a, **_k: None + upload_module.upload_local_area_batch_to_hf = lambda *_a, **_k: None + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.utils.data_upload", + upload_module, + ) + + calibration_utils = types.ModuleType( + "policyengine_us_data.calibration.calibration_utils" + ) + calibration_utils.STATE_CODES = {2: "AK", 37: "NC"} + calibration_utils.load_cd_geoadj_values = ( + lambda cds: {str(cd): 1.0 for cd in cds} + ) + + def calculate_spm_thresholds_vectorized( + *, + person_ages, + person_spm_unit_ids, + spm_unit_tenure_types, + spm_unit_geoadj, + year, + ): + n_units = len(spm_unit_geoadj) + thresholds = np.full(n_units, 8000.0, dtype=np.float32) + if n_units and len(person_spm_unit_ids): + adults = np.zeros(n_units, dtype=np.int32) + np.add.at(adults, person_spm_unit_ids, (person_ages >= 18).astype(np.int32)) + thresholds += adults.astype(np.float32) * 250.0 + return thresholds * spm_unit_geoadj.astype(np.float32) + + calibration_utils.calculate_spm_thresholds_vectorized = ( + calculate_spm_thresholds_vectorized + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.calibration_utils", + calibration_utils, + ) + + +def _load_runtime_modules(monkeypatch): + _install_stub_packages(monkeypatch) + importlib.invalidate_caches() + + publish_local_area = importlib.import_module( + "policyengine_us_data.calibration.publish_local_area" + ) + clone_and_assign = importlib.import_module( + "policyengine_us_data.calibration.clone_and_assign" + ) + entity_graph = importlib.import_module( + "policyengine_us_data.calibration.local_h5.entity_graph" + ) + source_dataset = importlib.import_module( + "policyengine_us_data.calibration.local_h5.source_dataset" + ) + builder_module = importlib.import_module( + "policyengine_us_data.calibration.local_h5.builder" + ) + us_augmentations = importlib.import_module( + "policyengine_us_data.calibration.local_h5.us_augmentations" + ) + + return ( + publish_local_area, + clone_and_assign, + entity_graph, + source_dataset, + builder_module, + us_augmentations, + sys.modules["policyengine_us_data.calibration.calibration_utils"], + ) + + +class FixtureVariableProvider: + def __init__(self, arrays, var_defs): + self.arrays = arrays + self.var_defs = var_defs + + def list_variables(self) -> tuple[str, ...]: + return tuple(sorted(self.var_defs)) + + def get_known_periods(self, variable: str) -> tuple[int | str, ...]: + periods = self.arrays.get(variable, {}) + return tuple(int(period) for period in sorted(periods)) + + def get_array(self, variable: str, period: int | str) -> np.ndarray: + return self.arrays[variable][str(period)] + + def get_variable_definition(self, variable: str): + return self.var_defs.get(variable) + + def calculate(self, variable: str, *, map_to: str | None = None): + if variable != "age" or map_to != "person": + raise KeyError(f"Unsupported calculate call: {variable}, {map_to}") + return SimpleNamespace(values=self.get_array("age", 2023)) + + +def _load_fixture_arrays(): + arrays: dict[str, dict[str, np.ndarray]] = {} + with h5py.File(FIXTURE_PATH, "r") as fixture: + for variable in fixture.keys(): + arrays[variable] = { + period: fixture[variable][period][:] + for period in fixture[variable].keys() + } + return arrays + + +def _make_fixture_snapshot(source_dataset_module, entity_graph_module): + arrays = _load_fixture_arrays() + extractor = entity_graph_module.EntityGraphExtractor(SUB_ENTITIES) + household_ids = arrays["household_id"]["2023"] + entity_graph = extractor.extract_from_arrays( + household_ids=household_ids, + person_household_ids=arrays["person_household_id"]["2023"], + entity_id_arrays={ + entity_key: arrays[f"{entity_key}_id"]["2023"] + for entity_key in SUB_ENTITIES + }, + person_entity_id_arrays={ + entity_key: arrays[f"person_{entity_key}_id"]["2023"] + for entity_key in SUB_ENTITIES + }, + ) + + def _var_def(entity_key: str, value_type): + return SimpleNamespace(entity=SimpleNamespace(key=entity_key), value_type=value_type) + + variable_provider = FixtureVariableProvider( + arrays=arrays, + var_defs={ + "age": _var_def("person", int), + "employment_income": _var_def("person", float), + "person_weight": _var_def("person", float), + "state_fips": _var_def("household", int), + "family_weight": _var_def("family", float), + "spm_unit_weight": _var_def("spm_unit", float), + "tax_unit_weight": _var_def("tax_unit", float), + "marital_unit_weight": _var_def("marital_unit", float), + }, + ) + + snapshot = source_dataset_module.SourceDatasetSnapshot( + dataset_path=FIXTURE_PATH, + time_period=2023, + household_ids=household_ids, + entity_graph=entity_graph, + input_variables=frozenset(variable_provider.list_variables()), + variable_provider=variable_provider, + ) + return snapshot, arrays + + +def _make_geography(geography_assignment_cls, n_households: int): + cd_geoid = np.repeat(np.asarray(TEST_CDS, dtype=str), n_households) + block_geoid = np.asarray( + [ + f"{_CD_COUNTY[cd]}{idx:06d}{idx:04d}"[:15] + for idx, cd in enumerate(cd_geoid) + ], + dtype="U15", + ) + county_fips = np.asarray([block[:5] for block in block_geoid], dtype="U5") + state_fips = np.asarray([int(block[:2]) for block in block_geoid], dtype=np.int32) + return geography_assignment_cls( + block_geoid=block_geoid, + cd_geoid=cd_geoid, + county_fips=county_fips, + state_fips=state_fips, + n_records=n_households, + n_clones=len(TEST_CDS), + ) + + +def _fake_geography_lookup(blocks: np.ndarray): + blocks = np.asarray(blocks).astype(str) + county_fips = np.asarray([int(block[:5]) for block in blocks], dtype=np.int32) + state_fips = np.asarray([int(block[:2]) for block in blocks], dtype=np.int32) + county_index = np.asarray( + [1 if int(block[:2]) == 2 else 2 for block in blocks], + dtype=np.int32, + ) + return { + "state_fips": state_fips, + "county_fips": county_fips, + "county_index": county_index, + "block_geoid": blocks.astype("S"), + "tract_geoid": np.asarray([block[:11] for block in blocks], dtype="S11"), + "cbsa_code": np.asarray(["99999"] * len(blocks), dtype="S5"), + "sldu": np.asarray(["000"] * len(blocks), dtype="S3"), + "sldl": np.asarray(["000"] * len(blocks), dtype="S3"), + "place_fips": np.asarray(["00000"] * len(blocks), dtype="S5"), + "vtd": np.asarray(["000000"] * len(blocks), dtype="S6"), + "puma": np.asarray(["00000"] * len(blocks), dtype="S5"), + "zcta": np.asarray(["00000"] * len(blocks), dtype="S5"), + } + + +def _fake_county_name_lookup(county_indices: np.ndarray) -> np.ndarray: + return np.asarray( + [f"COUNTY_{int(idx)}" for idx in np.asarray(county_indices)], + dtype="S16", + ) + + +def test_build_h5_writes_structural_output_from_real_fixture(monkeypatch, tmp_path): + ( + publish_local_area, + clone_and_assign, + entity_graph_module, + source_dataset_module, + builder_module, + us_augmentations, + calibration_utils, + ) = _load_runtime_modules(monkeypatch) + + snapshot, arrays = _make_fixture_snapshot( + source_dataset_module, + entity_graph_module, + ) + geography = _make_geography( + clone_and_assign.GeographyAssignment, + len(snapshot.household_ids), + ) + + weights = np.zeros(len(snapshot.household_ids) * len(TEST_CDS), dtype=float) + positive_households = [(0, 0, 1.25), (0, 1, 2.5), (1, 2, 1.75), (1, 3, 3.0)] + for clone_idx, household_idx, weight in positive_households: + weights[clone_idx * len(snapshot.household_ids) + household_idx] = weight + + expected_household_ids = snapshot.household_ids[[0, 1, 2, 3]] + expected_household_count = len(expected_household_ids) + expected_person_count = int( + np.isin(arrays["person_household_id"]["2023"], expected_household_ids).sum() + ) + + real_builder = builder_module.LocalAreaDatasetBuilder + augmentation_service = us_augmentations.USAugmentationService( + geography_lookup=_fake_geography_lookup, + county_name_lookup=_fake_county_name_lookup, + cd_geoadj_loader=lambda cds: {str(cd): 1.0 for cd in cds}, + threshold_calculator=calibration_utils.calculate_spm_thresholds_vectorized, + ) + monkeypatch.setattr( + publish_local_area, + "LocalAreaDatasetBuilder", + lambda: real_builder(us_augmentations=augmentation_service), + ) + + output_path = tmp_path / "minimal_build_h5_output.h5" + publish_local_area.build_h5( + weights=weights, + geography=geography, + dataset_path=FIXTURE_PATH, + output_path=output_path, + cd_subset=list(TEST_CDS), + takeup_filter=[], + source_snapshot=snapshot, + ) + + assert output_path.exists() + + with h5py.File(output_path, "r") as h5_file: + period = "2023" + + household_ids = h5_file["household_id"][period][:] + person_ids = h5_file["person_id"][period][:] + household_weights = h5_file["household_weight"][period][:] + districts = h5_file["congressional_district_geoid"][period][:] + state_fips = h5_file["state_fips"][period][:] + ages = h5_file["age"][period][:] + county = h5_file["county"][period][:] + spm_thresholds = h5_file["spm_unit_spm_threshold"][period][:] + snap_takeup = h5_file["takes_up_snap_if_eligible"][period][:] + + assert len(household_ids) == expected_household_count + assert len(person_ids) == expected_person_count + assert len(household_weights) == expected_household_count + assert len(districts) == expected_household_count + assert len(state_fips) == expected_household_count + assert len(ages) == expected_person_count + assert len(county) == expected_household_count + assert len(spm_thresholds) > 0 + assert len(snap_takeup) > 0 + + assert np.array_equal(np.unique(districts), np.asarray([200, 3701], dtype=np.int32)) + assert np.array_equal(np.unique(state_fips), np.asarray([2, 37], dtype=np.int32)) + assert np.all(household_weights > 0) + assert np.all(np.isfinite(spm_thresholds)) + assert set(np.unique(snap_takeup)).issubset({False, True}) + assert county.dtype.kind == "S" + assert np.array_equal( + np.sort(household_weights), + np.sort(np.asarray([weight for _, _, weight in positive_households], dtype=np.float32)), + ) diff --git a/tests/unit/calibration/test_local_h5_area_catalog.py b/tests/unit/calibration/test_local_h5_area_catalog.py new file mode 100644 index 000000000..648c2d280 --- /dev/null +++ b/tests/unit/calibration/test_local_h5_area_catalog.py @@ -0,0 +1,122 @@ +import importlib.util +from pathlib import Path +import sys +import types + + +def _module_path(*parts: str) -> Path: + return Path(__file__).resolve().parents[3].joinpath(*parts) + + +def _load_module(module_name: str, path: Path): + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _install_fake_package_hierarchy(monkeypatch): + package = types.ModuleType("policyengine_us_data") + package.__path__ = [] + calibration_package = types.ModuleType("policyengine_us_data.calibration") + calibration_package.__path__ = [] + local_h5_package = types.ModuleType("policyengine_us_data.calibration.local_h5") + local_h5_package.__path__ = [] + + monkeypatch.setitem(sys.modules, "policyengine_us_data", package) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration", + calibration_package, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5", + local_h5_package, + ) + + calibration_utils = types.ModuleType( + "policyengine_us_data.calibration.calibration_utils" + ) + calibration_utils.STATE_CODES = {1: "AL", 2: "AK", 36: "NY"} + calibration_utils.get_all_cds_from_database = lambda _db_uri: ["0101", "0200", "3607"] + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.calibration_utils", + calibration_utils, + ) + + contracts = _load_module( + "policyengine_us_data.calibration.local_h5.contracts", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "contracts.py", + ), + ) + area_catalog = _load_module( + "policyengine_us_data.calibration.local_h5.area_catalog", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "area_catalog.py", + ), + ) + return contracts, area_catalog + + +def test_us_area_catalog_constructs_weighted_regional_entries(monkeypatch): + _, area_catalog_module = _install_fake_package_hierarchy(monkeypatch) + USAreaCatalog = area_catalog_module.USAreaCatalog + + geography = types.SimpleNamespace( + county_fips=["36061", "01001", "36047"], + cd_geoid=["3607", "0101", "3607"], + ) + catalog = USAreaCatalog() + + entries = catalog.resolved_regional_entries( + "sqlite:////tmp/policy_data.db", + geography=geography, + ) + + state_entries = [e for e in entries if e.request.area_type == "state"] + district_entries = [e for e in entries if e.request.area_type == "district"] + city_entries = [e for e in entries if e.request.area_type == "city"] + + assert [e.request.area_id for e in district_entries] == ["AL-01", "AK-01", "NY-07"] + assert city_entries[0].request.output_relative_path == "cities/NYC.h5" + assert city_entries[0].request.validation_geographic_ids == ("3607",) + assert state_entries[0].request.filters[0].value == ("0101",) + assert state_entries[1].weight == 1 + assert city_entries[0].weight == 11 + + +def test_us_area_catalog_constructs_national_request(monkeypatch): + _, area_catalog_module = _install_fake_package_hierarchy(monkeypatch) + USAreaCatalog = area_catalog_module.USAreaCatalog + + entry = USAreaCatalog().resolved_national_entry() + + assert entry.request.area_type == "national" + assert entry.request.output_relative_path == "national/US.h5" + assert entry.request.validation_geo_level == "national" + assert entry.request.validation_geographic_ids == ("US",) + assert entry.weight == 1 + + +def test_us_area_catalog_skips_states_without_any_cds(monkeypatch): + _, area_catalog_module = _install_fake_package_hierarchy(monkeypatch) + USAreaCatalog = area_catalog_module.USAreaCatalog + + catalog = USAreaCatalog() + entries = catalog.regional_entries_from_cds(["0101", "3607"]) + + state_ids = [e.request.area_id for e in entries if e.request.area_type == "state"] + + assert state_ids == ["AL", "NY"] diff --git a/tests/unit/calibration/test_local_h5_build_h5_facade.py b/tests/unit/calibration/test_local_h5_build_h5_facade.py new file mode 100644 index 000000000..430116e76 --- /dev/null +++ b/tests/unit/calibration/test_local_h5_build_h5_facade.py @@ -0,0 +1,313 @@ +import importlib.util +from dataclasses import dataclass +from pathlib import Path +import sys +import types + +import numpy as np +import pytest + + +def _module_path(*parts: str) -> Path: + return Path(__file__).resolve().parents[3].joinpath(*parts) + + +def _load_module(module_name: str, path: Path): + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _load_publish_local_area_module(monkeypatch): + fake_policyengine_us = types.ModuleType("policyengine_us") + fake_policyengine_us.Microsimulation = object + monkeypatch.setitem(sys.modules, "policyengine_us", fake_policyengine_us) + + package = types.ModuleType("policyengine_us_data") + package.__path__ = [] + calibration_package = types.ModuleType("policyengine_us_data.calibration") + calibration_package.__path__ = [] + local_h5_package = types.ModuleType("policyengine_us_data.calibration.local_h5") + local_h5_package.__path__ = [] + utils_package = types.ModuleType("policyengine_us_data.utils") + utils_package.__path__ = [] + + monkeypatch.setitem(sys.modules, "policyengine_us_data", package) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration", + calibration_package, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5", + local_h5_package, + ) + monkeypatch.setitem(sys.modules, "policyengine_us_data.utils", utils_package) + + @dataclass(frozen=True) + class FakeAreaFilter: + geography_field: str + op: str + value: tuple[str, ...] + + contracts = types.ModuleType("policyengine_us_data.calibration.local_h5.contracts") + contracts.AreaFilter = FakeAreaFilter + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.contracts", + contracts, + ) + + source_dataset = types.ModuleType( + "policyengine_us_data.calibration.local_h5.source_dataset" + ) + + @dataclass(frozen=True) + class FakeSourceDatasetSnapshot: + dataset_path: Path + time_period: int + n_households: int + + class FakeReader: + instances = [] + snapshot = FakeSourceDatasetSnapshot( + dataset_path=Path("/tmp/source.h5"), + time_period=2024, + n_households=2, + ) + + def __init__(self, sub_entities): + self.sub_entities = tuple(sub_entities) + self.load_calls = [] + FakeReader.instances.append(self) + + def load(self, dataset_path): + self.load_calls.append(Path(dataset_path)) + return self.snapshot + + source_dataset.PolicyEngineDatasetReader = FakeReader + source_dataset.SourceDatasetSnapshot = FakeSourceDatasetSnapshot + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.source_dataset", + source_dataset, + ) + + builder_module = types.ModuleType( + "policyengine_us_data.calibration.local_h5.builder" + ) + writer_module = types.ModuleType( + "policyengine_us_data.calibration.local_h5.writer" + ) + + class FakeBuilder: + instances = [] + + def __init__(self): + self.calls = [] + FakeBuilder.instances.append(self) + + def build( + self, + *, + weights, + geography, + source, + filters=(), + takeup_filter=None, + ): + self.calls.append( + { + "weights": np.asarray(weights), + "geography": geography, + "source": source, + "filters": filters, + "takeup_filter": takeup_filter, + } + ) + selection = types.SimpleNamespace( + n_household_clones=2, + active_weights=np.asarray([1.0, 2.0], dtype=float), + ) + reindexed = types.SimpleNamespace( + person_source_indices=np.asarray([0, 1, 2], dtype=np.int64), + entity_source_indices={ + "tax_unit": np.asarray([0], dtype=np.int64), + "spm_unit": np.asarray([0], dtype=np.int64), + "family": np.asarray([0], dtype=np.int64), + "marital_unit": np.asarray([0], dtype=np.int64), + }, + ) + payload = types.SimpleNamespace(dataset_count=7) + return types.SimpleNamespace( + payload=payload, + selection=selection, + reindexed=reindexed, + time_period=2024, + ) + + class FakeWriter: + instances = [] + + def __init__(self): + self.write_calls = [] + self.verify_calls = [] + FakeWriter.instances.append(self) + + def write_payload(self, payload, output_path): + self.write_calls.append((payload, Path(output_path))) + return Path(output_path) + + def verify_output(self, output_path, *, time_period): + self.verify_calls.append((Path(output_path), time_period)) + return { + "household_count": 2, + "person_count": 3, + "household_weight_sum": 3.0, + } + + builder_module.LocalAreaDatasetBuilder = FakeBuilder + writer_module.H5Writer = FakeWriter + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.builder", + builder_module, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.writer", + writer_module, + ) + + us_augmentations_module = types.ModuleType( + "policyengine_us_data.calibration.local_h5.us_augmentations" + ) + us_augmentations_module.build_reported_takeup_anchors = ( + lambda data, time_period: {} + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.us_augmentations", + us_augmentations_module, + ) + + weights_module = types.ModuleType( + "policyengine_us_data.calibration.local_h5.weights" + ) + weights_module.infer_clone_count_from_weight_length = ( + lambda length, n_households: length // max(n_households, 1) + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.weights", + weights_module, + ) + + calibration_utils = types.ModuleType( + "policyengine_us_data.calibration.calibration_utils" + ) + calibration_utils.STATE_CODES = {1: "AL"} + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.calibration_utils", + calibration_utils, + ) + + clone_and_assign = types.ModuleType( + "policyengine_us_data.calibration.clone_and_assign" + ) + clone_and_assign.GeographyAssignment = object + clone_and_assign.assign_random_geography = lambda *_a, **_k: object() + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.clone_and_assign", + clone_and_assign, + ) + + hf_module = types.ModuleType("policyengine_us_data.utils.huggingface") + hf_module.download_calibration_inputs = lambda *_a, **_k: None + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.utils.huggingface", + hf_module, + ) + + upload_module = types.ModuleType("policyengine_us_data.utils.data_upload") + upload_module.upload_local_area_file = lambda *_a, **_k: None + upload_module.upload_local_area_batch_to_hf = lambda *_a, **_k: None + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.utils.data_upload", + upload_module, + ) + + takeup_module = types.ModuleType("policyengine_us_data.utils.takeup") + takeup_module.SIMPLE_TAKEUP_VARS = [] + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.utils.takeup", + takeup_module, + ) + + module = _load_module( + "publish_local_area_under_test", + _module_path( + "policyengine_us_data", + "calibration", + "publish_local_area.py", + ), + ) + return module, FakeReader, FakeBuilder, FakeWriter + + +def test_build_h5_delegates_to_builder_and_writer(monkeypatch, tmp_path): + publish_local_area, FakeReader, FakeBuilder, FakeWriter = ( + _load_publish_local_area_module(monkeypatch) + ) + + output_path = tmp_path / "states" / "AL.h5" + dataset_path = Path("/tmp/source.h5") + geography = types.SimpleNamespace(name="geography") + + result = publish_local_area.build_h5( + weights=np.asarray([1.0, 0.0, 0.0, 2.0], dtype=float), + geography=geography, + dataset_path=dataset_path, + output_path=output_path, + cd_subset=["0101"], + takeup_filter=["snap"], + ) + + assert result == output_path + assert FakeReader.instances[0].load_calls == [dataset_path] + builder_call = FakeBuilder.instances[0].calls[0] + assert builder_call["geography"] is geography + assert builder_call["source"] == FakeReader.snapshot + assert builder_call["takeup_filter"] == ["snap"] + assert len(builder_call["filters"]) == 1 + assert builder_call["filters"][0].geography_field == "cd_geoid" + assert builder_call["filters"][0].value == ("0101",) + assert FakeWriter.instances[0].write_calls[0][1] == output_path + assert FakeWriter.instances[0].verify_calls[0] == (output_path, 2024) + + +def test_build_h5_rejects_mismatched_source_snapshot(monkeypatch, tmp_path): + publish_local_area, _, _, _ = _load_publish_local_area_module(monkeypatch) + + with pytest.raises(ValueError, match="source_snapshot.dataset_path does not match"): + publish_local_area.build_h5( + weights=np.asarray([1.0], dtype=float), + geography=types.SimpleNamespace(), + dataset_path=tmp_path / "expected.h5", + output_path=tmp_path / "out.h5", + source_snapshot=types.SimpleNamespace( + dataset_path=tmp_path / "other.h5", + time_period=2024, + n_households=1, + ), + ) diff --git a/tests/unit/calibration/test_local_h5_builder.py b/tests/unit/calibration/test_local_h5_builder.py new file mode 100644 index 000000000..565cc8cda --- /dev/null +++ b/tests/unit/calibration/test_local_h5_builder.py @@ -0,0 +1,341 @@ +import importlib.util +from pathlib import Path +import sys +import types + +import numpy as np +import pytest + + +def _module_path(*parts: str) -> Path: + return Path(__file__).resolve().parents[3].joinpath(*parts) + + +def _load_module(module_name: str, path: Path): + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _install_fake_package_hierarchy(monkeypatch): + package = types.ModuleType("policyengine_us_data") + package.__path__ = [] + calibration_package = types.ModuleType("policyengine_us_data.calibration") + calibration_package.__path__ = [] + local_h5_package = types.ModuleType("policyengine_us_data.calibration.local_h5") + local_h5_package.__path__ = [] + + monkeypatch.setitem(sys.modules, "policyengine_us_data", package) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration", + calibration_package, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5", + local_h5_package, + ) + + contracts = _load_module( + "policyengine_us_data.calibration.local_h5.contracts", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "contracts.py", + ), + ) + weights = _load_module( + "policyengine_us_data.calibration.local_h5.weights", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "weights.py", + ), + ) + selection = _load_module( + "policyengine_us_data.calibration.local_h5.selection", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "selection.py", + ), + ) + entity_graph = _load_module( + "policyengine_us_data.calibration.local_h5.entity_graph", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "entity_graph.py", + ), + ) + source_dataset = _load_module( + "policyengine_us_data.calibration.local_h5.source_dataset", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "source_dataset.py", + ), + ) + reindexing = _load_module( + "policyengine_us_data.calibration.local_h5.reindexing", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "reindexing.py", + ), + ) + + fake_us_augmentations = types.ModuleType( + "policyengine_us_data.calibration.local_h5.us_augmentations" + ) + fake_us_augmentations.USAugmentationService = type( + "USAugmentationService", (), {} + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.us_augmentations", + fake_us_augmentations, + ) + + variables = _load_module( + "policyengine_us_data.calibration.local_h5.variables", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "variables.py", + ), + ) + builder = _load_module( + "policyengine_us_data.calibration.local_h5.builder", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "builder.py", + ), + ) + return contracts, selection, reindexing, variables, builder + + +def _make_selection(selection_module, *, active_blocks): + CloneSelection = selection_module.CloneSelection + count = len(active_blocks) + return CloneSelection( + active_clone_indices=np.arange(count, dtype=np.int64), + active_household_indices=np.arange(count, dtype=np.int64), + active_weights=np.asarray([1.5, 2.5][:count], dtype=float), + active_block_geoids=np.asarray(active_blocks, dtype=str), + active_cd_geoids=np.asarray(["0101", "0200"][:count], dtype=str), + active_county_fips=np.asarray(["01001", "02020"][:count], dtype=str), + active_state_fips=np.asarray([1, 2][:count], dtype=np.int64), + ) + + +def test_local_area_dataset_builder_builds_payload_and_injects_ids(monkeypatch): + contracts, selection_module, reindexing_module, variables, builder_module = ( + _install_fake_package_hierarchy(monkeypatch) + ) + + AreaFilter = contracts.AreaFilter + H5Payload = variables.H5Payload + ReindexedEntities = reindexing_module.ReindexedEntities + LocalAreaDatasetBuilder = builder_module.LocalAreaDatasetBuilder + LocalAreaBuildArtifacts = builder_module.LocalAreaBuildArtifacts + + selection = _make_selection(selection_module, active_blocks=["block-a", "block-b"]) + reindexed = ReindexedEntities( + household_source_indices=np.asarray([1, 0], dtype=np.int64), + person_source_indices=np.asarray([2, 1, 0], dtype=np.int64), + entity_source_indices={ + "tax_unit": np.asarray([1, 0], dtype=np.int64), + "spm_unit": np.asarray([0], dtype=np.int64), + }, + persons_per_clone=np.asarray([2, 1], dtype=np.int64), + entities_per_clone={ + "tax_unit": np.asarray([1, 1], dtype=np.int64), + "spm_unit": np.asarray([1, 0], dtype=np.int64), + }, + new_household_ids=np.asarray([10, 11], dtype=np.int32), + new_person_ids=np.asarray([20, 21, 22], dtype=np.int32), + new_person_household_ids=np.asarray([10, 10, 11], dtype=np.int32), + new_entity_ids={ + "tax_unit": np.asarray([30, 31], dtype=np.int32), + "spm_unit": np.asarray([40], dtype=np.int32), + }, + new_person_entity_ids={ + "tax_unit": np.asarray([30, 30, 31], dtype=np.int32), + "spm_unit": np.asarray([40, 40, 40], dtype=np.int32), + }, + ) + + class FakeSelector: + def __init__(self): + self.calls = [] + + def select(self, weights, geography, *, filters=()): + self.calls.append((weights, geography, filters)) + return selection + + class FakeReindexer: + def __init__(self): + self.calls = [] + + def reindex(self, source, selected): + self.calls.append((source, selected)) + return reindexed + + class FakeVariableCloner: + def __init__(self): + self.calls = [] + + def clone(self, source, reindexed_input, policy): + self.calls.append((source, reindexed_input, policy)) + return H5Payload( + variables={ + "source_income": {2024: np.asarray([100.0, 200.0])}, + } + ) + + class FakeAugmenter: + def __init__(self): + self.calls = [] + + def apply_all( + self, + data, + *, + time_period, + selection, + source, + reindexed, + takeup_filter, + ): + self.calls.append( + (time_period, selection, source, reindexed, tuple(takeup_filter or ())) + ) + data["augmented"] = {time_period: np.asarray([1, 1], dtype=np.int8)} + + selector = FakeSelector() + reindexer = FakeReindexer() + cloner = FakeVariableCloner() + augmenter = FakeAugmenter() + builder = LocalAreaDatasetBuilder( + selector=selector, + reindexer=reindexer, + variable_cloner=cloner, + us_augmentations=augmenter, + ) + + source = types.SimpleNamespace( + n_households=2, + time_period=2024, + ) + geography = types.SimpleNamespace() + filters = ( + AreaFilter( + geography_field="cd_geoid", + op="in", + value=("0101",), + ), + ) + + built = builder.build( + weights=np.asarray([1.0, 0.0, 0.0, 2.0], dtype=float), + geography=geography, + source=source, + filters=filters, + takeup_filter=("snap",), + ) + + assert isinstance(built, LocalAreaBuildArtifacts) + assert built.selection is selection + assert built.reindexed is reindexed + assert built.time_period == 2024 + np.testing.assert_array_equal( + built.payload.variables["household_id"][2024], + np.asarray([10, 11], dtype=np.int32), + ) + np.testing.assert_array_equal( + built.payload.variables["person_id"][2024], + np.asarray([20, 21, 22], dtype=np.int32), + ) + np.testing.assert_array_equal( + built.payload.variables["household_weight"][2024], + np.asarray([1.5, 2.5], dtype=np.float32), + ) + np.testing.assert_array_equal( + built.payload.variables["augmented"][2024], + np.asarray([1, 1], dtype=np.int8), + ) + assert selector.calls[0][2] == filters + assert reindexer.calls[0] == (source, selection) + assert cloner.calls[0][0] == source + assert cloner.calls[0][1] is reindexed + assert augmenter.calls[0][0] == 2024 + assert augmenter.calls[0][-1] == ("snap",) + + +def test_local_area_dataset_builder_rejects_empty_selection(monkeypatch): + _, selection_module, _, _, builder_module = _install_fake_package_hierarchy( + monkeypatch + ) + LocalAreaDatasetBuilder = builder_module.LocalAreaDatasetBuilder + + class FakeSelector: + def select(self, *_args, **_kwargs): + return _make_selection(selection_module, active_blocks=[]) + + builder = LocalAreaDatasetBuilder( + selector=FakeSelector(), + reindexer=types.SimpleNamespace(), + variable_cloner=types.SimpleNamespace(), + us_augmentations=types.SimpleNamespace(), + ) + + with pytest.raises(ValueError, match="No active clones after filtering"): + builder.build( + weights=np.asarray([0.0], dtype=float), + geography=types.SimpleNamespace(), + source=types.SimpleNamespace(n_households=1, time_period=2024), + filters=(), + takeup_filter=None, + ) + + +def test_local_area_dataset_builder_rejects_empty_block_geoids(monkeypatch): + _, selection_module, _, _, builder_module = _install_fake_package_hierarchy( + monkeypatch + ) + LocalAreaDatasetBuilder = builder_module.LocalAreaDatasetBuilder + + class FakeSelector: + def select(self, *_args, **_kwargs): + return _make_selection(selection_module, active_blocks=[""]) + + builder = LocalAreaDatasetBuilder( + selector=FakeSelector(), + reindexer=types.SimpleNamespace(), + variable_cloner=types.SimpleNamespace(), + us_augmentations=types.SimpleNamespace(), + ) + + with pytest.raises(ValueError, match="empty block GEOIDs"): + builder.build( + weights=np.asarray([1.0], dtype=float), + geography=types.SimpleNamespace(), + source=types.SimpleNamespace(n_households=1, time_period=2024), + filters=(), + takeup_filter=None, + ) diff --git a/tests/unit/calibration/test_local_h5_contracts.py b/tests/unit/calibration/test_local_h5_contracts.py new file mode 100644 index 000000000..504087b0d --- /dev/null +++ b/tests/unit/calibration/test_local_h5_contracts.py @@ -0,0 +1,247 @@ +import json +import importlib.util +from pathlib import Path +import sys + +import pytest + + +def _load_contracts_module(): + module_path = ( + Path(__file__).resolve().parents[3] + / "policyengine_us_data" + / "calibration" + / "local_h5" + / "contracts.py" + ) + spec = importlib.util.spec_from_file_location("local_h5_contracts", module_path) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +contracts = _load_contracts_module() +AreaBuildRequest = contracts.AreaBuildRequest +AreaBuildResult = contracts.AreaBuildResult +AreaFilter = contracts.AreaFilter +PublishingInputBundle = contracts.PublishingInputBundle +ValidationIssue = contracts.ValidationIssue +ValidationPolicy = contracts.ValidationPolicy +ValidationResult = contracts.ValidationResult +WorkerResult = contracts.WorkerResult + + +def test_area_filter_validates_eq_vs_in_shape(): + AreaFilter(geography_field="state_fips", op="eq", value=6) + AreaFilter(geography_field="county_fips", op="in", value=("06037", "06059")) + + with pytest.raises(ValueError, match="must be a tuple"): + AreaFilter(geography_field="county_fips", op="in", value="06037") + + with pytest.raises(ValueError, match="must not be a tuple"): + AreaFilter(geography_field="state_fips", op="eq", value=(6, 12)) + + +def test_area_build_request_national_defaults(): + request = AreaBuildRequest.national() + + assert request.area_type == "national" + assert request.area_id == "US" + assert request.output_relative_path == "national/US.h5" + assert request.validation_geo_level == "national" + assert request.validation_geographic_ids == ("US",) + assert request.filters == () + + +def test_area_build_request_requires_validation_level_if_ids_provided(): + with pytest.raises(ValueError, match="validation_geo_level"): + AreaBuildRequest( + area_type="district", + area_id="CA-12", + display_name="CA-12", + output_relative_path="districts/CA-12.h5", + validation_geographic_ids=("612",), + ) + + +def test_publishing_input_bundle_required_paths_and_json_dict(): + bundle = PublishingInputBundle( + weights_path=Path("/tmp/weights.npy"), + source_dataset_path=Path("/tmp/source.h5"), + target_db_path=Path("/tmp/policy_data.db"), + calibration_package_path=Path("/tmp/calibration_package.pkl"), + run_config_path=Path("/tmp/config.json"), + run_id="1.0.0_abc", + version="1.0.0", + n_clones=430, + seed=42, + ) + + assert bundle.required_paths() == ( + Path("/tmp/weights.npy"), + Path("/tmp/source.h5"), + Path("/tmp/policy_data.db"), + Path("/tmp/calibration_package.pkl"), + ) + assert json.loads(json.dumps(bundle.to_dict()))["version"] == "1.0.0" + + +def test_validation_policy_defaults_are_conservative(): + policy = ValidationPolicy() + + assert policy.enabled is True + assert policy.fail_on_exception is False + assert policy.fail_on_validation_failure is False + assert policy.run_sanity_checks is True + assert policy.run_target_validation is True + assert policy.run_national_validation is True + + +def test_completed_area_build_result_requires_output_path_and_no_build_error(): + request = AreaBuildRequest.national() + + with pytest.raises(ValueError, match="requires output_path"): + AreaBuildResult( + request=request, + build_status="completed", + ) + + with pytest.raises(ValueError, match="must not include build_error"): + AreaBuildResult( + request=request, + build_status="completed", + output_path=Path("/tmp/US.h5"), + build_error="should not be here", + ) + + +def test_failed_area_build_result_requires_build_error(): + request = AreaBuildRequest.national() + + with pytest.raises(ValueError, match="requires build_error"): + AreaBuildResult( + request=request, + build_status="failed", + ) + + +def test_worker_result_enforces_completed_and_failed_buckets(): + request = AreaBuildRequest.national() + completed = AreaBuildResult( + request=request, + build_status="completed", + output_path=Path("/tmp/US.h5"), + ) + failed = AreaBuildResult( + request=request, + build_status="failed", + build_error="boom", + ) + + result = WorkerResult(completed=(completed,), failed=(failed,)) + assert result.all_results() == (completed, failed) + + with pytest.raises(ValueError, match="completed"): + WorkerResult(completed=(failed,), failed=()) + + with pytest.raises(ValueError, match="failed"): + WorkerResult(completed=(), failed=(completed,)) + + +def test_worker_result_and_validation_result_are_json_serializable(): + request = AreaBuildRequest( + area_type="state", + area_id="CA", + display_name="California", + output_relative_path="states/CA.h5", + filters=(AreaFilter(geography_field="state_fips", op="eq", value=6),), + validation_geo_level="state", + validation_geographic_ids=("6",), + metadata={"takeup_filter": "snap,ssi"}, + ) + validation = ValidationResult( + status="failed", + rows=({"target_name": "population", "rel_abs_error": 0.12},), + issues=( + ValidationIssue( + code="sanity_fail", + message="population exceeded ceiling", + severity="error", + details={"target_name": "population"}, + ), + ), + summary={"n_targets": 1, "n_fail": 1}, + ) + completed = AreaBuildResult( + request=request, + build_status="completed", + output_path=Path("/tmp/states/CA.h5"), + validation=validation, + ) + worker_result = WorkerResult( + completed=(completed,), + failed=(), + worker_issues=( + ValidationIssue( + code="partial_validation", + message="one validator retried", + severity="warning", + ), + ), + ) + + payload = worker_result.to_dict() + roundtrip = json.loads(json.dumps(payload)) + + assert roundtrip["completed"][0]["request"]["area_id"] == "CA" + assert roundtrip["completed"][0]["validation"]["status"] == "failed" + assert roundtrip["completed"][0]["output_path"] == "/tmp/states/CA.h5" + assert roundtrip["worker_issues"][0]["severity"] == "warning" + + +def test_contracts_round_trip_from_dict(): + request = AreaBuildRequest( + area_type="district", + area_id="CA-12", + display_name="CA-12", + output_relative_path="districts/CA-12.h5", + filters=( + AreaFilter(geography_field="cd_geoid", op="in", value=("612",)), + ), + validation_geo_level="district", + validation_geographic_ids=("612",), + metadata={"source": "catalog"}, + ) + validation = ValidationResult( + status="error", + issues=( + ValidationIssue( + code="validation_exception", + message="validator crashed", + severity="error", + details={"traceback": "boom"}, + ), + ), + summary={"n_targets": 0}, + ) + result = WorkerResult( + completed=( + AreaBuildResult( + request=request, + build_status="completed", + output_path=Path("/tmp/districts/CA-12.h5"), + validation=validation, + ), + ), + failed=(), + ) + + restored_request = AreaBuildRequest.from_dict(request.to_dict()) + restored_result = WorkerResult.from_dict(result.to_dict()) + + assert restored_request == request + assert restored_result.completed[0].request == request + assert restored_result.completed[0].validation.issues[0].code == "validation_exception" diff --git a/tests/unit/calibration/test_local_h5_fingerprinting.py b/tests/unit/calibration/test_local_h5_fingerprinting.py new file mode 100644 index 000000000..cc748f04a --- /dev/null +++ b/tests/unit/calibration/test_local_h5_fingerprinting.py @@ -0,0 +1,328 @@ +import importlib.util +import pickle +from dataclasses import dataclass +from pathlib import Path +import sys +import types + +import numpy as np + + +def _module_path(*parts: str) -> Path: + return Path(__file__).resolve().parents[3].joinpath(*parts) + + +def _load_module(module_name: str, path: Path): + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _install_fake_package_hierarchy(monkeypatch): + package = types.ModuleType("policyengine_us_data") + package.__path__ = [] + calibration_package = types.ModuleType("policyengine_us_data.calibration") + calibration_package.__path__ = [] + local_h5_package = types.ModuleType("policyengine_us_data.calibration.local_h5") + local_h5_package.__path__ = [] + + monkeypatch.setitem(sys.modules, "policyengine_us_data", package) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration", + calibration_package, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5", + local_h5_package, + ) + + clone_module = types.ModuleType("policyengine_us_data.calibration.clone_and_assign") + + @dataclass(frozen=True) + class FakeGeographyAssignment: + block_geoid: np.ndarray + cd_geoid: np.ndarray + county_fips: np.ndarray + state_fips: np.ndarray + n_records: int + n_clones: int + + def fake_assign_random_geography(*, n_records, n_clones, seed): + total = n_records * n_clones + return FakeGeographyAssignment( + block_geoid=np.asarray(["990010000000001"] * total, dtype=str), + cd_geoid=np.asarray(["9901"] * total, dtype=str), + county_fips=np.asarray(["99001"] * total, dtype=str), + state_fips=np.asarray([99] * total, dtype=np.int64), + n_records=n_records, + n_clones=n_clones, + ) + + clone_module.GeographyAssignment = FakeGeographyAssignment + clone_module.assign_random_geography = fake_assign_random_geography + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.clone_and_assign", + clone_module, + ) + + package_geography = _load_module( + "policyengine_us_data.calibration.local_h5.package_geography", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "package_geography.py", + ), + ) + fingerprinting = _load_module( + "policyengine_us_data.calibration.local_h5.fingerprinting", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "fingerprinting.py", + ), + ) + return package_geography, fingerprinting + + +def _write_bytes(path: Path, data: bytes) -> Path: + path.write_bytes(data) + return path + + +def _write_weights(path: Path, values: np.ndarray) -> Path: + np.save(path, values) + return path + + +def _write_package(path: Path, *, geography: dict, metadata: dict | None = None) -> Path: + with open(path, "wb") as f: + pickle.dump( + { + "geography": geography, + "metadata": metadata or {}, + "X_sparse": types.SimpleNamespace(shape=(1, 1)), + }, + f, + protocol=pickle.HIGHEST_PROTOCOL, + ) + return path + + +def _sample_geography(*, suffix: str) -> dict: + return { + "block_geoid": np.asarray( + [f"06001000100100{suffix}", f"36061000100100{suffix}"], + dtype=str, + ), + "cd_geoid": np.asarray(["601", "1208"], dtype=str), + "county_fips": np.asarray(["06001", "36061"], dtype=str), + "state_fips": np.asarray([6, 36], dtype=np.int64), + "n_records": 2, + "n_clones": 1, + } + + +def test_create_publish_fingerprint_round_trips_record(monkeypatch, tmp_path): + _, fingerprinting = _install_fake_package_hierarchy(monkeypatch) + FingerprintService = fingerprinting.FingerprintService + + weights_path = _write_weights( + tmp_path / "weights.npy", + np.asarray([1.0, 2.0], dtype=np.float64), + ) + dataset_path = _write_bytes(tmp_path / "dataset.h5", b"dataset-one") + package_path = _write_package( + tmp_path / "package.pkl", + geography=_sample_geography(suffix="1"), + metadata={"git_commit": "abc"}, + ) + + service = FingerprintService() + record = service.create_publish_fingerprint( + weights_path=weights_path, + dataset_path=dataset_path, + calibration_package_path=package_path, + n_clones=1, + seed=42, + ) + + assert len(record.digest) == 16 + assert record.components is not None + assert record.inputs["weights_path"] == str(weights_path) + + payload = service.serialize(record) + restored = service.deserialize(payload) + + assert restored.digest == record.digest + assert restored.components == record.components + assert service.matches(record, restored) + + +def test_write_and_read_record_round_trip(monkeypatch, tmp_path): + _, fingerprinting = _install_fake_package_hierarchy(monkeypatch) + FingerprintService = fingerprinting.FingerprintService + + weights_path = _write_weights( + tmp_path / "weights.npy", + np.asarray([1.0, 2.0], dtype=np.float64), + ) + dataset_path = _write_bytes(tmp_path / "dataset.h5", b"dataset-one") + package_path = _write_package( + tmp_path / "package.pkl", + geography=_sample_geography(suffix="1"), + ) + + service = FingerprintService() + record = service.create_publish_fingerprint( + weights_path=weights_path, + dataset_path=dataset_path, + calibration_package_path=package_path, + n_clones=1, + seed=42, + ) + + record_path = tmp_path / "fingerprint.json" + service.write_record(record_path, record) + restored = service.read_record(record_path) + + assert restored == record + + +def test_publish_fingerprint_changes_when_geography_changes(monkeypatch, tmp_path): + _, fingerprinting = _install_fake_package_hierarchy(monkeypatch) + FingerprintService = fingerprinting.FingerprintService + + weights_path = _write_weights( + tmp_path / "weights.npy", + np.asarray([1.0, 2.0], dtype=np.float64), + ) + dataset_path = _write_bytes(tmp_path / "dataset.h5", b"dataset-one") + package_a = _write_package( + tmp_path / "package-a.pkl", + geography=_sample_geography(suffix="1"), + ) + package_b = _write_package( + tmp_path / "package-b.pkl", + geography=_sample_geography(suffix="2"), + ) + + service = FingerprintService() + record_a = service.create_publish_fingerprint( + weights_path=weights_path, + dataset_path=dataset_path, + calibration_package_path=package_a, + n_clones=1, + seed=42, + ) + record_b = service.create_publish_fingerprint( + weights_path=weights_path, + dataset_path=dataset_path, + calibration_package_path=package_b, + n_clones=1, + seed=42, + ) + + assert record_a.components is not None + assert record_b.components is not None + assert record_a.components.geography_sha256 != record_b.components.geography_sha256 + assert record_a.digest != record_b.digest + + +def test_publish_fingerprint_ignores_non_geography_package_metadata( + monkeypatch, tmp_path +): + _, fingerprinting = _install_fake_package_hierarchy(monkeypatch) + FingerprintService = fingerprinting.FingerprintService + + weights_path = _write_weights( + tmp_path / "weights.npy", + np.asarray([1.0, 2.0], dtype=np.float64), + ) + dataset_path = _write_bytes(tmp_path / "dataset.h5", b"dataset-one") + geography = _sample_geography(suffix="1") + package_a = _write_package( + tmp_path / "package-a.pkl", + geography=geography, + metadata={"git_commit": "abc", "created_at": "one"}, + ) + package_b = _write_package( + tmp_path / "package-b.pkl", + geography=geography, + metadata={"git_commit": "def", "created_at": "two"}, + ) + + service = FingerprintService() + record_a = service.create_publish_fingerprint( + weights_path=weights_path, + dataset_path=dataset_path, + calibration_package_path=package_a, + n_clones=1, + seed=42, + ) + record_b = service.create_publish_fingerprint( + weights_path=weights_path, + dataset_path=dataset_path, + calibration_package_path=package_b, + n_clones=1, + seed=42, + ) + + assert record_a.components == record_b.components + assert record_a.digest == record_b.digest + + +def test_publish_fingerprint_rejects_incompatible_package_shape(monkeypatch, tmp_path): + _, fingerprinting = _install_fake_package_hierarchy(monkeypatch) + FingerprintService = fingerprinting.FingerprintService + + weights_path = _write_weights( + tmp_path / "weights.npy", + np.asarray([1.0, 2.0], dtype=np.float64), + ) + dataset_path = _write_bytes(tmp_path / "dataset.h5", b"dataset-one") + package_path = _write_package( + tmp_path / "package.pkl", + geography={ + "block_geoid": np.asarray(["060010001001001"] * 4, dtype=str), + "cd_geoid": np.asarray(["601"] * 4, dtype=str), + "county_fips": np.asarray(["06001"] * 4, dtype=str), + "state_fips": np.asarray([6] * 4, dtype=np.int64), + "n_records": 2, + "n_clones": 2, + }, + ) + + service = FingerprintService() + try: + service.create_publish_fingerprint( + weights_path=weights_path, + dataset_path=dataset_path, + calibration_package_path=package_path, + n_clones=1, + seed=42, + ) + except ValueError as error: + assert "incompatible with the requested publish shape" in str(error) + else: + raise AssertionError("Expected incompatible package shape to fail") + + +def test_deserialize_legacy_fingerprint_payload(monkeypatch): + _, fingerprinting = _install_fake_package_hierarchy(monkeypatch) + FingerprintService = fingerprinting.FingerprintService + + service = FingerprintService() + record = service.deserialize({"fingerprint": "deadbeefdeadbeef"}) + + assert record.schema_version == "legacy" + assert record.digest == "deadbeefdeadbeef" + assert record.components is None diff --git a/tests/unit/calibration/test_local_h5_package_geography.py b/tests/unit/calibration/test_local_h5_package_geography.py new file mode 100644 index 000000000..90447ffff --- /dev/null +++ b/tests/unit/calibration/test_local_h5_package_geography.py @@ -0,0 +1,231 @@ +import importlib.util +import pickle +from dataclasses import dataclass +from pathlib import Path +import sys +import types + +import numpy as np + + +def _load_package_geography_module(): + module_path = ( + Path(__file__).resolve().parents[3] + / "policyengine_us_data" + / "calibration" + / "local_h5" + / "package_geography.py" + ) + spec = importlib.util.spec_from_file_location( + "local_h5_package_geography", + module_path, + ) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def _install_fake_clone_and_assign(monkeypatch): + package = types.ModuleType("policyengine_us_data") + package.__path__ = [] + calibration_package = types.ModuleType("policyengine_us_data.calibration") + calibration_package.__path__ = [] + clone_module = types.ModuleType("policyengine_us_data.calibration.clone_and_assign") + + @dataclass(frozen=True) + class FakeGeographyAssignment: + block_geoid: np.ndarray + cd_geoid: np.ndarray + county_fips: np.ndarray + state_fips: np.ndarray + n_records: int + n_clones: int + + def fake_assign_random_geography(*, n_records, n_clones, seed): + total = n_records * n_clones + block_geoid = np.asarray(["990010000000001"] * total, dtype=str) + cd_geoid = np.asarray(["9901"] * total, dtype=str) + county_fips = np.asarray(["99001"] * total, dtype=str) + state_fips = np.asarray([99] * total, dtype=np.int64) + return FakeGeographyAssignment( + block_geoid=block_geoid, + cd_geoid=cd_geoid, + county_fips=county_fips, + state_fips=state_fips, + n_records=n_records, + n_clones=n_clones, + ) + + clone_module.GeographyAssignment = FakeGeographyAssignment + clone_module.assign_random_geography = fake_assign_random_geography + + monkeypatch.setitem(sys.modules, "policyengine_us_data", package) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration", + calibration_package, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.clone_and_assign", + clone_module, + ) + + return FakeGeographyAssignment + + +package_geography = _load_package_geography_module() +CalibrationPackageGeographyLoader = package_geography.CalibrationPackageGeographyLoader +require_calibration_package_path = package_geography.require_calibration_package_path + + +def test_serialize_and_load_serialized_package_geography(monkeypatch): + FakeGeographyAssignment = _install_fake_clone_and_assign(monkeypatch) + loader = CalibrationPackageGeographyLoader() + geography = FakeGeographyAssignment( + block_geoid=np.asarray(["060010001001001", "060010001001002"], dtype=str), + cd_geoid=np.asarray(["601", "601"], dtype=str), + county_fips=np.asarray(["06001", "06001"], dtype=str), + state_fips=np.asarray([6, 6], dtype=np.int64), + n_records=2, + n_clones=1, + ) + + payload = loader.serialize_geography(geography) + loaded = loader.load_from_package_dict({"geography": payload}) + + assert loaded is not None + assert loaded.source == "serialized_package" + np.testing.assert_array_equal(loaded.geography.block_geoid, geography.block_geoid) + np.testing.assert_array_equal(loaded.geography.cd_geoid, geography.cd_geoid) + np.testing.assert_array_equal(loaded.geography.county_fips, geography.county_fips) + np.testing.assert_array_equal(loaded.geography.state_fips, geography.state_fips) + assert loaded.geography.n_records == 2 + assert loaded.geography.n_clones == 1 + + +def test_load_legacy_package_geography_derives_county_and_state(monkeypatch): + _install_fake_clone_and_assign(monkeypatch) + loader = CalibrationPackageGeographyLoader() + + loaded = loader.load_from_package_dict( + { + "block_geoid": np.asarray( + [ + "060010001001001", + "060010001001002", + "360610001001001", + "360610001001002", + ], + dtype=str, + ), + "cd_geoid": np.asarray(["601", "601", "1208", "1208"], dtype=str), + "metadata": {"base_n_records": 2, "n_clones": 2}, + } + ) + + assert loaded is not None + assert loaded.source == "legacy_package" + np.testing.assert_array_equal( + loaded.geography.county_fips, + np.asarray(["06001", "06001", "36061", "36061"], dtype=str), + ) + np.testing.assert_array_equal( + loaded.geography.state_fips, + np.asarray([6, 6, 36, 36], dtype=np.int64), + ) + assert loaded.geography.n_records == 2 + assert loaded.geography.n_clones == 2 + assert loaded.warnings + + +def test_resolve_for_weights_falls_back_when_package_geography_length_mismatches( + monkeypatch, tmp_path +): + _install_fake_clone_and_assign(monkeypatch) + loader = CalibrationPackageGeographyLoader() + + package_path = tmp_path / "package.pkl" + with open(package_path, "wb") as f: + pickle.dump( + { + "geography": { + "block_geoid": np.asarray(["060010001001001"] * 4, dtype=str), + "cd_geoid": np.asarray(["601"] * 4, dtype=str), + "county_fips": np.asarray(["06001"] * 4, dtype=str), + "state_fips": np.asarray([6] * 4, dtype=np.int64), + "n_records": 2, + "n_clones": 2, + } + }, + f, + protocol=pickle.HIGHEST_PROTOCOL, + ) + + resolved = loader.resolve_for_weights( + package_path=package_path, + weights_length=6, + n_records=2, + n_clones=3, + seed=42, + ) + + assert resolved.source == "generated" + assert resolved.geography.n_records == 2 + assert resolved.geography.n_clones == 3 + assert any( + "incompatible with the requested publish shape" in warning + for warning in resolved.warnings + ) + + +def test_resolve_for_weights_strict_raises_when_package_geography_mismatches( + monkeypatch, tmp_path +): + _install_fake_clone_and_assign(monkeypatch) + loader = CalibrationPackageGeographyLoader() + + package_path = tmp_path / "package.pkl" + with open(package_path, "wb") as f: + pickle.dump( + { + "geography": { + "block_geoid": np.asarray(["060010001001001"] * 4, dtype=str), + "cd_geoid": np.asarray(["601"] * 4, dtype=str), + "county_fips": np.asarray(["06001"] * 4, dtype=str), + "state_fips": np.asarray([6] * 4, dtype=np.int64), + "n_records": 2, + "n_clones": 2, + } + }, + f, + protocol=pickle.HIGHEST_PROTOCOL, + ) + + try: + loader.resolve_for_weights( + package_path=package_path, + weights_length=6, + n_records=2, + n_clones=3, + seed=42, + allow_seed_fallback=False, + ) + except ValueError as error: + assert "incompatible with the requested publish shape" in str(error) + else: + raise AssertionError("Expected strict geography resolution to fail") + + +def test_require_calibration_package_path_raises_for_missing_file(tmp_path): + missing = tmp_path / "missing.pkl" + + try: + require_calibration_package_path(missing) + except FileNotFoundError as error: + assert "Required calibration package not found" in str(error) + else: + raise AssertionError("Expected FileNotFoundError for missing package") diff --git a/tests/unit/calibration/test_local_h5_partitioning.py b/tests/unit/calibration/test_local_h5_partitioning.py new file mode 100644 index 000000000..0f4cf62ad --- /dev/null +++ b/tests/unit/calibration/test_local_h5_partitioning.py @@ -0,0 +1,95 @@ +import importlib.util +from pathlib import Path +import sys + + +def _load_partitioning_module(): + module_path = ( + Path(__file__).resolve().parents[3] + / "policyengine_us_data" + / "calibration" + / "local_h5" + / "partitioning.py" + ) + spec = importlib.util.spec_from_file_location("local_h5_partitioning", module_path) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +partitioning = _load_partitioning_module() +partition_weighted_work_items = partitioning.partition_weighted_work_items +work_item_key = partitioning.work_item_key + + +def _flatten(chunks): + return [item for chunk in chunks for item in chunk] + + +def test_work_item_key_uses_existing_completion_shape(): + item = {"type": "district", "id": "CA-12", "weight": 1} + assert work_item_key(item) == "district:CA-12" + + +def test_partition_filters_completed_items(): + work_items = [ + {"type": "state", "id": "CA", "weight": 3}, + {"type": "district", "id": "CA-12", "weight": 1}, + {"type": "city", "id": "NYC", "weight": 2}, + ] + + chunks = partition_weighted_work_items( + work_items, + num_workers=2, + completed={"district:CA-12"}, + ) + + flattened = _flatten(chunks) + assert all(item["id"] != "CA-12" for item in flattened) + assert {item["id"] for item in flattened} == {"CA", "NYC"} + + +def test_partition_returns_empty_for_zero_workers_or_zero_remaining(): + work_items = [{"type": "state", "id": "CA", "weight": 1}] + + assert partition_weighted_work_items(work_items, num_workers=0) == [] + assert ( + partition_weighted_work_items( + work_items, + num_workers=3, + completed={"state:CA"}, + ) + == [] + ) + + +def test_partition_uses_no_more_workers_than_remaining_items(): + work_items = [ + {"type": "state", "id": "CA", "weight": 5}, + {"type": "state", "id": "NY", "weight": 4}, + ] + + chunks = partition_weighted_work_items(work_items, num_workers=10) + + assert len(chunks) == 2 + assert all(len(chunk) == 1 for chunk in chunks) + + +def test_partition_is_weight_balancing_and_deterministic_for_equal_weights(): + work_items = [ + {"type": "district", "id": "A", "weight": 5}, + {"type": "district", "id": "B", "weight": 5}, + {"type": "district", "id": "C", "weight": 2}, + {"type": "district", "id": "D", "weight": 2}, + ] + + chunks = partition_weighted_work_items(work_items, num_workers=2) + + ids_by_chunk = [[item["id"] for item in chunk] for chunk in chunks] + loads = [sum(item["weight"] for item in chunk) for chunk in chunks] + + assert ids_by_chunk == [["A", "C"], ["B", "D"]] + assert loads == [7, 7] diff --git a/tests/unit/calibration/test_local_h5_reindexing.py b/tests/unit/calibration/test_local_h5_reindexing.py new file mode 100644 index 000000000..de953a437 --- /dev/null +++ b/tests/unit/calibration/test_local_h5_reindexing.py @@ -0,0 +1,237 @@ +import importlib.util +from pathlib import Path +import sys +import types + +import numpy as np + + +def _module_path(*parts: str) -> Path: + return Path(__file__).resolve().parents[3].joinpath(*parts) + + +def _load_module(module_name: str, path: Path): + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _install_fake_package_hierarchy(monkeypatch): + package = types.ModuleType("policyengine_us_data") + package.__path__ = [] + calibration_package = types.ModuleType("policyengine_us_data.calibration") + calibration_package.__path__ = [] + local_h5_package = types.ModuleType("policyengine_us_data.calibration.local_h5") + local_h5_package.__path__ = [] + + monkeypatch.setitem(sys.modules, "policyengine_us_data", package) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration", + calibration_package, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5", + local_h5_package, + ) + + contracts = _load_module( + "policyengine_us_data.calibration.local_h5.contracts", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "contracts.py", + ), + ) + weights = _load_module( + "policyengine_us_data.calibration.local_h5.weights", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "weights.py", + ), + ) + selection = _load_module( + "policyengine_us_data.calibration.local_h5.selection", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "selection.py", + ), + ) + entity_graph = _load_module( + "policyengine_us_data.calibration.local_h5.entity_graph", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "entity_graph.py", + ), + ) + source_dataset = _load_module( + "policyengine_us_data.calibration.local_h5.source_dataset", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "source_dataset.py", + ), + ) + reindexing = _load_module( + "policyengine_us_data.calibration.local_h5.reindexing", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "reindexing.py", + ), + ) + return contracts, selection, entity_graph, source_dataset, reindexing + + +def _make_snapshot(source_dataset_module, entity_graph_module): + EntityGraph = entity_graph_module.EntityGraph + SourceDatasetSnapshot = source_dataset_module.SourceDatasetSnapshot + + graph = EntityGraph( + household_ids=np.asarray([10, 20]), + person_household_ids=np.asarray([10, 10, 20, 20, 20]), + hh_id_to_index={10: 0, 20: 1}, + hh_to_persons={0: (0, 1), 1: (2, 3, 4)}, + entity_id_arrays={ + "tax_unit": np.asarray([100, 200, 300]), + "spm_unit": np.asarray([900, 901]), + }, + person_entity_id_arrays={ + "tax_unit": np.asarray([100, 100, 200, 300, 300]), + "spm_unit": np.asarray([900, 900, 901, 901, 901]), + }, + hh_to_entity={ + "tax_unit": {0: (0,), 1: (1, 2)}, + "spm_unit": {0: (0,), 1: (1,)}, + }, + ) + return SourceDatasetSnapshot( + dataset_path=Path("/tmp/source.h5"), + time_period=2024, + household_ids=np.asarray([10, 20]), + entity_graph=graph, + input_variables=frozenset({"household_id"}), + variable_provider=types.SimpleNamespace(), + ) + + +def _make_selection(selection_module, household_indices): + CloneSelection = selection_module.CloneSelection + household_indices = np.asarray(household_indices, dtype=np.int64) + n = len(household_indices) + return CloneSelection( + active_clone_indices=np.arange(n, dtype=np.int64), + active_household_indices=household_indices, + active_weights=np.ones(n, dtype=float), + active_block_geoids=np.asarray([f"block-{i}" for i in range(n)], dtype=str), + active_cd_geoids=np.asarray([f"cd-{i}" for i in range(n)], dtype=str), + active_county_fips=np.asarray([f"county-{i}" for i in range(n)], dtype=str), + active_state_fips=np.asarray([i for i in range(n)], dtype=np.int64), + ) + + +def test_entity_reindexer_assigns_unique_output_ids(monkeypatch): + _, selection_module, entity_graph_module, source_dataset_module, reindexing = ( + _install_fake_package_hierarchy(monkeypatch) + ) + EntityReindexer = reindexing.EntityReindexer + + snapshot = _make_snapshot(source_dataset_module, entity_graph_module) + selection = _make_selection(selection_module, [0, 1]) + result = EntityReindexer().reindex(snapshot, selection) + + np.testing.assert_array_equal(result.new_household_ids, np.asarray([0, 1])) + np.testing.assert_array_equal(result.new_person_ids, np.asarray([0, 1, 2, 3, 4])) + np.testing.assert_array_equal( + result.new_entity_ids["tax_unit"], + np.asarray([0, 1, 2]), + ) + np.testing.assert_array_equal( + result.new_entity_ids["spm_unit"], + np.asarray([0, 1]), + ) + + +def test_entity_reindexer_maps_people_and_entities_for_repeated_households( + monkeypatch, +): + _, selection_module, entity_graph_module, source_dataset_module, reindexing = ( + _install_fake_package_hierarchy(monkeypatch) + ) + EntityReindexer = reindexing.EntityReindexer + + snapshot = _make_snapshot(source_dataset_module, entity_graph_module) + selection = _make_selection(selection_module, [0, 0, 1]) + result = EntityReindexer().reindex(snapshot, selection) + + np.testing.assert_array_equal( + result.persons_per_clone, + np.asarray([2, 2, 3]), + ) + np.testing.assert_array_equal( + result.new_person_household_ids, + np.asarray([0, 0, 1, 1, 2, 2, 2]), + ) + np.testing.assert_array_equal( + result.new_person_entity_ids["tax_unit"], + np.asarray([0, 0, 1, 1, 2, 3, 3]), + ) + np.testing.assert_array_equal( + result.new_person_entity_ids["spm_unit"], + np.asarray([0, 0, 1, 1, 2, 2, 2]), + ) + + +def test_entity_reindexer_handles_multiple_entities_within_one_household(monkeypatch): + _, selection_module, entity_graph_module, source_dataset_module, reindexing = ( + _install_fake_package_hierarchy(monkeypatch) + ) + EntityReindexer = reindexing.EntityReindexer + + snapshot = _make_snapshot(source_dataset_module, entity_graph_module) + selection = _make_selection(selection_module, [1]) + result = EntityReindexer().reindex(snapshot, selection) + + np.testing.assert_array_equal( + result.entities_per_clone["tax_unit"], + np.asarray([2]), + ) + np.testing.assert_array_equal( + result.entity_source_indices["tax_unit"], + np.asarray([1, 2]), + ) + np.testing.assert_array_equal( + result.new_person_entity_ids["tax_unit"], + np.asarray([0, 1, 1]), + ) + + +def test_entity_reindexer_handles_empty_selection(monkeypatch): + _, selection_module, entity_graph_module, source_dataset_module, reindexing = ( + _install_fake_package_hierarchy(monkeypatch) + ) + EntityReindexer = reindexing.EntityReindexer + + snapshot = _make_snapshot(source_dataset_module, entity_graph_module) + selection = _make_selection(selection_module, []) + result = EntityReindexer().reindex(snapshot, selection) + + assert result.new_household_ids.size == 0 + assert result.new_person_ids.size == 0 + assert result.new_person_household_ids.size == 0 + assert result.entity_source_indices["tax_unit"].size == 0 + assert result.new_person_entity_ids["tax_unit"].size == 0 diff --git a/tests/unit/calibration/test_local_h5_selection.py b/tests/unit/calibration/test_local_h5_selection.py new file mode 100644 index 000000000..63669072b --- /dev/null +++ b/tests/unit/calibration/test_local_h5_selection.py @@ -0,0 +1,258 @@ +import importlib.util +from dataclasses import dataclass +from pathlib import Path +import sys +import types + +import numpy as np +import pytest + + +def _module_path(*parts: str) -> Path: + return Path(__file__).resolve().parents[3].joinpath(*parts) + + +def _load_module(module_name: str, path: Path): + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _install_fake_package_hierarchy(monkeypatch): + package = types.ModuleType("policyengine_us_data") + package.__path__ = [] + calibration_package = types.ModuleType("policyengine_us_data.calibration") + calibration_package.__path__ = [] + local_h5_package = types.ModuleType("policyengine_us_data.calibration.local_h5") + local_h5_package.__path__ = [] + + monkeypatch.setitem(sys.modules, "policyengine_us_data", package) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration", + calibration_package, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5", + local_h5_package, + ) + + contracts = _load_module( + "policyengine_us_data.calibration.local_h5.contracts", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "contracts.py", + ), + ) + weights = _load_module( + "policyengine_us_data.calibration.local_h5.weights", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "weights.py", + ), + ) + selection = _load_module( + "policyengine_us_data.calibration.local_h5.selection", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "selection.py", + ), + ) + return contracts, weights, selection + + +@dataclass(frozen=True) +class FakeGeography: + block_geoid: np.ndarray + cd_geoid: np.ndarray + county_fips: np.ndarray + state_fips: np.ndarray + n_records: int + n_clones: int + + +def _sample_geography() -> FakeGeography: + return FakeGeography( + block_geoid=np.asarray( + [ + "060010001001001", + "360610001001001", + "060130001001001", + "360810001001001", + "060010001001002", + "360610001001002", + "060130001001002", + "360810001001002", + ], + dtype=str, + ), + cd_geoid=np.asarray( + ["601", "1208", "605", "1214", "601", "1208", "605", "1214"], + dtype=str, + ), + county_fips=np.asarray( + ["06001", "36061", "06013", "36081", "06001", "36061", "06013", "36081"], + dtype=str, + ), + state_fips=np.asarray([6, 36, 6, 36, 6, 36, 6, 36], dtype=np.int64), + n_records=4, + n_clones=2, + ) + + +def test_clone_weight_matrix_validates_shape(monkeypatch): + _, weights_module, _ = _install_fake_package_hierarchy(monkeypatch) + CloneWeightMatrix = weights_module.CloneWeightMatrix + + matrix = CloneWeightMatrix.from_vector(np.arange(8, dtype=float), n_records=4) + + assert matrix.n_clones == 2 + assert matrix.shape == (2, 4) + np.testing.assert_array_equal( + matrix.as_matrix(), + np.asarray([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=float), + ) + + +def test_clone_weight_matrix_rejects_invalid_shapes(monkeypatch): + _, weights_module, _ = _install_fake_package_hierarchy(monkeypatch) + CloneWeightMatrix = weights_module.CloneWeightMatrix + + with pytest.raises(ValueError, match="n_records must be positive"): + CloneWeightMatrix.from_vector(np.arange(4, dtype=float), n_records=0) + + with pytest.raises(ValueError, match="not divisible"): + CloneWeightMatrix.from_vector(np.arange(7, dtype=float), n_records=4) + + +def test_infer_clone_count_from_weight_length(monkeypatch): + _, weights_module, _ = _install_fake_package_hierarchy(monkeypatch) + infer_clone_count_from_weight_length = ( + weights_module.infer_clone_count_from_weight_length + ) + + assert infer_clone_count_from_weight_length(12, 3) == 4 + + with pytest.raises(ValueError, match="n_records must be positive"): + infer_clone_count_from_weight_length(12, 0) + + with pytest.raises(ValueError, match="weight_length must be positive"): + infer_clone_count_from_weight_length(0, 3) + + with pytest.raises(ValueError, match="not divisible"): + infer_clone_count_from_weight_length(13, 3) + + +def test_area_selector_supports_national_state_district_and_city(monkeypatch): + contracts, weights_module, selection_module = _install_fake_package_hierarchy( + monkeypatch + ) + AreaFilter = contracts.AreaFilter + CloneWeightMatrix = weights_module.CloneWeightMatrix + AreaSelector = selection_module.AreaSelector + + weights = CloneWeightMatrix.from_vector( + np.asarray([1.0, 0.0, 2.0, 0.0, 0.5, 1.5, 0.0, 3.0]), + n_records=4, + ) + geography = _sample_geography() + selector = AreaSelector() + + national = selector.select(weights, geography) + state = selector.select( + weights, + geography, + filters=(AreaFilter("state_fips", "eq", 6),), + ) + district = selector.select( + weights, + geography, + filters=(AreaFilter("cd_geoid", "eq", "1208"),), + ) + city = selector.select( + weights, + geography, + filters=(AreaFilter("county_fips", "in", ("36061", "36081")),), + ) + + np.testing.assert_array_equal( + national.active_clone_indices, + np.asarray([0, 0, 1, 1, 1]), + ) + np.testing.assert_array_equal( + national.active_household_indices, + np.asarray([0, 2, 0, 1, 3]), + ) + np.testing.assert_array_equal(state.active_weights, np.asarray([1.0, 2.0, 0.5])) + np.testing.assert_array_equal(state.active_state_fips, np.asarray([6, 6, 6])) + np.testing.assert_array_equal(district.active_weights, np.asarray([1.5])) + np.testing.assert_array_equal(district.active_cd_geoids, np.asarray(["1208"])) + np.testing.assert_array_equal(city.active_weights, np.asarray([1.5, 3.0])) + np.testing.assert_array_equal( + city.active_county_fips, + np.asarray(["36061", "36081"]), + ) + + +def test_area_selector_returns_empty_selection(monkeypatch): + contracts, weights_module, selection_module = _install_fake_package_hierarchy( + monkeypatch + ) + AreaFilter = contracts.AreaFilter + CloneWeightMatrix = weights_module.CloneWeightMatrix + AreaSelector = selection_module.AreaSelector + + weights = CloneWeightMatrix.from_vector( + np.asarray([1.0, 0.0, 2.0, 0.0, 0.5, 1.5, 0.0, 3.0]), + n_records=4, + ) + selector = AreaSelector() + selection = selector.select( + weights, + _sample_geography(), + filters=(AreaFilter("county_fips", "eq", "99999"),), + ) + + assert selection.is_empty + assert selection.n_household_clones == 0 + assert selection.active_weights.size == 0 + + +def test_area_selector_is_deterministic(monkeypatch): + contracts, weights_module, selection_module = _install_fake_package_hierarchy( + monkeypatch + ) + AreaFilter = contracts.AreaFilter + CloneWeightMatrix = weights_module.CloneWeightMatrix + AreaSelector = selection_module.AreaSelector + + weights = CloneWeightMatrix.from_vector( + np.asarray([1.0, 0.0, 2.0, 0.0, 0.5, 1.5, 0.0, 3.0]), + n_records=4, + ) + geography = _sample_geography() + selector = AreaSelector() + filters = (AreaFilter("state_fips", "eq", 36),) + + first = selector.select(weights, geography, filters=filters) + second = selector.select(weights, geography, filters=filters) + + np.testing.assert_array_equal( + first.active_clone_indices, + second.active_clone_indices, + ) + np.testing.assert_array_equal( + first.active_household_indices, second.active_household_indices + ) + np.testing.assert_array_equal(first.active_weights, second.active_weights) diff --git a/tests/unit/calibration/test_local_h5_source_dataset.py b/tests/unit/calibration/test_local_h5_source_dataset.py new file mode 100644 index 000000000..201400966 --- /dev/null +++ b/tests/unit/calibration/test_local_h5_source_dataset.py @@ -0,0 +1,163 @@ +import importlib.util +from pathlib import Path +import sys +import types + +import numpy as np + + +def _module_path(*parts: str) -> Path: + return Path(__file__).resolve().parents[3].joinpath(*parts) + + +def _load_module(module_name: str, path: Path): + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _install_fake_package_hierarchy(monkeypatch): + package = types.ModuleType("policyengine_us_data") + package.__path__ = [] + calibration_package = types.ModuleType("policyengine_us_data.calibration") + calibration_package.__path__ = [] + local_h5_package = types.ModuleType("policyengine_us_data.calibration.local_h5") + local_h5_package.__path__ = [] + + monkeypatch.setitem(sys.modules, "policyengine_us_data", package) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration", + calibration_package, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5", + local_h5_package, + ) + + entity_graph = _load_module( + "policyengine_us_data.calibration.local_h5.entity_graph", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "entity_graph.py", + ), + ) + source_dataset = _load_module( + "policyengine_us_data.calibration.local_h5.source_dataset", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "source_dataset.py", + ), + ) + return entity_graph, source_dataset + + +def test_entity_graph_extractor_builds_household_relationships(monkeypatch): + entity_graph_module, _ = _install_fake_package_hierarchy(monkeypatch) + EntityGraphExtractor = entity_graph_module.EntityGraphExtractor + + extractor = EntityGraphExtractor(("tax_unit", "spm_unit")) + graph = extractor.extract_from_arrays( + household_ids=np.asarray([10, 20]), + person_household_ids=np.asarray([10, 10, 20, 20, 20]), + entity_id_arrays={ + "tax_unit": np.asarray([100, 200, 300]), + "spm_unit": np.asarray([900, 901]), + }, + person_entity_id_arrays={ + "tax_unit": np.asarray([100, 100, 200, 300, 300]), + "spm_unit": np.asarray([900, 900, 901, 901, 901]), + }, + ) + + assert graph.hh_id_to_index == {10: 0, 20: 1} + assert graph.hh_to_persons == {0: (0, 1), 1: (2, 3, 4)} + assert graph.hh_to_entity["tax_unit"] == {0: (0,), 1: (1, 2)} + assert graph.hh_to_entity["spm_unit"] == {0: (0,), 1: (1,)} + + +def test_policy_engine_dataset_reader_builds_snapshot_without_eager_holder_access( + monkeypatch, tmp_path +): + _, source_dataset_module = _install_fake_package_hierarchy(monkeypatch) + PolicyEngineDatasetReader = source_dataset_module.PolicyEngineDatasetReader + + class FakeHolder: + def __init__(self, values): + self.values = np.asarray(values) + + def get_known_periods(self): + return (2024,) + + def get_array(self, period): + assert period == 2024 + return self.values + + class FakeVariableDef: + def __init__(self, entity_key): + self.entity = types.SimpleNamespace(key=entity_key) + + class FakeVariables(dict): + def keys(self): + return super().keys() + + class FakeMicrosimulation: + instances = [] + + def __init__(self, dataset): + self.dataset = dataset + self.default_calculation_period = 2024 + self.input_variables = {"household_id", "tax_unit_id"} + self.tax_benefit_system = types.SimpleNamespace( + variables=FakeVariables( + { + "sample_var": FakeVariableDef("household"), + } + ) + ) + self.get_holder_calls = 0 + FakeMicrosimulation.instances.append(self) + + def calculate(self, variable, map_to=None): + lookup = { + ("household_id", "household"): np.asarray([10, 20]), + ("household_id", "person"): np.asarray([10, 10, 20]), + ("tax_unit_id", "tax_unit"): np.asarray([100, 200]), + ("person_tax_unit_id", "person"): np.asarray([100, 100, 200]), + } + return types.SimpleNamespace(values=lookup[(variable, map_to)]) + + def get_holder(self, variable): + self.get_holder_calls += 1 + assert variable == "sample_var" + return FakeHolder([1, 2]) + + fake_policyengine_us = types.ModuleType("policyengine_us") + fake_policyengine_us.Microsimulation = FakeMicrosimulation + monkeypatch.setitem(sys.modules, "policyengine_us", fake_policyengine_us) + + reader = PolicyEngineDatasetReader(("tax_unit",)) + snapshot = reader.load(tmp_path / "source.h5") + + assert snapshot.dataset_path == tmp_path / "source.h5" + assert snapshot.time_period == 2024 + assert snapshot.n_households == 2 + assert snapshot.input_variables == frozenset({"household_id", "tax_unit_id"}) + assert snapshot.entity_graph.hh_to_persons == {0: (0, 1), 1: (2,)} + + fake_sim = FakeMicrosimulation.instances[-1] + assert fake_sim.get_holder_calls == 0 + + periods = snapshot.variable_provider.get_known_periods("sample_var") + + assert periods == (2024,) + assert fake_sim.get_holder_calls == 1 diff --git a/tests/unit/calibration/test_local_h5_unified_package_io.py b/tests/unit/calibration/test_local_h5_unified_package_io.py new file mode 100644 index 000000000..55a3a6d5a --- /dev/null +++ b/tests/unit/calibration/test_local_h5_unified_package_io.py @@ -0,0 +1,132 @@ +import importlib.util +from dataclasses import dataclass +from pathlib import Path +import sys +import types + +import numpy as np +import pandas as pd + + +def _module_path(*parts: str) -> Path: + return Path(__file__).resolve().parents[3].joinpath(*parts) + + +def _load_module(module_name: str, path: Path): + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _install_fake_package_hierarchy(monkeypatch): + package = types.ModuleType("policyengine_us_data") + package.__path__ = [] + calibration_package = types.ModuleType("policyengine_us_data.calibration") + calibration_package.__path__ = [] + local_h5_package = types.ModuleType("policyengine_us_data.calibration.local_h5") + local_h5_package.__path__ = [] + + monkeypatch.setitem(sys.modules, "policyengine_us_data", package) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration", + calibration_package, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5", + local_h5_package, + ) + + clone_module = types.ModuleType("policyengine_us_data.calibration.clone_and_assign") + + @dataclass(frozen=True) + class FakeGeographyAssignment: + block_geoid: np.ndarray + cd_geoid: np.ndarray + county_fips: np.ndarray + state_fips: np.ndarray + n_records: int + n_clones: int + + clone_module.GeographyAssignment = FakeGeographyAssignment + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.clone_and_assign", + clone_module, + ) + + _load_module( + "policyengine_us_data.calibration.local_h5.package_geography", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "package_geography.py", + ), + ) + unified_calibration = _load_module( + "policyengine_us_data.calibration.unified_calibration", + _module_path( + "policyengine_us_data", + "calibration", + "unified_calibration.py", + ), + ) + return FakeGeographyAssignment, unified_calibration + + +def test_save_and_load_calibration_package_round_trips_serialized_geography( + monkeypatch, tmp_path +): + FakeGeographyAssignment, unified_calibration = _install_fake_package_hierarchy( + monkeypatch + ) + + geography = FakeGeographyAssignment( + block_geoid=np.asarray(["060010001001001", "360610001001001"], dtype=str), + cd_geoid=np.asarray(["601", "1208"], dtype=str), + county_fips=np.asarray(["06001", "36061"], dtype=str), + state_fips=np.asarray([6, 36], dtype=np.int64), + n_records=2, + n_clones=1, + ) + package_path = tmp_path / "calibration_package.pkl" + + unified_calibration.save_calibration_package( + path=str(package_path), + X_sparse=np.zeros((1, 1), dtype=np.float64), + targets_df=pd.DataFrame({"variable": ["household_count"], "value": [1.0]}), + target_names=["household_count"], + metadata={"created_at": "2026-04-10T00:00:00Z"}, + geography=geography, + initial_weights=np.asarray([1.0], dtype=np.float64), + ) + + loaded = unified_calibration.load_calibration_package(str(package_path)) + + assert loaded["geography"] is not None + np.testing.assert_array_equal( + loaded["geography"]["block_geoid"], + geography.block_geoid, + ) + np.testing.assert_array_equal( + loaded["geography"]["cd_geoid"], + geography.cd_geoid, + ) + np.testing.assert_array_equal( + loaded["geography"]["county_fips"], + geography.county_fips, + ) + np.testing.assert_array_equal( + loaded["geography"]["state_fips"], + geography.state_fips, + ) + assert loaded["geography"]["n_records"] == geography.n_records + assert loaded["geography"]["n_clones"] == geography.n_clones + np.testing.assert_array_equal(loaded["cd_geoid"], geography.cd_geoid) + np.testing.assert_array_equal(loaded["block_geoid"], geography.block_geoid) diff --git a/tests/unit/calibration/test_local_h5_us_augmentations.py b/tests/unit/calibration/test_local_h5_us_augmentations.py new file mode 100644 index 000000000..44308ee99 --- /dev/null +++ b/tests/unit/calibration/test_local_h5_us_augmentations.py @@ -0,0 +1,405 @@ +import importlib.util +from pathlib import Path +import sys +import types + +import numpy as np + + +def _module_path(*parts: str) -> Path: + return Path(__file__).resolve().parents[3].joinpath(*parts) + + +def _load_module(module_name: str, path: Path): + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _install_fake_package_hierarchy(monkeypatch): + package = types.ModuleType("policyengine_us_data") + package.__path__ = [] + calibration_package = types.ModuleType("policyengine_us_data.calibration") + calibration_package.__path__ = [] + local_h5_package = types.ModuleType("policyengine_us_data.calibration.local_h5") + local_h5_package.__path__ = [] + utils_package = types.ModuleType("policyengine_us_data.utils") + utils_package.__path__ = [] + + monkeypatch.setitem(sys.modules, "policyengine_us_data", package) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration", + calibration_package, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5", + local_h5_package, + ) + monkeypatch.setitem(sys.modules, "policyengine_us_data.utils", utils_package) + + block_assignment = types.ModuleType( + "policyengine_us_data.calibration.block_assignment" + ) + block_assignment.derive_geography_from_blocks = lambda blocks: {} + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.block_assignment", + block_assignment, + ) + + calibration_utils = types.ModuleType( + "policyengine_us_data.calibration.calibration_utils" + ) + calibration_utils.calculate_spm_thresholds_vectorized = ( + lambda **kwargs: np.asarray([], dtype=np.float64) + ) + calibration_utils.load_cd_geoadj_values = lambda cds: {} + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.calibration_utils", + calibration_utils, + ) + + takeup = types.ModuleType("policyengine_us_data.utils.takeup") + takeup.apply_block_takeup_to_arrays = lambda **kwargs: {} + takeup.reported_subsidized_marketplace_by_tax_unit = ( + lambda person_tax_unit_ids, tax_unit_ids, reported_mask: np.asarray( + [ + bool( + reported_mask[person_tax_unit_ids == tax_unit_id].any() + ) + for tax_unit_id in tax_unit_ids + ], + dtype=bool, + ) + ) + monkeypatch.setitem(sys.modules, "policyengine_us_data.utils.takeup", takeup) + + reindexing = types.ModuleType( + "policyengine_us_data.calibration.local_h5.reindexing" + ) + reindexing.ReindexedEntities = type("ReindexedEntities", (), {}) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.reindexing", + reindexing, + ) + + selection = types.ModuleType("policyengine_us_data.calibration.local_h5.selection") + selection.CloneSelection = type("CloneSelection", (), {}) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.selection", + selection, + ) + + source_dataset = types.ModuleType( + "policyengine_us_data.calibration.local_h5.source_dataset" + ) + source_dataset.SourceDatasetSnapshot = type("SourceDatasetSnapshot", (), {}) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.source_dataset", + source_dataset, + ) + + return _load_module( + "policyengine_us_data.calibration.local_h5.us_augmentations", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "us_augmentations.py", + ), + ) + + +def test_us_augmentation_service_applies_geography_outputs(monkeypatch): + module = _install_fake_package_hierarchy(monkeypatch) + USAugmentationService = module.USAugmentationService + + def fake_geography_lookup(unique_blocks): + np.testing.assert_array_equal(unique_blocks, np.asarray(["b1", "b2"])) + return { + "state_fips": np.asarray([2, 1]), + "county_index": np.asarray([1, 0]), + "county_fips": np.asarray(["02001", "01001"]), + "tract_geoid": np.asarray(["t2", "t1"]), + } + + service = USAugmentationService( + geography_lookup=fake_geography_lookup, + county_name_lookup=lambda idx: np.asarray( + [f"COUNTY_{value}" for value in idx], dtype="S" + ), + ) + + data = {} + clone_geo = service.apply_geography( + data, + time_period=2024, + active_blocks=np.asarray(["b2", "b1", "b2"]), + active_clone_cds=np.asarray(["0101", "0201", "0101"]), + ) + + np.testing.assert_array_equal( + clone_geo["county_fips"], + np.asarray(["01001", "02001", "01001"]), + ) + np.testing.assert_array_equal( + data["state_fips"][2024], + np.asarray([1, 2, 1], dtype=np.int32), + ) + np.testing.assert_array_equal( + data["county"][2024], + np.asarray([b"COUNTY_0", b"COUNTY_1", b"COUNTY_0"]), + ) + np.testing.assert_array_equal( + data["tract_geoid"][2024], + np.asarray([b"t1", b"t2", b"t1"]), + ) + np.testing.assert_array_equal( + data["congressional_district_geoid"][2024], + np.asarray([101, 201, 101], dtype=np.int32), + ) + + +def test_us_augmentation_service_applies_la_zip_patch(monkeypatch): + module = _install_fake_package_hierarchy(monkeypatch) + service = module.USAugmentationService() + data = {} + + service.apply_zip_code_patch( + data, + time_period=2024, + county_fips=np.asarray(["06037", "06059", "06037"]), + ) + + np.testing.assert_array_equal( + data["zip_code"][2024], + np.asarray([b"90001", b"UNKNOWN", b"90001"]), + ) + + +def test_us_augmentation_service_recalculates_spm_thresholds(monkeypatch): + module = _install_fake_package_hierarchy(monkeypatch) + USAugmentationService = module.USAugmentationService + + captured = {} + + def fake_threshold_calculator(**kwargs): + captured.update(kwargs) + return np.asarray([11.0, 22.0]) + + service = USAugmentationService( + cd_geoadj_loader=lambda cds: {"0101": 1.1, "0201": 2.2}, + threshold_calculator=fake_threshold_calculator, + ) + + class FakeProvider: + def calculate(self, variable, *, map_to=None): + assert variable == "age" + assert map_to == "person" + return types.SimpleNamespace(values=np.asarray([30, 40, 50])) + + def get_known_periods(self, variable): + assert variable == "spm_unit_tenure_type" + return (2024,) + + def get_array(self, variable, period): + assert variable == "spm_unit_tenure_type" + assert period == 2024 + return np.asarray([b"OWNER", b"RENTER"]) + + source = types.SimpleNamespace(variable_provider=FakeProvider()) + reindexed = types.SimpleNamespace( + person_source_indices=np.asarray([2, 0, 1], dtype=np.int64), + entity_source_indices={"spm_unit": np.asarray([1, 0], dtype=np.int64)}, + entities_per_clone={"spm_unit": np.asarray([1, 1], dtype=np.int64)}, + new_person_entity_ids={"spm_unit": np.asarray([0, 1, 1], dtype=np.int32)}, + ) + data = {} + + service.apply_spm_thresholds( + data, + time_period=2024, + active_clone_cds=np.asarray(["0101", "0201"]), + source=source, + reindexed=reindexed, + ) + + np.testing.assert_array_equal( + data["spm_unit_spm_threshold"][2024], + np.asarray([11.0, 22.0]), + ) + np.testing.assert_array_equal( + captured["person_ages"], + np.asarray([50, 30, 40]), + ) + np.testing.assert_array_equal( + captured["person_spm_unit_ids"], + np.asarray([0, 1, 1], dtype=np.int32), + ) + np.testing.assert_array_equal( + captured["spm_unit_tenure_types"], + np.asarray([b"RENTER", b"OWNER"]), + ) + np.testing.assert_array_equal( + captured["spm_unit_geoadj"], + np.asarray([1.1, 2.2]), + ) + assert captured["year"] == 2024 + + +def test_us_augmentation_service_applies_takeup_with_clone_indices(monkeypatch): + module = _install_fake_package_hierarchy(monkeypatch) + USAugmentationService = module.USAugmentationService + + captured = {} + + def fake_takeup_fn(**kwargs): + captured.update(kwargs) + return {"snap": np.asarray([True, False, True])} + + service = USAugmentationService(takeup_fn=fake_takeup_fn) + selection = types.SimpleNamespace( + active_block_geoids=np.asarray(["b1", "b2"]), + active_clone_indices=np.asarray([3, 4], dtype=np.int64), + active_household_indices=np.asarray([1, 0], dtype=np.int64), + n_household_clones=2, + ) + source = types.SimpleNamespace( + household_ids=np.asarray([10, 20], dtype=np.int64), + ) + reindexed = types.SimpleNamespace( + persons_per_clone=np.asarray([2, 1], dtype=np.int64), + entities_per_clone={ + "tax_unit": np.asarray([1, 2], dtype=np.int64), + "spm_unit": np.asarray([1, 1], dtype=np.int64), + }, + person_source_indices=np.asarray([2, 3, 0], dtype=np.int64), + entity_source_indices={ + "tax_unit": np.asarray([1, 2, 0], dtype=np.int64), + "spm_unit": np.asarray([1, 0], dtype=np.int64), + }, + ) + data = {} + + service.apply_takeup( + data, + time_period=2024, + takeup_filter=("snap",), + selection=selection, + source=source, + reindexed=reindexed, + clone_geo={"state_fips": np.asarray([6, 36])}, + ) + + np.testing.assert_array_equal( + captured["hh_blocks"], + np.asarray(["b1", "b2"]), + ) + np.testing.assert_array_equal( + captured["hh_state_fips"], + np.asarray([6, 36], dtype=np.int32), + ) + np.testing.assert_array_equal( + captured["hh_ids"], + np.asarray([20, 10], dtype=np.int64), + ) + np.testing.assert_array_equal( + captured["hh_clone_indices"], + np.asarray([3, 4], dtype=np.int64), + ) + np.testing.assert_array_equal( + captured["entity_hh_indices"]["person"], + np.asarray([0, 0, 1], dtype=np.int64), + ) + np.testing.assert_array_equal( + captured["entity_hh_indices"]["tax_unit"], + np.asarray([0, 1, 1], dtype=np.int64), + ) + np.testing.assert_array_equal( + captured["entity_hh_indices"]["spm_unit"], + np.asarray([0, 1], dtype=np.int64), + ) + assert captured["entity_counts"] == { + "person": 3, + "tax_unit": 3, + "spm_unit": 2, + } + assert captured["time_period"] == 2024 + assert captured["takeup_filter"] == ("snap",) + np.testing.assert_array_equal( + data["snap"][2024], + np.asarray([True, False, True]), + ) + + +def test_us_augmentation_service_passes_reported_takeup_anchors(monkeypatch): + module = _install_fake_package_hierarchy(monkeypatch) + USAugmentationService = module.USAugmentationService + + captured = {} + + def fake_takeup_fn(**kwargs): + captured.update(kwargs) + return {} + + service = USAugmentationService(takeup_fn=fake_takeup_fn) + selection = types.SimpleNamespace( + active_block_geoids=np.asarray(["b1", "b2"]), + active_clone_indices=np.asarray([3, 4], dtype=np.int64), + active_household_indices=np.asarray([1, 0], dtype=np.int64), + n_household_clones=2, + ) + source = types.SimpleNamespace( + household_ids=np.asarray([10, 20], dtype=np.int64), + ) + reindexed = types.SimpleNamespace( + persons_per_clone=np.asarray([2, 1], dtype=np.int64), + entities_per_clone={ + "tax_unit": np.asarray([1, 2], dtype=np.int64), + "spm_unit": np.asarray([1, 1], dtype=np.int64), + }, + person_source_indices=np.asarray([2, 3, 0], dtype=np.int64), + entity_source_indices={ + "tax_unit": np.asarray([1, 2], dtype=np.int64), + "spm_unit": np.asarray([1, 0], dtype=np.int64), + }, + ) + data = { + "person_tax_unit_id": {2024: np.asarray([1, 1, 2], dtype=np.int64)}, + "tax_unit_id": {2024: np.asarray([1, 2], dtype=np.int64)}, + "reported_has_subsidized_marketplace_health_coverage_at_interview": { + 2024: np.asarray([True, False, False], dtype=bool) + }, + "has_medicaid_health_coverage_at_interview": { + 2024: np.asarray([False, True, False], dtype=bool) + }, + } + + service.apply_takeup( + data, + time_period=2024, + takeup_filter=("snap",), + selection=selection, + source=source, + reindexed=reindexed, + clone_geo={"state_fips": np.asarray([6, 36])}, + ) + + np.testing.assert_array_equal( + captured["reported_anchors"]["takes_up_aca_if_eligible"], + np.asarray([True, False], dtype=bool), + ) + np.testing.assert_array_equal( + captured["reported_anchors"]["takes_up_medicaid_if_eligible"], + np.asarray([False, True, False], dtype=bool), + ) diff --git a/tests/unit/calibration/test_local_h5_validation_helpers.py b/tests/unit/calibration/test_local_h5_validation_helpers.py new file mode 100644 index 000000000..93b1b4ca0 --- /dev/null +++ b/tests/unit/calibration/test_local_h5_validation_helpers.py @@ -0,0 +1,86 @@ +import importlib.util +from pathlib import Path +import sys + + +def _load_validation_module(): + module_path = ( + Path(__file__).resolve().parents[3] + / "policyengine_us_data" + / "calibration" + / "local_h5" + / "validation.py" + ) + spec = importlib.util.spec_from_file_location("local_h5_validation", module_path) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +validation = _load_validation_module() +make_validation_error = validation.make_validation_error +summarize_validation_rows = validation.summarize_validation_rows +tag_validation_errors = validation.tag_validation_errors +validation_geo_level_for_area_type = validation.validation_geo_level_for_area_type + + +def test_validation_geo_level_maps_current_area_types_explicitly(): + assert validation_geo_level_for_area_type("states") == "state" + assert validation_geo_level_for_area_type("districts") == "district" + assert validation_geo_level_for_area_type("cities") == "district" + assert validation_geo_level_for_area_type("national") == "national" + + +def test_summarize_validation_rows_counts_failures_and_ignores_infinite_error(): + summary = summarize_validation_rows( + ( + {"sanity_check": "PASS", "rel_abs_error": 0.10}, + {"sanity_check": "FAIL", "rel_abs_error": 0.30}, + {"sanity_check": "FAIL", "rel_abs_error": float("inf")}, + ) + ) + + assert summary == { + "n_targets": 3, + "n_sanity_fail": 2, + "mean_rel_abs_error": 0.2, + } + + +def test_make_validation_error_returns_structured_payload(): + payload = make_validation_error( + item_key="district:CA-12", + error=RuntimeError("validator crashed"), + traceback_text="traceback lines", + ) + + assert payload == { + "item": "district:CA-12", + "error": "validator crashed", + "traceback": "traceback lines", + } + + +def test_tag_validation_errors_attaches_source_without_dropping_fields(): + tagged = tag_validation_errors( + ( + { + "item": "district:CA-12", + "error": "validator crashed", + "traceback": "traceback lines", + }, + ), + source="regional", + ) + + assert tagged == [ + { + "item": "district:CA-12", + "error": "validator crashed", + "traceback": "traceback lines", + "source": "regional", + } + ] diff --git a/tests/unit/calibration/test_local_h5_variables.py b/tests/unit/calibration/test_local_h5_variables.py new file mode 100644 index 000000000..ff3153e7e --- /dev/null +++ b/tests/unit/calibration/test_local_h5_variables.py @@ -0,0 +1,350 @@ +import importlib.util +from pathlib import Path +import sys +import types + +import numpy as np + + +def _module_path(*parts: str) -> Path: + return Path(__file__).resolve().parents[3].joinpath(*parts) + + +def _load_module(module_name: str, path: Path): + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _install_fake_package_hierarchy(monkeypatch): + package = types.ModuleType("policyengine_us_data") + package.__path__ = [] + calibration_package = types.ModuleType("policyengine_us_data.calibration") + calibration_package.__path__ = [] + local_h5_package = types.ModuleType("policyengine_us_data.calibration.local_h5") + local_h5_package.__path__ = [] + + monkeypatch.setitem(sys.modules, "policyengine_us_data", package) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration", + calibration_package, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5", + local_h5_package, + ) + + entity_graph = _load_module( + "policyengine_us_data.calibration.local_h5.entity_graph", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "entity_graph.py", + ), + ) + source_dataset = _load_module( + "policyengine_us_data.calibration.local_h5.source_dataset", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "source_dataset.py", + ), + ) + _load_module( + "policyengine_us_data.calibration.local_h5.contracts", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "contracts.py", + ), + ) + _load_module( + "policyengine_us_data.calibration.local_h5.weights", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "weights.py", + ), + ) + _load_module( + "policyengine_us_data.calibration.local_h5.selection", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "selection.py", + ), + ) + reindexing = _load_module( + "policyengine_us_data.calibration.local_h5.reindexing", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "reindexing.py", + ), + ) + variables = _load_module( + "policyengine_us_data.calibration.local_h5.variables", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "variables.py", + ), + ) + return entity_graph, source_dataset, reindexing, variables + + +_EnumLike = type("Enum", (), {}) + + +class FakeVariableDef: + def __init__(self, entity_key, value_type): + self.entity = types.SimpleNamespace(key=entity_key) + self.value_type = value_type + + +class FakeProvider: + def __init__(self): + self._definitions = { + "household_income": FakeVariableDef("household", float), + "person_age": FakeVariableDef("person", int), + "tax_unit_amount": FakeVariableDef("tax_unit", float), + "enum_status": FakeVariableDef("household", _EnumLike), + "county_fips": FakeVariableDef("household", str), + "two_period_var": FakeVariableDef("household", float), + "ignored_output_only": FakeVariableDef("output_only", float), + } + self._periods = { + "household_income": (2024,), + "person_age": (2024,), + "tax_unit_amount": (2024,), + "enum_status": (2024,), + "county_fips": (2024,), + "two_period_var": (2023, 2024), + "ignored_output_only": (2024,), + } + self._arrays = { + ("household_income", 2024): np.asarray([100.0, 200.0]), + ("person_age", 2024): np.asarray([34, 35, 50, 18, 17]), + ("tax_unit_amount", 2024): np.asarray([10.0, 20.0, 30.0]), + ("enum_status", 2024): np.asarray(["A", "B"], dtype=object), + ("county_fips", 2024): np.asarray(["06001", "36061"], dtype=object), + ("two_period_var", 2023): np.asarray([1.0, 2.0]), + ("two_period_var", 2024): np.asarray([3.0, 4.0]), + ("ignored_output_only", 2024): np.asarray([9.0, 9.0]), + } + + def list_variables(self): + return tuple(self._definitions.keys()) + + def get_known_periods(self, variable): + return self._periods[variable] + + def get_array(self, variable, period): + return self._arrays[(variable, period)] + + def get_variable_definition(self, variable): + return self._definitions.get(variable) + + def calculate(self, variable, *, map_to=None): + raise AssertionError("calculate should not be used in VariableCloner tests") + + +def _make_snapshot(source_dataset_module, entity_graph_module): + EntityGraph = entity_graph_module.EntityGraph + SourceDatasetSnapshot = source_dataset_module.SourceDatasetSnapshot + + graph = EntityGraph( + household_ids=np.asarray([10, 20]), + person_household_ids=np.asarray([10, 10, 20, 20, 20]), + hh_id_to_index={10: 0, 20: 1}, + hh_to_persons={0: (0, 1), 1: (2, 3, 4)}, + entity_id_arrays={ + "tax_unit": np.asarray([100, 200, 300]), + }, + person_entity_id_arrays={ + "tax_unit": np.asarray([100, 100, 200, 300, 300]), + }, + hh_to_entity={ + "tax_unit": {0: (0,), 1: (1, 2)}, + }, + ) + return SourceDatasetSnapshot( + dataset_path=Path("/tmp/source.h5"), + time_period=2024, + household_ids=np.asarray([10, 20]), + entity_graph=graph, + input_variables=frozenset( + { + "household_income", + "person_age", + "tax_unit_amount", + "enum_status", + "county_fips", + "two_period_var", + "ignored_output_only", + } + ), + variable_provider=FakeProvider(), + ) + + +def _make_reindexed(reindexing_module): + ReindexedEntities = reindexing_module.ReindexedEntities + return ReindexedEntities( + household_source_indices=np.asarray([1, 0], dtype=np.int64), + person_source_indices=np.asarray([2, 3, 4, 0, 1], dtype=np.int64), + entity_source_indices={ + "tax_unit": np.asarray([1, 2, 0], dtype=np.int64), + }, + persons_per_clone=np.asarray([3, 2], dtype=np.int64), + entities_per_clone={ + "tax_unit": np.asarray([2, 1], dtype=np.int64), + }, + new_household_ids=np.asarray([0, 1], dtype=np.int32), + new_person_ids=np.asarray([0, 1, 2, 3, 4], dtype=np.int32), + new_person_household_ids=np.asarray([0, 0, 0, 1, 1], dtype=np.int32), + new_entity_ids={ + "tax_unit": np.asarray([0, 1, 2], dtype=np.int32), + }, + new_person_entity_ids={ + "tax_unit": np.asarray([0, 1, 1, 2, 2], dtype=np.int32), + }, + ) + + +def test_variable_cloner_slices_household_person_and_subentity_arrays(monkeypatch): + entity_graph, source_dataset, reindexing, variables = _install_fake_package_hierarchy( + monkeypatch + ) + VariableCloner = variables.VariableCloner + VariableExportPolicy = variables.VariableExportPolicy + + snapshot = _make_snapshot(source_dataset, entity_graph) + reindexed = _make_reindexed(reindexing) + payload = VariableCloner().clone(snapshot, reindexed, VariableExportPolicy()) + + np.testing.assert_array_equal( + payload.variables["household_income"][2024], + np.asarray([200.0, 100.0]), + ) + np.testing.assert_array_equal( + payload.variables["person_age"][2024], + np.asarray([50, 18, 17, 34, 35]), + ) + np.testing.assert_array_equal( + payload.variables["tax_unit_amount"][2024], + np.asarray([20.0, 30.0, 10.0]), + ) + + +def test_variable_cloner_handles_multiple_periods(monkeypatch): + entity_graph, source_dataset, reindexing, variables = _install_fake_package_hierarchy( + monkeypatch + ) + VariableCloner = variables.VariableCloner + VariableExportPolicy = variables.VariableExportPolicy + + snapshot = _make_snapshot(source_dataset, entity_graph) + reindexed = _make_reindexed(reindexing) + payload = VariableCloner().clone(snapshot, reindexed, VariableExportPolicy()) + + np.testing.assert_array_equal( + payload.variables["two_period_var"][2023], + np.asarray([2.0, 1.0]), + ) + np.testing.assert_array_equal( + payload.variables["two_period_var"][2024], + np.asarray([4.0, 3.0]), + ) + + +def test_variable_cloner_encodes_enum_and_county_fips_values(monkeypatch): + entity_graph, source_dataset, reindexing, variables = _install_fake_package_hierarchy( + monkeypatch + ) + VariableCloner = variables.VariableCloner + VariableExportPolicy = variables.VariableExportPolicy + + snapshot = _make_snapshot(source_dataset, entity_graph) + reindexed = _make_reindexed(reindexing) + payload = VariableCloner().clone(snapshot, reindexed, VariableExportPolicy()) + + assert payload.variables["enum_status"][2024].dtype.kind == "S" + np.testing.assert_array_equal( + payload.variables["enum_status"][2024], + np.asarray([b"B", b"A"]), + ) + assert payload.variables["county_fips"][2024].dtype == np.int32 + np.testing.assert_array_equal( + payload.variables["county_fips"][2024], + np.asarray([36061, 6001], dtype=np.int32), + ) + + +def test_variable_cloner_respects_excluded_variables(monkeypatch): + entity_graph, source_dataset, reindexing, variables = _install_fake_package_hierarchy( + monkeypatch + ) + VariableCloner = variables.VariableCloner + VariableExportPolicy = variables.VariableExportPolicy + + snapshot = _make_snapshot(source_dataset, entity_graph) + reindexed = _make_reindexed(reindexing) + payload = VariableCloner().clone( + snapshot, + reindexed, + VariableExportPolicy(excluded_variables=frozenset({"household_income"})), + ) + + assert "household_income" not in payload.variables + assert "person_age" in payload.variables + + +def test_variable_cloner_respects_required_variables_when_input_variables_disabled( + monkeypatch, +): + entity_graph, source_dataset, reindexing, variables = _install_fake_package_hierarchy( + monkeypatch + ) + VariableCloner = variables.VariableCloner + VariableExportPolicy = variables.VariableExportPolicy + + snapshot = _make_snapshot(source_dataset, entity_graph) + reindexed = _make_reindexed(reindexing) + payload = VariableCloner().clone( + snapshot, + reindexed, + VariableExportPolicy( + include_input_variables=False, + required_variables=frozenset({"person_age"}), + ), + ) + + assert set(payload.variables) == {"person_age"} + + +def test_variable_cloner_skips_variables_for_uncloned_entities(monkeypatch): + entity_graph, source_dataset, reindexing, variables = _install_fake_package_hierarchy( + monkeypatch + ) + VariableCloner = variables.VariableCloner + VariableExportPolicy = variables.VariableExportPolicy + + snapshot = _make_snapshot(source_dataset, entity_graph) + reindexed = _make_reindexed(reindexing) + payload = VariableCloner().clone(snapshot, reindexed, VariableExportPolicy()) + + assert "ignored_output_only" not in payload.variables diff --git a/tests/unit/calibration/test_local_h5_worker_service.py b/tests/unit/calibration/test_local_h5_worker_service.py new file mode 100644 index 000000000..b615fa98a --- /dev/null +++ b/tests/unit/calibration/test_local_h5_worker_service.py @@ -0,0 +1,501 @@ +import importlib.util +from dataclasses import dataclass +from pathlib import Path +import sys +import types + +import numpy as np + + +def _module_path(*parts: str) -> Path: + return Path(__file__).resolve().parents[3].joinpath(*parts) + + +def _load_module(module_name: str, path: Path): + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _install_fake_package_hierarchy(monkeypatch): + package = types.ModuleType("policyengine_us_data") + package.__path__ = [] + calibration_package = types.ModuleType("policyengine_us_data.calibration") + calibration_package.__path__ = [] + local_h5_package = types.ModuleType("policyengine_us_data.calibration.local_h5") + local_h5_package.__path__ = [] + + monkeypatch.setitem(sys.modules, "policyengine_us_data", package) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration", + calibration_package, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5", + local_h5_package, + ) + + contracts = _load_module( + "policyengine_us_data.calibration.local_h5.contracts", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "contracts.py", + ), + ) + _load_module( + "policyengine_us_data.calibration.local_h5.validation", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "validation.py", + ), + ) + _load_module( + "policyengine_us_data.calibration.local_h5.weights", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "weights.py", + ), + ) + + builder_module = types.ModuleType("policyengine_us_data.calibration.local_h5.builder") + builder_module.LocalAreaDatasetBuilder = type("LocalAreaDatasetBuilder", (), {}) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.builder", + builder_module, + ) + + package_geo_module = types.ModuleType( + "policyengine_us_data.calibration.local_h5.package_geography" + ) + package_geo_module.CalibrationPackageGeographyLoader = type( + "CalibrationPackageGeographyLoader", (), {} + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.package_geography", + package_geo_module, + ) + + source_dataset_module = types.ModuleType( + "policyengine_us_data.calibration.local_h5.source_dataset" + ) + + @dataclass(frozen=True) + class FakeSourceDatasetSnapshot: + dataset_path: Path + time_period: int + household_ids: np.ndarray + + @property + def n_households(self): + return int(len(self.household_ids)) + + source_dataset_module.SourceDatasetSnapshot = FakeSourceDatasetSnapshot + source_dataset_module.PolicyEngineDatasetReader = type( + "PolicyEngineDatasetReader", (), {} + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.source_dataset", + source_dataset_module, + ) + + writer_module = types.ModuleType("policyengine_us_data.calibration.local_h5.writer") + writer_module.H5Writer = type("H5Writer", (), {}) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.writer", + writer_module, + ) + + worker_service = _load_module( + "policyengine_us_data.calibration.local_h5.worker_service", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "worker_service.py", + ), + ) + return contracts, source_dataset_module, worker_service + + +def test_worker_session_loads_source_geography_and_weights_once(monkeypatch, tmp_path): + contracts, source_dataset_module, worker_service = _install_fake_package_hierarchy( + monkeypatch + ) + ValidationPolicy = contracts.ValidationPolicy + WorkerSession = worker_service.WorkerSession + ValidationContext = worker_service.ValidationContext + + weights_path = tmp_path / "weights.npy" + np.save(weights_path, np.asarray([1.0, 0.0, 2.0, 0.0], dtype=float)) + dataset_path = tmp_path / "source.h5" + output_dir = tmp_path / "output" + + snapshot = source_dataset_module.SourceDatasetSnapshot( + dataset_path=dataset_path, + time_period=2024, + household_ids=np.asarray([10, 20]), + ) + source_calls = [] + geo_calls = [] + + class FakeSourceReader: + def load(self, path): + source_calls.append(Path(path)) + return snapshot + + class FakeGeographyLoader: + def resolve_for_weights( + self, + *, + package_path, + weights_length, + n_records, + n_clones, + seed, + allow_seed_fallback, + ): + geo_calls.append( + { + "package_path": package_path, + "weights_length": weights_length, + "n_records": n_records, + "n_clones": n_clones, + "seed": seed, + "allow_seed_fallback": allow_seed_fallback, + } + ) + return types.SimpleNamespace( + geography=types.SimpleNamespace(n_records=2, n_clones=2), + source="package", + warnings=("exact geography",), + ) + + validation_context = ValidationContext( + validation_targets=(), + training_mask_full=np.asarray([], dtype=bool), + constraints_map={}, + db_path=tmp_path / "policy_data.db", + period=2024, + ) + + session = WorkerSession.load( + weights_path=weights_path, + dataset_path=dataset_path, + output_dir=output_dir, + calibration_package_path=tmp_path / "calibration_package.pkl", + requested_n_clones=430, + seed=99, + takeup_filter=("snap", "wic"), + validation_policy=ValidationPolicy(enabled=False), + validation_context=validation_context, + source_reader=FakeSourceReader(), + geography_loader=FakeGeographyLoader(), + ) + + assert source_calls == [dataset_path] + assert geo_calls == [ + { + "package_path": tmp_path / "calibration_package.pkl", + "weights_length": 4, + "n_records": 2, + "n_clones": 2, + "seed": 99, + "allow_seed_fallback": True, + } + ] + np.testing.assert_array_equal(session.weights, np.asarray([1.0, 0.0, 2.0, 0.0])) + assert session.source_snapshot is snapshot + assert session.geography_source == "package" + assert session.geography_warnings == ("exact geography",) + assert session.n_clones == 2 + assert session.takeup_filter == ("snap", "wic") + + +def test_local_h5_worker_service_handles_mixed_chunk_results(monkeypatch): + contracts, source_dataset_module, worker_service = _install_fake_package_hierarchy( + monkeypatch + ) + AreaBuildRequest = contracts.AreaBuildRequest + AreaFilter = contracts.AreaFilter + ValidationResult = contracts.ValidationResult + LocalH5WorkerService = worker_service.LocalH5WorkerService + ValidationContext = worker_service.ValidationContext + WorkerSession = worker_service.WorkerSession + + snapshot = source_dataset_module.SourceDatasetSnapshot( + dataset_path=Path("/tmp/source.h5"), + time_period=2024, + household_ids=np.asarray([10, 20]), + ) + session = WorkerSession( + source_snapshot=snapshot, + weights=np.asarray([1.0, 0.0, 0.0, 2.0], dtype=float), + geography=types.SimpleNamespace(), + output_dir=Path("/tmp/output"), + takeup_filter=("snap",), + validation_policy=contracts.ValidationPolicy(enabled=True), + validation_context=ValidationContext( + validation_targets=(), + training_mask_full=np.asarray([], dtype=bool), + constraints_map={}, + db_path=Path("/tmp/policy_data.db"), + period=2024, + ), + ) + request_ok = AreaBuildRequest( + area_type="state", + area_id="CA", + display_name="CA", + output_relative_path="states/CA.h5", + ) + request_fail = AreaBuildRequest( + area_type="district", + area_id="CA-12", + display_name="CA-12", + output_relative_path="districts/CA-12.h5", + filters=( + AreaFilter( + geography_field="cd_geoid", + op="in", + value=("0612",), + ), + ), + ) + + builder_calls = [] + + class FakeBuilder: + def build(self, **kwargs): + builder_calls.append(kwargs) + if kwargs["filters"] == request_fail.filters: + raise ValueError("builder exploded") + return types.SimpleNamespace( + payload="payload", + selection=types.SimpleNamespace(), + reindexed=types.SimpleNamespace(), + time_period=2024, + ) + + writer_calls = [] + + class FakeWriter: + def write_payload(self, payload, output_path): + writer_calls.append(("write", payload, Path(output_path))) + return Path(output_path) + + def verify_output(self, output_path, *, time_period): + writer_calls.append(("verify", Path(output_path), time_period)) + return {"household_count": 1} + + validator_calls = [] + + def fake_validator(output_path, request, session_obj): + validator_calls.append((Path(output_path), request.area_id, session_obj)) + return ValidationResult( + status="passed", + rows=(), + summary={"n_targets": 0, "n_sanity_fail": 0, "mean_rel_abs_error": 0.0}, + ) + + service = LocalH5WorkerService( + builder=FakeBuilder(), + writer=FakeWriter(), + validator=fake_validator, + ) + + result = service.run(session, (request_ok, request_fail)) + + assert [item.request.area_id for item in result.completed] == ["CA"] + assert [item.request.area_id for item in result.failed] == ["CA-12"] + assert result.failed[0].build_error == "builder exploded" + assert builder_calls[0]["source"] is snapshot + assert builder_calls[1]["source"] is snapshot + assert writer_calls == [ + ("write", "payload", Path("/tmp/output/states/CA.h5")), + ("verify", Path("/tmp/output/states/CA.h5"), 2024), + ] + assert validator_calls[0][1] == "CA" + + +def test_local_h5_worker_service_records_validation_exception(monkeypatch): + contracts, source_dataset_module, worker_service = _install_fake_package_hierarchy( + monkeypatch + ) + AreaBuildRequest = contracts.AreaBuildRequest + LocalH5WorkerService = worker_service.LocalH5WorkerService + ValidationContext = worker_service.ValidationContext + WorkerSession = worker_service.WorkerSession + + snapshot = source_dataset_module.SourceDatasetSnapshot( + dataset_path=Path("/tmp/source.h5"), + time_period=2024, + household_ids=np.asarray([10]), + ) + session = WorkerSession( + source_snapshot=snapshot, + weights=np.asarray([1.0], dtype=float), + geography=types.SimpleNamespace(), + output_dir=Path("/tmp/output"), + validation_policy=contracts.ValidationPolicy(enabled=True), + validation_context=ValidationContext( + validation_targets=(), + training_mask_full=np.asarray([], dtype=bool), + constraints_map={}, + db_path=Path("/tmp/policy_data.db"), + period=2024, + ), + ) + request = AreaBuildRequest.national() + + class FakeBuilder: + def build(self, **_kwargs): + return types.SimpleNamespace( + payload="payload", + selection=types.SimpleNamespace(), + reindexed=types.SimpleNamespace(), + time_period=2024, + ) + + class FakeWriter: + def write_payload(self, payload, output_path): + return Path(output_path) + + def verify_output(self, output_path, *, time_period): + return {} + + def exploding_validator(*_args, **_kwargs): + raise RuntimeError("validator crashed") + + service = LocalH5WorkerService( + builder=FakeBuilder(), + writer=FakeWriter(), + validator=exploding_validator, + ) + result = service.run(session, (request,)) + + assert len(result.completed) == 1 + validation = result.completed[0].validation + assert validation.status == "error" + assert validation.issues[0].code == "validation_exception" + assert validation.issues[0].message == "validator crashed" + + +def test_worker_result_to_legacy_dict_flattens_structured_results(monkeypatch): + contracts, _, worker_service = _install_fake_package_hierarchy(monkeypatch) + AreaBuildRequest = contracts.AreaBuildRequest + AreaBuildResult = contracts.AreaBuildResult + ValidationIssue = contracts.ValidationIssue + ValidationResult = contracts.ValidationResult + WorkerResult = contracts.WorkerResult + + completed = AreaBuildResult( + request=AreaBuildRequest( + area_type="state", + area_id="CA", + display_name="CA", + output_relative_path="states/CA.h5", + ), + build_status="completed", + output_path=Path("/tmp/output/states/CA.h5"), + validation=ValidationResult( + status="failed", + rows=( + {"sanity_check": "FAIL", "rel_abs_error": 0.3}, + {"sanity_check": "PASS", "rel_abs_error": 0.1}, + ), + summary={"n_targets": 2, "n_sanity_fail": 1, "mean_rel_abs_error": 0.2}, + ), + ) + failed = AreaBuildResult( + request=AreaBuildRequest( + area_type="district", + area_id="CA-12", + display_name="CA-12", + output_relative_path="districts/CA-12.h5", + ), + build_status="failed", + build_error="build crashed", + ) + result = WorkerResult( + completed=(completed,), + failed=(failed,), + worker_issues=( + ValidationIssue( + code="session_warning", + message="stale cache", + severity="warning", + details={"path": "/tmp/cache"}, + ), + ), + ) + + payload = worker_service.worker_result_to_legacy_dict(result) + + assert payload["completed"] == ["state:CA"] + assert payload["failed"] == ["district:CA-12"] + assert payload["validation_rows"] == [ + {"sanity_check": "FAIL", "rel_abs_error": 0.3}, + {"sanity_check": "PASS", "rel_abs_error": 0.1}, + ] + assert payload["validation_summary"]["state:CA"] == { + "n_targets": 2, + "n_sanity_fail": 1, + "mean_rel_abs_error": 0.2, + } + assert payload["errors"][0] == { + "item": "district:CA-12", + "error": "build crashed", + } + assert payload["errors"][1] == { + "item": "worker", + "error": "stale cache", + "code": "session_warning", + "details": {"path": "/tmp/cache"}, + } + + +def test_build_requests_from_work_items_handles_city_and_invalid_items(monkeypatch): + contracts, _, worker_service = _install_fake_package_hierarchy(monkeypatch) + + geography = types.SimpleNamespace( + cd_geoid=np.asarray(["3607", "3610", "0101"], dtype=str), + county_fips=np.asarray(["36061", "36047", "01001"], dtype=str), + ) + + requests, failures = worker_service.build_requests_from_work_items( + ( + {"type": "city", "id": "NYC"}, + {"type": "unknown", "id": "mystery"}, + ), + geography=geography, + state_codes={36: "NY", 1: "AL"}, + at_large_districts={0, 98}, + nyc_county_fips={"36061", "36047", "36081"}, + ) + + assert len(requests) == 1 + city_request = requests[0] + assert city_request.area_type == "city" + assert city_request.output_relative_path == "cities/NYC.h5" + assert city_request.validation_geo_level == "district" + assert city_request.validation_geographic_ids == ("3607", "3610") + assert city_request.filters[0].geography_field == "county_fips" + assert len(failures) == 1 + assert failures[0].request.area_type == "custom" + assert failures[0].build_error == "Unknown item type: unknown" diff --git a/tests/unit/calibration/test_local_h5_writer.py b/tests/unit/calibration/test_local_h5_writer.py new file mode 100644 index 000000000..cf6b4b405 --- /dev/null +++ b/tests/unit/calibration/test_local_h5_writer.py @@ -0,0 +1,119 @@ +import importlib.util +from pathlib import Path +import sys +import types + +import numpy as np + + +def _module_path(*parts: str) -> Path: + return Path(__file__).resolve().parents[3].joinpath(*parts) + + +def _load_module(module_name: str, path: Path): + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _install_fake_package_hierarchy(monkeypatch): + package = types.ModuleType("policyengine_us_data") + package.__path__ = [] + calibration_package = types.ModuleType("policyengine_us_data.calibration") + calibration_package.__path__ = [] + local_h5_package = types.ModuleType("policyengine_us_data.calibration.local_h5") + local_h5_package.__path__ = [] + + monkeypatch.setitem(sys.modules, "policyengine_us_data", package) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration", + calibration_package, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5", + local_h5_package, + ) + + _load_module( + "policyengine_us_data.calibration.local_h5.entity_graph", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "entity_graph.py", + ), + ) + _load_module( + "policyengine_us_data.calibration.local_h5.source_dataset", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "source_dataset.py", + ), + ) + _load_module( + "policyengine_us_data.calibration.local_h5.reindexing", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "reindexing.py", + ), + ) + variables = _load_module( + "policyengine_us_data.calibration.local_h5.variables", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "variables.py", + ), + ) + writer = _load_module( + "policyengine_us_data.calibration.local_h5.writer", + _module_path( + "policyengine_us_data", + "calibration", + "local_h5", + "writer.py", + ), + ) + return variables, writer + + +def test_h5_writer_writes_payload_and_verifies_output(monkeypatch, tmp_path): + variables, writer_module = _install_fake_package_hierarchy(monkeypatch) + H5Payload = variables.H5Payload + H5Writer = writer_module.H5Writer + + payload = H5Payload( + variables={ + "household_id": {2024: np.asarray([1, 2], dtype=np.int32)}, + "person_id": {2024: np.asarray([10, 11, 12], dtype=np.int32)}, + "household_weight": {2024: np.asarray([1.5, 2.0], dtype=np.float32)}, + "person_weight": { + 2024: np.asarray([1.0, 1.25, 1.25], dtype=np.float32) + }, + } + ) + output_path = tmp_path / "nested" / "local.h5" + + writer = H5Writer() + written_path = writer.write_payload(payload, output_path) + summary = writer.verify_output(written_path, time_period=2024) + + assert written_path == output_path + assert output_path.exists() + assert summary == { + "household_count": 2, + "person_count": 3, + "household_weight_sum": 3.5, + "person_weight_sum": 3.5, + } diff --git a/tests/unit/calibration/test_worker_script_adapter.py b/tests/unit/calibration/test_worker_script_adapter.py new file mode 100644 index 000000000..adebfb5d8 --- /dev/null +++ b/tests/unit/calibration/test_worker_script_adapter.py @@ -0,0 +1,224 @@ +import importlib.util +import json +from pathlib import Path +import sys +import types + + +def _module_path(*parts: str) -> Path: + return Path(__file__).resolve().parents[3].joinpath(*parts) + + +def _load_worker_script_module(): + module_path = _module_path("modal_app", "worker_script.py") + spec = importlib.util.spec_from_file_location( + "worker_script_under_test", + module_path, + ) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def test_main_requires_requests_or_work_items(monkeypatch): + worker_script = _load_worker_script_module() + monkeypatch.setattr( + sys, + "argv", + [ + "worker_script.py", + "--weights-path", + "/tmp/weights.npy", + "--dataset-path", + "/tmp/source.h5", + "--db-path", + "/tmp/policy_data.db", + "--output-dir", + "/tmp/output", + ], + ) + + try: + worker_script.main() + except ValueError as error: + assert str(error) == "Either --requests-json or --work-items is required" + else: + raise AssertionError("Expected ValueError when no request input is provided") + + +def test_main_delegates_to_worker_service_and_emits_structured_json( + monkeypatch, + capsys, +): + worker_script = _load_worker_script_module() + calls = {} + + takeup = types.ModuleType("policyengine_us_data.utils.takeup") + takeup.SIMPLE_TAKEUP_VARS = [{"variable": "snap"}] + monkeypatch.setitem(sys.modules, "policyengine_us_data.utils.takeup", takeup) + + contracts = types.ModuleType("policyengine_us_data.calibration.local_h5.contracts") + + class FakeAreaBuildRequest: + @staticmethod + def from_dict(payload): + return {"parsed_request": dict(payload)} + + class FakeValidationPolicy: + def __init__(self, enabled=True): + self.enabled = enabled + + contracts.AreaBuildRequest = FakeAreaBuildRequest + contracts.ValidationPolicy = FakeValidationPolicy + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.contracts", + contracts, + ) + + package_geography = types.ModuleType( + "policyengine_us_data.calibration.local_h5.package_geography" + ) + package_geography.require_calibration_package_path = lambda path: Path(path) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.package_geography", + package_geography, + ) + + worker_service = types.ModuleType( + "policyengine_us_data.calibration.local_h5.worker_service" + ) + + class FakeWorkerResult: + def to_dict(self): + return { + "completed": [], + "failed": [], + "worker_issues": [], + } + + class FakeWorkerSession: + @classmethod + def load(cls, **kwargs): + calls["session_kwargs"] = kwargs + return types.SimpleNamespace( + requested_n_clones=430, + n_clones=430, + geography_source="package", + geography_warnings=(), + geography=types.SimpleNamespace(n_clones=430, n_records=2), + source_snapshot=types.SimpleNamespace(n_households=2), + ) + + class FakeLocalH5WorkerService: + def run(self, session, requests, initial_failures=()): + calls["service_run"] = { + "session": session, + "requests": requests, + "initial_failures": initial_failures, + } + return FakeWorkerResult() + + worker_service.WorkerSession = FakeWorkerSession + worker_service.LocalH5WorkerService = FakeLocalH5WorkerService + worker_service.load_validation_context = lambda **kwargs: None + worker_service.build_requests_from_work_items = ( + lambda *args, **kwargs: ((), ()) + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.worker_service", + worker_service, + ) + + publish_local_area = types.ModuleType( + "policyengine_us_data.calibration.publish_local_area" + ) + publish_local_area.AT_LARGE_DISTRICTS = {0, 98} + publish_local_area.NYC_COUNTY_FIPS = {"36061"} + publish_local_area.SUB_ENTITIES = ["tax_unit", "spm_unit"] + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.publish_local_area", + publish_local_area, + ) + + calibration_utils = types.ModuleType( + "policyengine_us_data.calibration.calibration_utils" + ) + calibration_utils.STATE_CODES = {1: "AL"} + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.calibration_utils", + calibration_utils, + ) + + source_dataset = types.ModuleType( + "policyengine_us_data.calibration.local_h5.source_dataset" + ) + + class FakeReader: + def __init__(self, sub_entities): + calls["reader_sub_entities"] = tuple(sub_entities) + + source_dataset.PolicyEngineDatasetReader = FakeReader + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.source_dataset", + source_dataset, + ) + + requests_json = json.dumps( + [ + { + "area_type": "state", + "area_id": "CA", + "display_name": "CA", + "output_relative_path": "states/CA.h5", + } + ] + ) + monkeypatch.setattr( + sys, + "argv", + [ + "worker_script.py", + "--requests-json", + requests_json, + "--weights-path", + "/tmp/weights.npy", + "--dataset-path", + "/tmp/source.h5", + "--db-path", + "/tmp/policy_data.db", + "--output-dir", + "/tmp/output", + ], + ) + + worker_script.main() + captured = capsys.readouterr() + + assert calls["reader_sub_entities"] == ("tax_unit", "spm_unit") + assert calls["session_kwargs"]["weights_path"] == Path("/tmp/weights.npy") + assert calls["session_kwargs"]["dataset_path"] == Path("/tmp/source.h5") + assert calls["session_kwargs"]["allow_seed_fallback"] is False + assert calls["service_run"]["requests"] == ( + { + "parsed_request": { + "area_type": "state", + "area_id": "CA", + "display_name": "CA", + "output_relative_path": "states/CA.h5", + } + }, + ) + assert json.loads(captured.out) == { + "completed": [], + "failed": [], + "worker_issues": [], + } diff --git a/tests/unit/test_local_area_coordinator_contract.py b/tests/unit/test_local_area_coordinator_contract.py new file mode 100644 index 000000000..167ffecc1 --- /dev/null +++ b/tests/unit/test_local_area_coordinator_contract.py @@ -0,0 +1,758 @@ +import importlib.util +import json +from dataclasses import dataclass +from pathlib import Path +import sys +import types + +import pytest + + +def _load_module(module_name: str, path: Path): + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _load_local_area_module(monkeypatch): + class FakeModalFunction: + def __init__(self, fn): + self.fn = fn + self.object_id = "fake-modal-fn" + + def __call__(self, *args, **kwargs): + return self.fn(*args, **kwargs) + + def remote(self, *args, **kwargs): + return self.fn(*args, **kwargs) + + def spawn(self, *args, **kwargs): + result = self.fn(*args, **kwargs) + return types.SimpleNamespace(object_id="fake-handle", get=lambda: result) + + class FakeApp: + def __init__(self, *_args, **_kwargs): + pass + + def function(self, **_kwargs): + def decorator(fn): + return FakeModalFunction(fn) + + return decorator + + def local_entrypoint(self, **_kwargs): + def decorator(fn): + return fn + + return decorator + + fake_modal = types.ModuleType("modal") + fake_modal.App = FakeApp + fake_modal.Secret = types.SimpleNamespace( + from_name=lambda *args, **kwargs: object() + ) + fake_modal.Volume = types.SimpleNamespace( + from_name=lambda *args, **kwargs: types.SimpleNamespace( + reload=lambda: None, + commit=lambda: None, + ) + ) + monkeypatch.setitem(sys.modules, "modal", fake_modal) + + images_module = types.ModuleType("modal_app.images") + images_module.cpu_image = object() + monkeypatch.setitem(sys.modules, "modal_app.images", images_module) + + resilience_module = types.ModuleType("modal_app.resilience") + resilience_module.reconcile_run_dir_fingerprint = lambda *_a, **_k: "initialized" + monkeypatch.setitem(sys.modules, "modal_app.resilience", resilience_module) + + package = types.ModuleType("policyengine_us_data") + package.__path__ = [] + calibration_package = types.ModuleType("policyengine_us_data.calibration") + calibration_package.__path__ = [] + local_h5_package = types.ModuleType("policyengine_us_data.calibration.local_h5") + local_h5_package.__path__ = [] + monkeypatch.setitem(sys.modules, "policyengine_us_data", package) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration", + calibration_package, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5", + local_h5_package, + ) + + contracts_spec = importlib.util.spec_from_file_location( + "policyengine_us_data.calibration.local_h5.contracts", + Path(__file__).resolve().parents[2] + / "policyengine_us_data" + / "calibration" + / "local_h5" + / "contracts.py", + ) + contracts_module = importlib.util.module_from_spec(contracts_spec) + assert contracts_spec is not None + assert contracts_spec.loader is not None + sys.modules[contracts_spec.name] = contracts_module + contracts_spec.loader.exec_module(contracts_module) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.contracts", + contracts_module, + ) + + fingerprinting_module = types.ModuleType( + "policyengine_us_data.calibration.local_h5.fingerprinting" + ) + fingerprinting_module.FingerprintService = type("FingerprintService", (), {}) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.fingerprinting", + fingerprinting_module, + ) + + package_geo_module = types.ModuleType( + "policyengine_us_data.calibration.local_h5.package_geography" + ) + package_geo_module.CalibrationPackageGeographyLoader = type( + "CalibrationPackageGeographyLoader", (), {} + ) + package_geo_module.require_calibration_package_path = lambda path: Path(path) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.package_geography", + package_geo_module, + ) + + partitioning_module = types.ModuleType( + "policyengine_us_data.calibration.local_h5.partitioning" + ) + partitioning_module.partition_weighted_work_items = ( + lambda work_items, _num_workers, _completed: [work_items] + ) + partitioning_module.work_item_key = ( + lambda item: f"{item['type']}:{item['id']}" + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.partitioning", + partitioning_module, + ) + + area_catalog_module = types.ModuleType( + "policyengine_us_data.calibration.local_h5.area_catalog" + ) + area_catalog_module.USAreaCatalog = type("USAreaCatalog", (), {}) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.area_catalog", + area_catalog_module, + ) + + module_path = Path(__file__).resolve().parents[2] / "modal_app" / "local_area.py" + return _load_module("local_area_under_test", module_path) + + +def _translated_path_factory(tmp_path): + real_path = Path + pipeline_root = tmp_path / "pipeline" + staging_root = tmp_path / "staging" + + def translate(value): + raw = str(value) + if raw == "/pipeline": + return pipeline_root + if raw.startswith("/pipeline/"): + return pipeline_root / raw.removeprefix("/pipeline/") + if raw == "/staging": + return staging_root + if raw.startswith("/staging/"): + return staging_root / raw.removeprefix("/staging/") + return real_path(value) + + return translate + + +def _prepare_artifacts(tmp_path, run_id: str): + artifacts = tmp_path / "pipeline" / "artifacts" / run_id + artifacts.mkdir(parents=True, exist_ok=True) + for name in ( + "calibration_weights.npy", + "national_calibration_weights.npy", + "source_imputed_stratified_extended_cps.h5", + "policy_data.db", + "calibration_package.pkl", + ): + (artifacts / name).write_bytes(b"x") + return artifacts + + +class _FakeFingerprintService: + def __init__(self, digest="actualfp"): + self.digest = digest + + def create_publish_fingerprint(self, **_kwargs): + return types.SimpleNamespace(digest=self.digest) + + +@dataclass(frozen=True) +class _FakeEntry: + request: object + weight: int + + @property + def key(self): + return f"{self.request.area_type}:{self.request.area_id}" + + def to_partition_item(self): + return { + "type": self.request.area_type, + "id": self.request.area_id, + "weight": self.weight, + } + + +class _FakeCatalog: + def __init__(self, entries=None, national_entry=None): + self._entries = entries or () + self._national_entry = national_entry + + def resolved_regional_entries(self, *_args, **_kwargs): + return self._entries + + def resolved_national_entry(self): + return self._national_entry + + +def test_coordinate_publish_requires_calibration_package(monkeypatch): + local_area = _load_local_area_module(monkeypatch) + monkeypatch.setattr(local_area, "setup_gcp_credentials", lambda: None) + monkeypatch.setattr(local_area, "setup_repo", lambda _branch: None) + monkeypatch.setattr(local_area, "get_version", lambda: "1.0.0") + monkeypatch.setattr( + local_area, + "require_calibration_package_path", + lambda _path: (_ for _ in ()).throw(FileNotFoundError("missing package")), + ) + + with pytest.raises(FileNotFoundError, match="missing package"): + local_area.coordinate_publish( + branch="main", + num_workers=1, + skip_upload=True, + validate=False, + run_id="run1", + ) + + +def test_coordinate_publish_rejects_pinned_fingerprint_mismatch( + monkeypatch, tmp_path +): + local_area = _load_local_area_module(monkeypatch) + monkeypatch.setattr(local_area, "Path", _translated_path_factory(tmp_path)) + _prepare_artifacts(tmp_path, "run1") + monkeypatch.setattr(local_area, "setup_gcp_credentials", lambda: None) + monkeypatch.setattr(local_area, "setup_repo", lambda _branch: None) + monkeypatch.setattr(local_area, "get_version", lambda: "1.0.0") + monkeypatch.setattr(local_area, "validate_artifacts", lambda *_a, **_k: None) + monkeypatch.setattr( + local_area, + "require_calibration_package_path", + lambda path: local_area.Path(path), + ) + monkeypatch.setattr(local_area, "_derive_canonical_n_clones", lambda **_k: 3) + monkeypatch.setattr( + local_area, + "FingerprintService", + lambda: _FakeFingerprintService(digest="actualfp"), + ) + + with pytest.raises(RuntimeError, match="Pinned fingerprint does not match"): + local_area.coordinate_publish( + branch="main", + num_workers=1, + skip_upload=True, + validate=False, + run_id="run1", + expected_fingerprint="expectedfp", + ) + + +def test_coordinate_publish_returns_validation_errors_in_skip_upload_mode( + monkeypatch, tmp_path +): + local_area = _load_local_area_module(monkeypatch) + request = local_area.AreaBuildRequest( + area_type="state", + area_id="CA", + display_name="CA", + output_relative_path="states/CA.h5", + ) + monkeypatch.setattr(local_area, "Path", _translated_path_factory(tmp_path)) + _prepare_artifacts(tmp_path, "run1") + monkeypatch.setattr(local_area, "setup_gcp_credentials", lambda: None) + monkeypatch.setattr(local_area, "setup_repo", lambda _branch: None) + monkeypatch.setattr(local_area, "get_version", lambda: "1.0.0") + monkeypatch.setattr(local_area, "validate_artifacts", lambda *_a, **_k: None) + monkeypatch.setattr( + local_area, + "require_calibration_package_path", + lambda path: local_area.Path(path), + ) + monkeypatch.setattr(local_area, "_derive_canonical_n_clones", lambda **_k: 3) + monkeypatch.setattr( + local_area, + "FingerprintService", + lambda: _FakeFingerprintService(digest="actualfp"), + ) + reconcile_calls = [] + monkeypatch.setattr( + local_area, + "reconcile_run_dir_fingerprint", + lambda *_a, **kwargs: ( + reconcile_calls.append(kwargs), + "initialized", + )[1], + ) + monkeypatch.setattr( + local_area, + "_load_catalog_geography", + lambda *_a, **_k: object(), + ) + monkeypatch.setattr( + local_area, + "USAreaCatalog", + lambda: _FakeCatalog(entries=(_FakeEntry(request=request, weight=1),)), + ) + monkeypatch.setattr( + local_area, + "run_phase", + lambda *_a, **_k: ( + {"state:CA"}, + [], + [{"area_type": "state", "area_id": "CA"}], + [{"item": "state:CA", "error": "validator crashed"}], + ), + ) + + result = local_area.coordinate_publish( + branch="main", + num_workers=1, + skip_upload=True, + validate=False, + run_id="run1", + ) + + assert result["fingerprint"] == "actualfp" + assert reconcile_calls == [{"scope": "regional"}] + assert result["validation_rows"] == [{"area_type": "state", "area_id": "CA"}] + assert result["validation_errors"] == [ + {"item": "state:CA", "error": "validator crashed"} + ] + + +def test_coordinate_publish_raises_on_build_failures_even_if_files_exist( + monkeypatch, tmp_path +): + local_area = _load_local_area_module(monkeypatch) + request = local_area.AreaBuildRequest( + area_type="state", + area_id="CA", + display_name="CA", + output_relative_path="states/CA.h5", + ) + monkeypatch.setattr(local_area, "Path", _translated_path_factory(tmp_path)) + _prepare_artifacts(tmp_path, "run1") + monkeypatch.setattr(local_area, "setup_gcp_credentials", lambda: None) + monkeypatch.setattr(local_area, "setup_repo", lambda _branch: None) + monkeypatch.setattr(local_area, "get_version", lambda: "1.0.0") + monkeypatch.setattr(local_area, "validate_artifacts", lambda *_a, **_k: None) + monkeypatch.setattr( + local_area, + "require_calibration_package_path", + lambda path: local_area.Path(path), + ) + monkeypatch.setattr(local_area, "_derive_canonical_n_clones", lambda **_k: 3) + monkeypatch.setattr( + local_area, + "FingerprintService", + lambda: _FakeFingerprintService(digest="actualfp"), + ) + monkeypatch.setattr( + local_area, + "reconcile_run_dir_fingerprint", + lambda *_a, **_k: "initialized", + ) + monkeypatch.setattr( + local_area, + "_load_catalog_geography", + lambda *_a, **_k: object(), + ) + monkeypatch.setattr( + local_area, + "USAreaCatalog", + lambda: _FakeCatalog(entries=(_FakeEntry(request=request, weight=1),)), + ) + monkeypatch.setattr( + local_area, + "run_phase", + lambda *_a, **_k: ( + {"state:CA"}, + [{"type": "build_failure", "item": "state:CA", "error": "bad output"}], + [], + [], + ), + ) + + with pytest.raises(RuntimeError, match="build failure"): + local_area.coordinate_publish( + branch="main", + num_workers=1, + skip_upload=True, + validate=False, + run_id="run1", + ) + + +def test_coordinate_publish_raises_on_worker_issues_even_if_files_exist( + monkeypatch, tmp_path +): + local_area = _load_local_area_module(monkeypatch) + request = local_area.AreaBuildRequest( + area_type="state", + area_id="CA", + display_name="CA", + output_relative_path="states/CA.h5", + ) + monkeypatch.setattr(local_area, "Path", _translated_path_factory(tmp_path)) + _prepare_artifacts(tmp_path, "run1") + monkeypatch.setattr(local_area, "setup_gcp_credentials", lambda: None) + monkeypatch.setattr(local_area, "setup_repo", lambda _branch: None) + monkeypatch.setattr(local_area, "get_version", lambda: "1.0.0") + monkeypatch.setattr(local_area, "validate_artifacts", lambda *_a, **_k: None) + monkeypatch.setattr( + local_area, + "require_calibration_package_path", + lambda path: local_area.Path(path), + ) + monkeypatch.setattr(local_area, "_derive_canonical_n_clones", lambda **_k: 3) + monkeypatch.setattr( + local_area, + "FingerprintService", + lambda: _FakeFingerprintService(digest="actualfp"), + ) + monkeypatch.setattr( + local_area, + "reconcile_run_dir_fingerprint", + lambda *_a, **_k: "initialized", + ) + monkeypatch.setattr( + local_area, + "_load_catalog_geography", + lambda *_a, **_k: object(), + ) + monkeypatch.setattr( + local_area, + "USAreaCatalog", + lambda: _FakeCatalog(entries=(_FakeEntry(request=request, weight=1),)), + ) + monkeypatch.setattr( + local_area, + "run_phase", + lambda *_a, **_k: ( + {"state:CA"}, + [{"type": "worker_issue", "item": "worker", "error": "subprocess failed"}], + [], + [], + ), + ) + + with pytest.raises(RuntimeError, match="worker issue"): + local_area.coordinate_publish( + branch="main", + num_workers=1, + skip_upload=True, + validate=False, + run_id="run1", + ) + + +def test_coordinate_national_publish_returns_worker_validation_errors( + monkeypatch, tmp_path +): + local_area = _load_local_area_module(monkeypatch) + national_request = local_area.AreaBuildRequest.national() + monkeypatch.setattr(local_area, "Path", _translated_path_factory(tmp_path)) + _prepare_artifacts(tmp_path, "run1") + monkeypatch.setattr(local_area, "setup_gcp_credentials", lambda: None) + monkeypatch.setattr(local_area, "setup_repo", lambda _branch: None) + monkeypatch.setattr(local_area, "get_version", lambda: "1.0.0") + monkeypatch.setattr(local_area, "validate_artifacts", lambda *_a, **_k: None) + monkeypatch.setattr( + local_area, + "require_calibration_package_path", + lambda path: local_area.Path(path), + ) + monkeypatch.setattr(local_area, "_derive_canonical_n_clones", lambda **_k: 3) + monkeypatch.setattr( + local_area, + "USAreaCatalog", + lambda: _FakeCatalog( + national_entry=_FakeEntry(request=national_request, weight=1) + ), + ) + monkeypatch.setattr( + local_area, + "FingerprintService", + lambda: _FakeFingerprintService(digest="natfp"), + ) + reconcile_calls = [] + monkeypatch.setattr( + local_area, + "reconcile_run_dir_fingerprint", + lambda *_a, **kwargs: ( + reconcile_calls.append(kwargs), + "initialized", + )[1], + ) + + def fake_remote(**_kwargs): + output_path = local_area.Path("/staging/run1/national/US.h5") + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_bytes(b"fake-h5") + return { + "completed": [ + { + "request": national_request.to_dict(), + "build_status": "completed", + "output_path": str(output_path), + "build_error": None, + "validation": { + "status": "error", + "rows": [], + "issues": [ + { + "code": "validation_exception", + "message": "validator crashed", + "severity": "error", + "details": {}, + } + ], + "summary": {}, + }, + } + ], + "failed": [], + "worker_issues": [], + } + + monkeypatch.setattr( + local_area, + "build_areas_worker", + types.SimpleNamespace(remote=fake_remote), + ) + monkeypatch.setattr( + local_area.subprocess, + "run", + lambda *_a, **_k: types.SimpleNamespace(returncode=0, stdout="Done", stderr=""), + ) + + result = local_area.coordinate_national_publish( + branch="main", + validate=False, + run_id="run1", + ) + + assert result["validation_errors"] == [ + { + "item": "national:US", + "error": "validator crashed", + "code": "validation_exception", + "details": {}, + } + ] + assert result["fingerprint"] == "natfp" + assert reconcile_calls == [{"scope": "national"}] + + +def test_coordinate_national_publish_rejects_pinned_fingerprint_mismatch( + monkeypatch, tmp_path +): + local_area = _load_local_area_module(monkeypatch) + national_request = local_area.AreaBuildRequest.national() + monkeypatch.setattr(local_area, "Path", _translated_path_factory(tmp_path)) + _prepare_artifacts(tmp_path, "run1") + monkeypatch.setattr(local_area, "setup_gcp_credentials", lambda: None) + monkeypatch.setattr(local_area, "setup_repo", lambda _branch: None) + monkeypatch.setattr(local_area, "get_version", lambda: "1.0.0") + monkeypatch.setattr(local_area, "validate_artifacts", lambda *_a, **_k: None) + monkeypatch.setattr( + local_area, + "require_calibration_package_path", + lambda path: local_area.Path(path), + ) + monkeypatch.setattr(local_area, "_derive_canonical_n_clones", lambda **_k: 3) + monkeypatch.setattr( + local_area, + "USAreaCatalog", + lambda: _FakeCatalog( + national_entry=_FakeEntry(request=national_request, weight=1) + ), + ) + monkeypatch.setattr( + local_area, + "FingerprintService", + lambda: _FakeFingerprintService(digest="actualfp"), + ) + + with pytest.raises(RuntimeError, match="Pinned fingerprint does not match"): + local_area.coordinate_national_publish( + branch="main", + validate=False, + run_id="run1", + expected_fingerprint="expectedfp", + ) + + +def test_coordinate_national_publish_resumes_without_rebuilding( + monkeypatch, tmp_path +): + local_area = _load_local_area_module(monkeypatch) + national_request = local_area.AreaBuildRequest.national() + monkeypatch.setattr(local_area, "Path", _translated_path_factory(tmp_path)) + _prepare_artifacts(tmp_path, "run1") + output_path = tmp_path / "staging" / "run1" / "national" / "US.h5" + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_bytes(b"existing-h5") + monkeypatch.setattr(local_area, "setup_gcp_credentials", lambda: None) + monkeypatch.setattr(local_area, "setup_repo", lambda _branch: None) + monkeypatch.setattr(local_area, "get_version", lambda: "1.0.0") + monkeypatch.setattr(local_area, "validate_artifacts", lambda *_a, **_k: None) + monkeypatch.setattr( + local_area, + "require_calibration_package_path", + lambda path: local_area.Path(path), + ) + monkeypatch.setattr(local_area, "_derive_canonical_n_clones", lambda **_k: 3) + monkeypatch.setattr( + local_area, + "USAreaCatalog", + lambda: _FakeCatalog( + national_entry=_FakeEntry(request=national_request, weight=1) + ), + ) + monkeypatch.setattr( + local_area, + "FingerprintService", + lambda: _FakeFingerprintService(digest="natfp"), + ) + monkeypatch.setattr( + local_area, + "reconcile_run_dir_fingerprint", + lambda *_a, **_k: "resume", + ) + monkeypatch.setattr( + local_area, + "build_areas_worker", + types.SimpleNamespace( + remote=lambda **_kwargs: (_ for _ in ()).throw( + AssertionError("worker should not be called on resume") + ) + ), + ) + monkeypatch.setattr( + local_area.subprocess, + "run", + lambda *_a, **_k: types.SimpleNamespace(returncode=0, stdout="Done", stderr=""), + ) + + result = local_area.coordinate_national_publish( + branch="main", + validate=False, + run_id="run1", + expected_fingerprint="natfp", + ) + + assert result["fingerprint"] == "natfp" + + +def test_run_phase_aggregates_structured_worker_results(monkeypatch): + local_area = _load_local_area_module(monkeypatch) + request = local_area.AreaBuildRequest( + area_type="state", + area_id="CA", + display_name="CA", + output_relative_path="states/CA.h5", + ) + entry = _FakeEntry(request=request, weight=2) + run_dir = Path("/staging/run1") + monkeypatch.setattr(local_area, "get_completed_from_volume", lambda _run_dir: {"state:CA"}) + monkeypatch.setattr(local_area.staging_volume, "reload", lambda: None) + + def fake_spawn(**kwargs): + assert kwargs["requests"] == [request.to_dict()] + payload = { + "completed": [ + { + "request": request.to_dict(), + "build_status": "completed", + "output_path": "/staging/run1/states/CA.h5", + "build_error": None, + "validation": { + "status": "failed", + "rows": [ + { + "target_name": "population", + "sanity_check": "FAIL", + "rel_abs_error": 0.2, + } + ], + "issues": [], + "summary": { + "n_targets": 1, + "n_sanity_fail": 1, + "mean_rel_abs_error": 0.2, + }, + }, + } + ], + "failed": [], + "worker_issues": [], + } + return types.SimpleNamespace(object_id="fake-handle", get=lambda: payload) + + monkeypatch.setattr( + local_area, + "build_areas_worker", + types.SimpleNamespace(spawn=fake_spawn), + ) + + completed, errors, validation_rows, validation_errors = local_area.run_phase( + "All areas", + entries=[entry], + num_workers=1, + completed=set(), + branch="main", + run_id="run1", + calibration_inputs={"weights": "w", "dataset": "d", "database": "db"}, + run_dir=run_dir, + validate=True, + ) + + assert completed == {"state:CA"} + assert errors == [] + assert validation_rows == [ + { + "target_name": "population", + "sanity_check": "FAIL", + "rel_abs_error": 0.2, + } + ] + assert validation_errors == [] diff --git a/tests/unit/test_modal_resilience.py b/tests/unit/test_modal_resilience.py index 34390be8d..d891a954f 100644 --- a/tests/unit/test_modal_resilience.py +++ b/tests/unit/test_modal_resilience.py @@ -1,24 +1,114 @@ +import importlib.util import json +from dataclasses import dataclass +from pathlib import Path +import sys +import types import pytest -from modal_app.resilience import ( - ensure_resume_sha_compatible, - reconcile_run_dir_fingerprint, -) +def _load_module(module_name: str, path: Path): + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _load_resilience_module(monkeypatch): + package = types.ModuleType("policyengine_us_data") + package.__path__ = [] + calibration_package = types.ModuleType("policyengine_us_data.calibration") + calibration_package.__path__ = [] + local_h5_package = types.ModuleType("policyengine_us_data.calibration.local_h5") + local_h5_package.__path__ = [] + + monkeypatch.setitem(sys.modules, "policyengine_us_data", package) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration", + calibration_package, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5", + local_h5_package, + ) + + @dataclass(frozen=True) + class FakeFingerprintRecord: + schema_version: str + algorithm: str + digest: str + components: object | None = None + inputs: dict = None + + def to_dict(self): + return { + "fingerprint": self.digest, + "digest": self.digest, + "schema_version": self.schema_version, + "algorithm": self.algorithm, + } + + class FakeFingerprintService: + ALGORITHM = "sha256-truncated-16" + + def legacy_record(self, digest: str): + return FakeFingerprintRecord( + schema_version="legacy", + algorithm=self.ALGORITHM, + digest=str(digest), + ) + + def write_record(self, path, record): + Path(path).write_text(json.dumps(record.to_dict())) + + def read_record(self, path): + payload = json.loads(Path(path).read_text()) + return FakeFingerprintRecord( + schema_version=payload.get("schema_version", "legacy"), + algorithm=payload.get("algorithm", self.ALGORITHM), + digest=payload.get("digest") or payload.get("fingerprint"), + ) + + def matches(self, stored, current): + return stored.digest == current.digest + + fingerprinting = types.ModuleType( + "policyengine_us_data.calibration.local_h5.fingerprinting" + ) + fingerprinting.FingerprintRecord = FakeFingerprintRecord + fingerprinting.FingerprintService = FakeFingerprintService + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.fingerprinting", + fingerprinting, + ) + + module_path = Path(__file__).resolve().parents[2] / "modal_app" / "resilience.py" + module = _load_module("resilience_under_test", module_path) + return module, FakeFingerprintRecord + + +def test_resume_requires_same_sha(monkeypatch): + resilience, _ = _load_resilience_module(monkeypatch) -def test_resume_requires_same_sha(): with pytest.raises(RuntimeError, match="Start a fresh run instead"): - ensure_resume_sha_compatible( + resilience.ensure_resume_sha_compatible( branch="fix/pipeline-resilience", run_sha="0123456789abcdef", current_sha="fedcba9876543210", ) -def test_resume_allows_same_sha(): - result = ensure_resume_sha_compatible( +def test_resume_allows_same_sha(monkeypatch): + resilience, _ = _load_resilience_module(monkeypatch) + + result = resilience.ensure_resume_sha_compatible( branch="fix/pipeline-resilience", run_sha="0123456789abcdef", current_sha="0123456789abcdef", @@ -26,8 +116,10 @@ def test_resume_allows_same_sha(): assert result is True -def test_resume_force_allows_mismatched_sha(): - result = ensure_resume_sha_compatible( +def test_resume_force_allows_mismatched_sha(monkeypatch): + resilience, _ = _load_resilience_module(monkeypatch) + + result = resilience.ensure_resume_sha_compatible( branch="fix/pipeline-resilience", run_sha="0123456789abcdef", current_sha="fedcba9876543210", @@ -36,8 +128,10 @@ def test_resume_force_allows_mismatched_sha(): assert result is False -def test_resume_force_with_matching_sha(): - result = ensure_resume_sha_compatible( +def test_resume_force_with_matching_sha(monkeypatch): + resilience, _ = _load_resilience_module(monkeypatch) + + result = resilience.ensure_resume_sha_compatible( branch="fix/pipeline-resilience", run_sha="0123456789abcdef", current_sha="0123456789abcdef", @@ -46,14 +140,15 @@ def test_resume_force_with_matching_sha(): assert result is True -def test_reconcile_run_dir_resumes_matching_fingerprint(tmp_path): +def test_reconcile_run_dir_resumes_matching_legacy_fingerprint(monkeypatch, tmp_path): + resilience, _ = _load_resilience_module(monkeypatch) run_dir = tmp_path / "1.2.3_abc12345_20260407_120000" run_dir.mkdir() (run_dir / "states").mkdir() (run_dir / "states" / "CA.h5").write_text("h5") (run_dir / "fingerprint.json").write_text(json.dumps({"fingerprint": "abc123"})) - action = reconcile_run_dir_fingerprint(run_dir, "abc123") + action = resilience.reconcile_run_dir_fingerprint(run_dir, "abc123") assert action == "resume" assert (run_dir / "states" / "CA.h5").exists() @@ -62,7 +157,8 @@ def test_reconcile_run_dir_resumes_matching_fingerprint(tmp_path): } -def test_reconcile_run_dir_rejects_changed_fingerprint_with_h5s(tmp_path): +def test_reconcile_run_dir_rejects_changed_fingerprint_with_h5s(monkeypatch, tmp_path): + resilience, _ = _load_resilience_module(monkeypatch) run_dir = tmp_path / "1.2.3_abc12345_20260407_120000" run_dir.mkdir() (run_dir / "states").mkdir() @@ -70,7 +166,7 @@ def test_reconcile_run_dir_rejects_changed_fingerprint_with_h5s(tmp_path): (run_dir / "fingerprint.json").write_text(json.dumps({"fingerprint": "oldfp"})) with pytest.raises(RuntimeError, match="Fingerprint mismatch"): - reconcile_run_dir_fingerprint(run_dir, "newfp") + resilience.reconcile_run_dir_fingerprint(run_dir, "newfp") assert (run_dir / "states" / "CA.h5").exists() assert json.loads((run_dir / "fingerprint.json").read_text()) == { @@ -78,29 +174,128 @@ def test_reconcile_run_dir_rejects_changed_fingerprint_with_h5s(tmp_path): } -def test_reconcile_run_dir_rejects_missing_fingerprint_with_h5s(tmp_path): +def test_reconcile_run_dir_rejects_missing_fingerprint_with_h5s(monkeypatch, tmp_path): + resilience, _ = _load_resilience_module(monkeypatch) run_dir = tmp_path / "1.2.3_abc12345_20260407_120000" run_dir.mkdir() (run_dir / "states").mkdir() (run_dir / "states" / "CA.h5").write_text("stale") with pytest.raises(RuntimeError, match="Missing fingerprint metadata"): - reconcile_run_dir_fingerprint(run_dir, "newfp") + resilience.reconcile_run_dir_fingerprint(run_dir, "newfp") assert (run_dir / "states" / "CA.h5").exists() assert not (run_dir / "fingerprint.json").exists() -def test_reconcile_run_dir_clears_empty_stale_directory(tmp_path): +def test_reconcile_run_dir_clears_empty_stale_directory(monkeypatch, tmp_path): + resilience, _ = _load_resilience_module(monkeypatch) run_dir = tmp_path / "1.2.3_abc12345_20260407_120000" run_dir.mkdir() (run_dir / "scratch.txt").write_text("stale") (run_dir / "fingerprint.json").write_text(json.dumps({"fingerprint": "oldfp"})) - action = reconcile_run_dir_fingerprint(run_dir, "newfp") + action = resilience.reconcile_run_dir_fingerprint(run_dir, "newfp") assert action == "initialized" assert not (run_dir / "scratch.txt").exists() - assert json.loads((run_dir / "fingerprint.json").read_text()) == { - "fingerprint": "newfp" - } + stored = json.loads((run_dir / "fingerprint.json").read_text()) + assert stored["fingerprint"] == "newfp" + assert stored["digest"] == "newfp" + assert stored["schema_version"] == "legacy" + assert stored["algorithm"] == "sha256-truncated-16" + + +def test_reconcile_run_dir_accepts_rich_record_object(monkeypatch, tmp_path): + resilience, FakeFingerprintRecord = _load_resilience_module(monkeypatch) + run_dir = tmp_path / "1.2.3_abc12345_20260407_120000" + + action = resilience.reconcile_run_dir_fingerprint( + run_dir, + FakeFingerprintRecord( + schema_version="local_h5_publish_v1", + algorithm="sha256-truncated-16", + digest="rich1234", + ), + ) + + assert action == "initialized" + stored = json.loads((run_dir / "fingerprint.json").read_text()) + assert stored["fingerprint"] == "rich1234" + assert stored["digest"] == "rich1234" + assert stored["schema_version"] == "local_h5_publish_v1" + + +def test_reconcile_scope_initializes_without_touching_sibling_outputs( + monkeypatch, tmp_path +): + resilience, _ = _load_resilience_module(monkeypatch) + run_dir = tmp_path / "1.2.3_abc12345_20260407_120000" + run_dir.mkdir() + (run_dir / "national").mkdir() + (run_dir / "national" / "US.h5").write_text("national") + + action = resilience.reconcile_run_dir_fingerprint( + run_dir, + "regionalfp", + scope="regional", + ) + + assert action == "initialized" + assert (run_dir / "national" / "US.h5").exists() + stored = json.loads( + (run_dir / ".publish_scopes" / "regional" / "fingerprint.json").read_text() + ) + assert stored["fingerprint"] == "regionalfp" + + +def test_reconcile_scope_clears_only_owned_dirs_when_no_owned_h5s( + monkeypatch, tmp_path +): + resilience, _ = _load_resilience_module(monkeypatch) + run_dir = tmp_path / "1.2.3_abc12345_20260407_120000" + run_dir.mkdir() + (run_dir / "states").mkdir() + (run_dir / "states" / "scratch.txt").write_text("stale-regional") + (run_dir / "national").mkdir() + (run_dir / "national" / "US.h5").write_text("national") + scope_dir = run_dir / ".publish_scopes" / "regional" + scope_dir.mkdir(parents=True) + (scope_dir / "fingerprint.json").write_text(json.dumps({"fingerprint": "oldfp"})) + + action = resilience.reconcile_run_dir_fingerprint( + run_dir, + "newfp", + scope="regional", + ) + + assert action == "initialized" + assert not (run_dir / "states").exists() + assert (run_dir / "national" / "US.h5").exists() + stored = json.loads((scope_dir / "fingerprint.json").read_text()) + assert stored["fingerprint"] == "newfp" + + +def test_reconcile_scope_rejects_changed_fingerprint_with_owned_h5s( + monkeypatch, tmp_path +): + resilience, _ = _load_resilience_module(monkeypatch) + run_dir = tmp_path / "1.2.3_abc12345_20260407_120000" + run_dir.mkdir() + (run_dir / "states").mkdir() + (run_dir / "states" / "CA.h5").write_text("regional") + (run_dir / "national").mkdir() + (run_dir / "national" / "US.h5").write_text("national") + scope_dir = run_dir / ".publish_scopes" / "regional" + scope_dir.mkdir(parents=True) + (scope_dir / "fingerprint.json").write_text(json.dumps({"fingerprint": "oldfp"})) + + with pytest.raises(RuntimeError, match="staged regional H5 files"): + resilience.reconcile_run_dir_fingerprint( + run_dir, + "newfp", + scope="regional", + ) + + assert (run_dir / "states" / "CA.h5").exists() + assert (run_dir / "national" / "US.h5").exists() diff --git a/tests/unit/test_pipeline.py b/tests/unit/test_pipeline.py index 5aaca8a47..f7c1ec34c 100644 --- a/tests/unit/test_pipeline.py +++ b/tests/unit/test_pipeline.py @@ -63,6 +63,22 @@ def test_from_dict(self): assert meta.status == "completed" assert meta.step_timings["build_datasets"]["status"] == "completed" + def test_from_dict_maps_legacy_fingerprint_to_regional_fingerprint(self): + meta = RunMetadata.from_dict( + { + "run_id": "1.72.3_abc12345_20260319_120000", + "branch": "main", + "sha": "abc12345deadbeef", + "version": "1.72.3", + "start_time": "2026-03-19T12:00:00Z", + "status": "completed", + "fingerprint": "legacyfp", + } + ) + + assert meta.fingerprint == "legacyfp" + assert meta.regional_fingerprint == "legacyfp" + def test_roundtrip(self): meta = RunMetadata( run_id="1.72.3_abc12345_20260319_120000", diff --git a/tests/unit/test_pipeline_validation_diagnostics.py b/tests/unit/test_pipeline_validation_diagnostics.py new file mode 100644 index 000000000..67d1d522a --- /dev/null +++ b/tests/unit/test_pipeline_validation_diagnostics.py @@ -0,0 +1,277 @@ +import importlib.util +import json +from pathlib import Path +import sys +import types +from unittest.mock import MagicMock + + +def _module_path(*parts: str) -> Path: + return Path(__file__).resolve().parents[2].joinpath(*parts) + + +def _install_fake_modal(monkeypatch): + modal = types.ModuleType("modal") + + class FakeApp: + def __init__(self, *_args, **_kwargs): + pass + + def include(self, _other): + return None + + def function(self, *args, **kwargs): + def decorator(fn): + return fn + + return decorator + + def local_entrypoint(self, *args, **kwargs): + def decorator(fn): + return fn + + return decorator + + class FakeSecret: + @staticmethod + def from_name(_name): + return object() + + class FakeVolume: + @staticmethod + def from_name(_name, create_if_missing=False): + return MagicMock() + + modal.App = FakeApp + modal.Secret = FakeSecret + modal.Volume = FakeVolume + monkeypatch.setitem(sys.modules, "modal", modal) + + +def _load_pipeline_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[2] + + modal_app_package = types.ModuleType("modal_app") + modal_app_package.__path__ = [str(repo_root / "modal_app")] + monkeypatch.setitem(sys.modules, "modal_app", modal_app_package) + + images_module = types.ModuleType("modal_app.images") + images_module.cpu_image = object() + monkeypatch.setitem(sys.modules, "modal_app.images", images_module) + + data_build_module = types.ModuleType("modal_app.data_build") + data_build_module.app = object() + data_build_module.build_datasets = object() + monkeypatch.setitem(sys.modules, "modal_app.data_build", data_build_module) + + calibration_module = types.ModuleType("modal_app.remote_calibration_runner") + calibration_module.app = object() + calibration_module.build_package_remote = object() + calibration_module.PACKAGE_GPU_FUNCTIONS = {} + monkeypatch.setitem( + sys.modules, + "modal_app.remote_calibration_runner", + calibration_module, + ) + + local_area_module = types.ModuleType("modal_app.local_area") + local_area_module.app = object() + local_area_module.coordinate_publish = object() + local_area_module.coordinate_national_publish = object() + local_area_module.promote_publish = object() + local_area_module.promote_national_publish = object() + monkeypatch.setitem(sys.modules, "modal_app.local_area", local_area_module) + + policyengine_root = types.ModuleType("policyengine_us_data") + policyengine_root.__path__ = [str(repo_root / "policyengine_us_data")] + monkeypatch.setitem(sys.modules, "policyengine_us_data", policyengine_root) + + calibration_package = types.ModuleType("policyengine_us_data.calibration") + calibration_package.__path__ = [ + str(repo_root / "policyengine_us_data" / "calibration") + ] + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration", + calibration_package, + ) + + local_h5_package = types.ModuleType("policyengine_us_data.calibration.local_h5") + local_h5_package.__path__ = [ + str(repo_root / "policyengine_us_data" / "calibration" / "local_h5") + ] + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5", + local_h5_package, + ) + + validation_module = types.ModuleType( + "policyengine_us_data.calibration.local_h5.validation" + ) + + def tag_validation_errors(errors, source): + return [{**error, "source": source} for error in errors] + + validation_module.tag_validation_errors = tag_validation_errors + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.local_h5.validation", + validation_module, + ) + + _install_fake_modal(monkeypatch) + + module_path = _module_path("modal_app", "pipeline.py") + spec = importlib.util.spec_from_file_location( + "modal_app.pipeline_under_test", + module_path, + ) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def _make_meta(pipeline_module): + return pipeline_module.RunMetadata( + run_id="test_run", + branch="main", + sha="abc123", + version="1.0.0", + start_time="2026-03-19T12:00:00Z", + status="running", + ) + + +def test_write_validation_diagnostics_writes_outputs_and_meta(monkeypatch, tmp_path): + pipeline = _load_pipeline_module(monkeypatch) + runs_dir = tmp_path / "runs" + meta = _make_meta(pipeline) + mock_vol = MagicMock() + + regional_result = { + "validation_rows": [ + { + "area_type": "state", + "area_id": "CA", + "district": "", + "variable": "household_count", + "target_name": "household_count", + "period": 2024, + "target_value": 100.0, + "sim_value": 110.0, + "error": 10.0, + "rel_error": 0.1, + "abs_error": 10.0, + "rel_abs_error": 0.1, + "sanity_check": "FAIL", + "sanity_reason": "too_high", + "in_training": True, + } + ], + "validation_errors": [ + { + "item": "state:CA", + "error": "regional validator crashed", + "code": "validation_exception", + "details": {"traceback": "tb-regional"}, + } + ], + } + national_result = { + "national_validation": "national validation output", + "validation_errors": [ + { + "item": "national:US", + "error": "national validator crashed", + "code": "validation_exception", + "details": {"traceback": "tb-national"}, + } + ], + } + + monkeypatch.setattr(pipeline, "RUNS_DIR", str(runs_dir)) + + pipeline._write_validation_diagnostics( + run_id="test_run", + regional_result=regional_result, + national_result=national_result, + meta=meta, + vol=mock_vol, + ) + + diag_dir = runs_dir / "test_run" / "diagnostics" + csv_path = diag_dir / "validation_results.csv" + errors_path = diag_dir / "validation_errors.json" + national_path = diag_dir / "national_validation.txt" + meta_path = runs_dir / "test_run" / "meta.json" + + assert csv_path.exists() + assert errors_path.exists() + assert national_path.exists() + assert meta_path.exists() + + csv_lines = csv_path.read_text().strip().splitlines() + assert len(csv_lines) == 2 + assert "area_type,area_id" in csv_lines[0] + assert "state,CA" in csv_lines[1] + + assert json.loads(errors_path.read_text()) == [ + { + "item": "state:CA", + "error": "regional validator crashed", + "code": "validation_exception", + "details": {"traceback": "tb-regional"}, + "source": "regional", + }, + { + "item": "national:US", + "error": "national validator crashed", + "code": "validation_exception", + "details": {"traceback": "tb-national"}, + "source": "national", + }, + ] + assert national_path.read_text() == "national validation output" + + assert meta.step_timings["validation"] == { + "total_targets": 1, + "sanity_failures": 1, + "mean_rel_abs_error": 0.1, + "validation_errors": 2, + "worst_areas": [ + { + "area": "state:CA", + "mean_rae": 0.1, + "sanity_fails": 1, + } + ], + } + + saved_meta = json.loads(meta_path.read_text()) + assert saved_meta["step_timings"]["validation"]["validation_errors"] == 2 + assert mock_vol.commit.call_count == 2 + + +def test_write_validation_diagnostics_skips_when_no_data(monkeypatch, tmp_path): + pipeline = _load_pipeline_module(monkeypatch) + runs_dir = tmp_path / "runs" + meta = _make_meta(pipeline) + mock_vol = MagicMock() + + monkeypatch.setattr(pipeline, "RUNS_DIR", str(runs_dir)) + + pipeline._write_validation_diagnostics( + run_id="test_run", + regional_result={}, + national_result={}, + meta=meta, + vol=mock_vol, + ) + + assert not (runs_dir / "test_run").exists() + assert meta.step_timings == {} + mock_vol.commit.assert_not_called()