diff --git a/changelog.d/1041.changed b/changelog.d/1041.changed new file mode 100644 index 000000000..08363c26e --- /dev/null +++ b/changelog.d/1041.changed @@ -0,0 +1 @@ +Stage 2 calibration package manifests now track the explicit target config identity and contract artifact path. diff --git a/changelog.d/1065.changed b/changelog.d/1065.changed new file mode 100644 index 000000000..a8e81f607 --- /dev/null +++ b/changelog.d/1065.changed @@ -0,0 +1 @@ +Stage 2 calibration package construction now resolves its inputs and outputs through run-scoped artifact bundles. diff --git a/changelog.d/1073.changed b/changelog.d/1073.changed new file mode 100644 index 000000000..2e3a100ab --- /dev/null +++ b/changelog.d/1073.changed @@ -0,0 +1 @@ +Add typed Stage 2 calibration package payload reader and writer helpers. diff --git a/changelog.d/1083.changed b/changelog.d/1083.changed new file mode 100644 index 000000000..509a9ae81 --- /dev/null +++ b/changelog.d/1083.changed @@ -0,0 +1 @@ +Add Stage 2 target catalog selection artifacts for calibration packages. diff --git a/docs/pipeline_map.yaml b/docs/pipeline_map.yaml index 7bcfdda2c..223638262 100644 --- a/docs/pipeline_map.yaml +++ b/docs/pipeline_map.yaml @@ -805,7 +805,15 @@ stages: label: run_calibration() description: 'Build phase: resolve targets and constraints, assemble clone values, and package the sparse calibration matrix' node_ids: + - stage2_input_bundle + - stage2_build_context + - stage2_target_config_identity + - stage2_target_catalog_load + - stage2_target_catalog_reader + - stage2_target_selection_policy + - stage2_target_selection_result - target_resolve + - stage2_target_config_apply - target_uprate - geo_build - constraint_resolve @@ -813,8 +821,25 @@ stages: - clone_assembly - takeup_rerand - sparse_build + - build_matrix + - build_matrix_chunked + - stage2_artifact_specs + - stage2_payload_boundary + - stage2_payload_writer + - stage2_payload_reader + - stage2_calibration_package_writer - out_pkg + - out_metadata + - out_targets + - out_target_facets + - stage2_calibration_package_contract_writer + - out_contract + - stage2_calibration_package_contract_validator extra_nodes: + - id: in_stage1_contract_s2 + label: dataset_build_output.json + node_type: artifact + description: Stage 1 handoff contract preferred for Stage 2 input resolution - id: in_cps_s5 label: source_imputed_stratified_extended_cps.h5 node_type: artifact @@ -859,6 +884,22 @@ stages: label: calibration_package.pkl node_type: artifact description: X_sparse CSR matrix, targets_df, initial_weights, metadata + - id: out_metadata + label: calibration_package_meta.json + node_type: artifact + description: Metadata sidecar generated from the typed package payload and Stage 2 contract + - id: out_targets + label: calibration_targets.jsonl + node_type: artifact + description: Row-level selected target metadata with stable target_id and target_index join keys + - id: out_target_facets + label: calibration_target_facets.json + node_type: artifact + description: Compact target counts by variable, geography level, target name, period, and constraint key + - id: out_contract + label: calibration_package_contract.json + node_type: artifact + description: Stage 2 package handoff contract written next to calibration_package.pkl - id: util_sql label: sqlalchemy node_type: utility @@ -876,20 +917,82 @@ stages: node_type: utility description: CSR/COO matrix construction edges: + - source: in_stage1_contract_s2 + target: stage2_input_bundle + edge_type: data_flow + label: preferred input contract - source: in_cps_s5 + target: stage2_input_bundle + edge_type: data_flow + label: compatibility fallback + - source: in_db_s5 + target: stage2_input_bundle + edge_type: external_source + label: compatibility fallback + - source: stage2_input_bundle + target: stage2_build_context + edge_type: data_flow + label: validated inputs + - source: stage2_artifact_specs + target: stage2_build_context + edge_type: uses_utility + label: output bundle paths + - source: stage2_build_context target: target_resolve edge_type: data_flow + label: dataset and database paths + - source: stage2_build_context + target: stage2_calibration_package_writer + edge_type: uses_utility + label: package output bundle - source: in_db_s5 target: target_resolve edge_type: external_source label: SQL targets - source: in_config_s5 - target: target_resolve + target: stage2_target_config_identity + edge_type: data_flow + label: config file + - source: stage2_target_config_identity + target: stage2_target_catalog_load + edge_type: data_flow + label: resolved path and checksum + - source: in_db_s5 + target: stage2_target_catalog_reader + edge_type: external_source + label: active and disabled targets + - source: stage2_target_catalog_reader + target: stage2_target_selection_policy + edge_type: data_flow + label: target catalog + - source: stage2_target_catalog_load + target: stage2_target_selection_policy + edge_type: data_flow + label: include/exclude rules + - source: stage2_target_selection_policy + target: stage2_target_selection_result + edge_type: data_flow + label: selected targets + - source: stage2_target_selection_result + target: build_matrix edge_type: data_flow - label: include list + label: matrix target order + - source: stage2_target_selection_result + target: build_matrix_chunked + edge_type: data_flow + label: matrix target order + - source: stage2_target_catalog_load + target: stage2_target_config_apply + edge_type: data_flow + label: include/exclude rules - source: target_resolve + target: stage2_target_config_apply + edge_type: data_flow + label: candidate targets + - source: stage2_target_config_apply target: target_uprate edge_type: data_flow + label: selected targets - source: target_uprate target: geo_build edge_type: data_flow @@ -917,8 +1020,89 @@ stages: target: sparse_build edge_type: data_flow - source: sparse_build + target: build_matrix + edge_type: uses_library + label: non-chunked path + - source: sparse_build + target: build_matrix_chunked + edge_type: uses_library + label: chunked path + - source: build_matrix + target: stage2_payload_boundary + edge_type: data_flow + - source: build_matrix_chunked + target: stage2_payload_boundary + edge_type: data_flow + - source: stage2_payload_boundary + target: stage2_payload_writer + edge_type: data_flow + label: typed pickle payload + - source: stage2_payload_writer + target: stage2_calibration_package_writer + edge_type: uses_library + label: pickle write + - source: stage2_artifact_specs + target: stage2_calibration_package_writer + edge_type: uses_utility + label: package path + - source: stage2_calibration_package_writer target: out_pkg edge_type: produces_artifact + - source: stage2_payload_writer + target: out_metadata + edge_type: produces_artifact + label: sidecar metadata + - source: stage2_target_selection_result + target: out_targets + edge_type: produces_artifact + label: row-level selected targets + - source: stage2_target_selection_result + target: out_target_facets + edge_type: produces_artifact + label: derived facets + - source: out_pkg + target: stage2_payload_reader + edge_type: data_flow + - source: out_pkg + target: stage2_calibration_package_contract_writer + edge_type: data_flow + - source: stage2_payload_reader + target: stage2_calibration_package_contract_writer + edge_type: uses_library + label: summary and checksum + - source: out_targets + target: stage2_calibration_package_contract_writer + edge_type: data_flow + label: target metadata artifact + - source: out_target_facets + target: stage2_calibration_package_contract_writer + edge_type: data_flow + label: target facet artifact + - source: stage2_artifact_specs + target: stage2_calibration_package_contract_writer + edge_type: uses_utility + label: contract path + - source: stage2_calibration_package_contract_writer + target: out_contract + edge_type: produces_artifact + - source: stage2_calibration_package_contract_writer + target: out_targets + edge_type: validates + - source: stage2_calibration_package_contract_writer + target: out_target_facets + edge_type: validates + - source: out_pkg + target: stage2_calibration_package_contract_validator + edge_type: validates + - source: out_contract + target: stage2_calibration_package_contract_validator + edge_type: validates + - source: in_cps_s5 + target: stage2_calibration_package_contract_validator + edge_type: validates + - source: in_db_s5 + target: stage2_calibration_package_contract_validator + edge_type: validates - source: util_sql target: target_resolve edge_type: uses_utility diff --git a/modal_app/pipeline.py b/modal_app/pipeline.py index 39a437808..56ef5879c 100644 --- a/modal_app/pipeline.py +++ b/modal_app/pipeline.py @@ -92,6 +92,11 @@ write_run_meta, ) from policyengine_us_data.utils.run_context import RunContext, resolve_run_id # noqa: E402 +from policyengine_us_data.calibration_package.specs import ( # noqa: E402 + Stage2InputBundleError, + resolve_target_config_identity, + stage2_build_context_for_run, +) from policyengine_us_data.utils.error_redaction import ( # noqa: E402 redacted_bounded_error_text, redact_error_text, @@ -162,6 +167,7 @@ def _calibration_package_parameters( workers: int, n_clones: int, target_config: str | None, + all_active_targets: bool = False, skip_county: bool, chunked_matrix: bool, chunk_size: int, @@ -169,11 +175,17 @@ def _calibration_package_parameters( num_matrix_workers: int, ) -> dict: """Return manifest parameters that affect package construction.""" + target_config_identity = resolve_target_config_identity( + target_config, + all_active_targets=all_active_targets, + ) effective_parallel = bool(chunked_matrix and parallel_matrix) params = { "workers": workers if not chunked_matrix else None, "n_clones": n_clones, - "target_config": target_config, + "target_config": target_config_identity.path, + "target_config_sha256": target_config_identity.sha256, + "target_config_mode": target_config_identity.mode, "skip_county": skip_county, "chunked_matrix": bool(chunked_matrix), "chunk_size": chunk_size if chunked_matrix else None, @@ -547,6 +559,7 @@ def verify_runtime_seams() -> dict: "modal_app/step_manifests/errors.py", "modal_app/step_manifests/status.py", "modal_app/fixtures/h5_cases.py", + "policyengine_us_data/calibration_package/specs.py", "tests/integration/test_fixture_50hh.h5", "policyengine_us_data/calibration/target_config.yaml", "policyengine_us_data/calibration/target_config_full.yaml", @@ -1231,13 +1244,13 @@ def run_pipeline( print(f" Completed in {completed_build_manifest.duration_s}s") # ── Step 2: Build calibration package ── + package_context = stage2_build_context_for_run(PIPELINE_MOUNT, run_id) + package_input_validation = package_context.input_bundle.validation_report() package_inputs = _artifact_identities( - { - "dataset": _artifacts_dir(run_id) - / "source_imputed_stratified_extended_cps.h5", - "database": _artifacts_dir(run_id) / "policy_data.db", - } + package_context.input_bundle.manifest_inputs ) + package_inputs["input_validation"] = package_input_validation.to_dict() + package_artifacts = package_context.output_bundle package_parameters = _calibration_package_parameters( workers=num_workers, n_clones=n_clones, @@ -1248,6 +1261,18 @@ def run_pipeline( parallel_matrix=parallel_matrix, num_matrix_workers=num_matrix_workers, ) + if package_input_validation.status != "pass": + active_step_manifest = _start_step_manifest( + meta, + BUILD_CALIBRATION_PACKAGE, + parameters=package_parameters, + input_identities=package_inputs, + vol=pipeline_volume, + ) + raise Stage2InputBundleError( + package_context.input_bundle, + package_input_validation, + ) package_reuse = _step_reusable( meta, BUILD_CALIBRATION_PACKAGE, @@ -1302,8 +1327,7 @@ def run_pipeline( completed_package_manifest = _complete_step_manifest( active_step_manifest, outputs=collect_artifacts( - [_artifacts_dir(run_id) / "calibration_package.pkl"], - missing_ok=True, + package_artifacts.manifest_outputs, ), vol=pipeline_volume, ) diff --git a/modal_app/remote_calibration_runner.py b/modal_app/remote_calibration_runner.py index 7a198a0ed..a81764fba 100644 --- a/modal_app/remote_calibration_runner.py +++ b/modal_app/remote_calibration_runner.py @@ -12,6 +12,10 @@ sys.path.insert(0, _p) from modal_app.images import gpu_image as image # noqa: E402 +from policyengine_us_data.calibration_package.specs import ( # noqa: E402 + calibration_package_artifact_paths, + stage2_build_context_for_run, +) app = modal.App( os.environ.get("US_DATA_FIT_WEIGHTS_APP_NAME") or "policyengine-us-data-fit-weights" @@ -317,23 +321,30 @@ def _print_provenance_from_meta(meta: dict, current_branch: str = None) -> None: def _write_package_sidecar(pkg_path: str) -> bool: - """Extract metadata from a pickle package and write a JSON sidecar. + """Write package metadata from the typed payload and contract sidecar. Returns: True if sidecar was written successfully, False otherwise. """ - import json import logging - import pickle - sidecar_path = pkg_path.replace(".pkl", "_meta.json") try: - with open(pkg_path, "rb") as f: - package = pickle.load(f) - meta = package.get("metadata", {}) - del package - with open(sidecar_path, "w") as f: - json.dump(meta, f, indent=2) + from policyengine_us_data.calibration_package.payload import ( + CalibrationPackageReader, + CalibrationPackageWriter, + ) + from policyengine_us_data.calibration_package.specs import ( + CALIBRATION_PACKAGE_CONTRACT_FILENAME, + ) + from policyengine_us_data.stage_contracts.io import read_contract + + package_path = Path(pkg_path) + payload = CalibrationPackageReader(package_path=package_path).read() + contract_path = package_path.with_name(CALIBRATION_PACKAGE_CONTRACT_FILENAME) + contract = read_contract(contract_path) if contract_path.exists() else None + sidecar_path = CalibrationPackageWriter( + package_path=package_path, + ).write_metadata_sidecar(payload, contract=contract) print( f"Sidecar metadata written to {sidecar_path}", flush=True, @@ -368,18 +379,14 @@ def _build_package_impl( _ensure_geography_prerequisites() pipeline_vol.reload() - artifacts = f"{PIPELINE_MOUNT}/artifacts" - if run_id: - artifacts = f"{artifacts}/{run_id}" - db_path = f"{artifacts}/policy_data.db" - dataset_path = f"{artifacts}/source_imputed_stratified_extended_cps.h5" - for label, p in [("database", db_path), ("dataset", dataset_path)]: - if not os.path.exists(p): - raise RuntimeError( - f"Missing {label} on pipeline volume: {p}. Run data_build first." - ) - - pkg_path = f"{artifacts}/calibration_package.pkl" + build_context = stage2_build_context_for_run( + PIPELINE_MOUNT, run_id + ).require_inputs() + input_bundle = build_context.input_bundle + package_artifacts = build_context.output_bundle + db_path = str(input_bundle.target_database) + dataset_path = str(input_bundle.source_dataset) + pkg_path = str(package_artifacts.package) cmd = [ *_python_cmd("-m", "policyengine_us_data.calibration.unified_calibration"), "--device", @@ -404,7 +411,7 @@ def _build_package_impl( if chunked_matrix: cmd.extend(["--chunked-matrix", "--chunk-size", str(chunk_size)]) if parallel_matrix: - chunk_dir = f"{artifacts}/matrix_build" + chunk_dir = str(package_artifacts.matrix_build_dir) cmd.extend( [ "--parallel", @@ -439,14 +446,12 @@ def _build_package_impl( raise RuntimeError(f"Package build failed with code {build_rc}") from policyengine_us_data.stage_contracts.calibration_package import ( - CALIBRATION_PACKAGE_CONTRACT_FILENAME, validate_persisted_calibration_package_contract, ) - contract_path = f"{artifacts}/{CALIBRATION_PACKAGE_CONTRACT_FILENAME}" validate_persisted_calibration_package_contract( - package_path=Path(pkg_path), - contract_path=Path(contract_path), + package_path=package_artifacts.package, + contract_path=package_artifacts.contract, dataset_path=Path(dataset_path), db_path=Path(db_path), ) @@ -525,8 +530,9 @@ def check_volume_package(artifacts_dir: str = "") -> dict: import json base = artifacts_dir if artifacts_dir else f"{PIPELINE_MOUNT}/artifacts" - pkg_path = f"{base}/calibration_package.pkl" - sidecar_path = f"{base}/calibration_package_meta.json" + package_artifacts = calibration_package_artifact_paths(base) + pkg_path = str(package_artifacts.package) + sidecar_path = str(package_artifacts.metadata) if not os.path.exists(pkg_path): return {"exists": False} diff --git a/policyengine_us_data/calibration/unified_calibration.py b/policyengine_us_data/calibration/unified_calibration.py index 9418bd4c9..467be0ade 100644 --- a/policyengine_us_data/calibration/unified_calibration.py +++ b/policyengine_us_data/calibration/unified_calibration.py @@ -44,6 +44,23 @@ from policyengine_us_data.calibration.calibration_utils import ( create_target_groups, ) +from policyengine_us_data.calibration_package.payload import ( + CalibrationPackagePayload, + CalibrationPackageReader, + CalibrationPackageWriter, +) +from policyengine_us_data.calibration_package.specs import ( + DEFAULT_TARGET_CONFIG_PATH as DEFAULT_TARGET_CONFIG_RELATIVE_PATH, + CALIBRATION_TARGET_FACETS_FILENAME, + CALIBRATION_TARGETS_FILENAME, + TargetConfigIdentity, + resolve_target_config_identity, +) +from policyengine_us_data.calibration_package.targets import ( + TargetCatalog, + TargetCatalogReader, + TargetSelectionPolicy, +) from policyengine_us_data.pipeline_metadata import pipeline_node from policyengine_us_data.stage_contracts.calibration_package import ( CalibrationPackageParameters, @@ -72,7 +89,9 @@ LEARNING_RATE = 0.15 DEFAULT_EPOCHS = 100 DEFAULT_N_CLONES = 430 -DEFAULT_TARGET_CONFIG_PATH = Path(__file__).resolve().parent / "target_config.yaml" +DEFAULT_TARGET_CONFIG_PATH = ( + Path(__file__).resolve().parents[2] / DEFAULT_TARGET_CONFIG_RELATIVE_PATH +) def _utc_now_isoformat() -> str: @@ -86,6 +105,8 @@ def _calibration_package_contract_parameters( workers: int, n_clones: int, target_config_path: str | None, + target_config_sha256: str | None, + target_config_mode: str | None, skip_county: bool, skip_source_impute: bool, skip_takeup_rerandomize: bool, @@ -100,6 +121,8 @@ def _calibration_package_contract_parameters( workers=workers, n_clones=n_clones, target_config_path=target_config_path, + target_config_sha256=target_config_sha256, + target_config_mode=target_config_mode, skip_county=skip_county, skip_source_impute=skip_source_impute, skip_takeup_rerandomize=skip_takeup_rerandomize, @@ -110,6 +133,47 @@ def _calibration_package_contract_parameters( ) +def _target_config_identity_for_metadata( + *, + target_config: dict | None, + target_config_path: str | None, + target_config_identity: TargetConfigIdentity | None, +) -> TargetConfigIdentity | None: + """Return a resolved identity consistent with the parsed target config.""" + + if target_config_identity is not None: + if ( + target_config is None + and target_config_identity.mode != "all_active_targets" + ): + raise ValueError( + "target_config_identity requires a parsed target_config unless " + "all_active_targets is selected" + ) + if ( + target_config is not None + and target_config_identity.mode == "all_active_targets" + ): + raise ValueError( + "all_active_targets identity cannot be paired with a target_config" + ) + return target_config_identity + if target_config is None: + if target_config_path is not None: + raise ValueError( + "target_config_path cannot be recorded unless target_config is parsed" + ) + return TargetConfigIdentity( + path=None, + sha256=None, + mode="all_active_targets", + resolved_path=None, + ) + if target_config_path is None: + return None + return resolve_target_config_identity(target_config_path) + + def get_git_provenance() -> dict: """Capture git state and package version for provenance tracking.""" import subprocess as _sp @@ -451,14 +515,30 @@ def parse_args(argv=None): return parser.parse_args(argv) +@pipeline_node( + PipelineNode( + id="stage2_target_catalog_load", + label="Load Stage 2 Target Config", + node_type="library", + description="Load the include/exclude target-selection catalog used by Stage 2 package construction.", + source_file="policyengine_us_data/calibration/unified_calibration.py", + status="current", + stability="moving", + pathways=["calibration_package"], + artifacts_in=[DEFAULT_TARGET_CONFIG_RELATIVE_PATH], + validation_commands=[ + "uv run pytest tests/unit/calibration/test_target_config.py" + ], + ) +) def load_target_config(path: str) -> dict: - """Load target exclusion config from YAML. + """Load target include/exclude config from YAML. Args: path: Path to YAML config file. Returns: - Parsed config dict with 'exclude' list. + Parsed config dict with include and exclude lists. """ import yaml @@ -509,24 +589,18 @@ def apply_target_config( Returns: (filtered_targets_df, filtered_X_sparse, filtered_names) """ - include_rules = config.get("include", []) - exclude_rules = config.get("exclude", []) - - if not include_rules and not exclude_rules: + policy = TargetSelectionPolicy.from_config(config) + if not policy.include_rules and not policy.exclude_rules: return targets_df, X_sparse, target_names n_before = len(targets_df) - - if include_rules: - keep_mask = _match_rules(targets_df, include_rules) - else: - keep_mask = np.ones(n_before, dtype=bool) - - if exclude_rules: - drop_mask = _match_rules(targets_df, exclude_rules) - keep_mask &= ~drop_mask - - n_dropped = n_before - keep_mask.sum() + working_targets = targets_df.copy() + if "target_id" not in working_targets.columns: + working_targets["target_id"] = list(range(len(working_targets))) + selection = policy.select(TargetCatalog.from_targets(working_targets)) + selected_ids = set(selection.target_ids) + keep_mask = working_targets["target_id"].isin(selected_ids) + n_dropped = n_before - int(keep_mask.sum()) logger.info( "Target config: kept %d / %d targets (dropped %d)", keep_mask.sum(), @@ -534,7 +608,7 @@ def apply_target_config( n_dropped, ) - idx = np.where(keep_mask)[0] + idx = np.where(keep_mask.to_numpy())[0] filtered_df = targets_df.iloc[idx].reset_index(drop=True) filtered_X = None if X_sparse is None else X_sparse[idx, :] filtered_names = [target_names[i] for i in idx] @@ -542,6 +616,21 @@ def apply_target_config( return filtered_df, filtered_X, filtered_names +@pipeline_node( + PipelineNode( + id="stage2_target_config_apply", + label="Apply Stage 2 Target Config", + node_type="library", + description="Apply Stage 2 target include/exclude rules before matrix construction.", + source_file="policyengine_us_data/calibration/unified_calibration.py", + status="current", + stability="moving", + pathways=["calibration_package"], + validation_commands=[ + "uv run pytest tests/unit/calibration/test_target_config.py" + ], + ) +) def apply_target_config_to_targets( targets_df: "pd.DataFrame", config: dict, @@ -556,6 +645,22 @@ def apply_target_config_to_targets( return filtered_df +@pipeline_node( + PipelineNode( + id="stage2_calibration_package_writer", + label="Stage 2 Package Writer", + node_type="library", + description="Persist the Stage 2 sparse matrix, target rows, target names, geography arrays, and provenance metadata.", + source_file="policyengine_us_data/calibration/unified_calibration.py", + status="current", + stability="moving", + pathways=["calibration_package"], + artifacts_out=["calibration_package.pkl"], + validation_commands=[ + "uv run pytest tests/unit/calibration/test_unified_calibration.py" + ], + ) +) def save_calibration_package( path: str, X_sparse, @@ -578,20 +683,16 @@ def save_calibration_package( cd_geoid: CD GEOID array from geography assignment. block_geoid: Block GEOID array from geography assignment. """ - import pickle - - package = { - "X_sparse": X_sparse, - "targets_df": targets_df, - "target_names": target_names, - "metadata": metadata, - "initial_weights": initial_weights, - "cd_geoid": cd_geoid, - "block_geoid": block_geoid, - } - Path(path).parent.mkdir(parents=True, exist_ok=True) - with open(path, "wb") as f: - pickle.dump(package, f, protocol=pickle.HIGHEST_PROTOCOL) + payload = CalibrationPackagePayload( + X_sparse=X_sparse, + targets_df=targets_df, + target_names=target_names, + metadata=metadata, + initial_weights=initial_weights, + cd_geoid=cd_geoid, + block_geoid=block_geoid, + ) + CalibrationPackageWriter(package_path=Path(path)).write(payload) logger.info("Calibration package saved to %s", path) @@ -604,16 +705,14 @@ def load_calibration_package(path: str) -> dict: Returns: Dict with X_sparse, targets_df, target_names, metadata. """ - import pickle - - with open(path, "rb") as f: - package = pickle.load(f) + payload = CalibrationPackageReader(package_path=Path(path)).read() + package = payload.to_mapping() logger.info( "Loaded package: %d targets, %d records", - package["X_sparse"].shape[0], - package["X_sparse"].shape[1], + payload.X_sparse.shape[0], + payload.X_sparse.shape[1], ) - meta = package.get("metadata", {}) + meta = payload.metadata print_package_provenance(meta) check_package_staleness(meta) return package @@ -1285,6 +1384,7 @@ def run_calibration( skip_county: bool = True, target_config: dict = None, target_config_path: str = None, + target_config_identity: TargetConfigIdentity | None = None, build_only: bool = False, package_path: str = None, package_output_path: str = None, @@ -1322,6 +1422,7 @@ def run_calibration( skip_source_impute: Skip ACS/SIPP/SCF imputations. target_config: Parsed target config dict. target_config_path: Path to target config, for provenance. + target_config_identity: Resolved target config path/checksum identity. build_only: If True, save package and skip fitting. package_path: Load pre-built package (skip build). package_output_path: Where to save calibration package. @@ -1432,6 +1533,11 @@ def run_calibration( db_uri=db_uri, time_period=time_period, ) + resolved_target_identity = _target_config_identity_for_metadata( + target_config=target_config, + target_config_path=target_config_path, + target_config_identity=target_config_identity, + ) # Compute base household AGI for conditional geographic assignment base_agi = sim.calculate("adjusted_gross_income", map_to="household").values.astype( @@ -1522,20 +1628,23 @@ def run_calibration( target_filter = {} if domain_variables: target_filter["domain_variables"] = domain_variables - if target_config: - candidate_targets = builder._query_targets(target_filter) - filtered_targets = apply_target_config_to_targets( - candidate_targets, - target_config, - ) - if len(filtered_targets) == 0: - raise ValueError("Target config excluded all targets") - target_filter["target_ids"] = filtered_targets["target_id"].tolist() - logger.info( - "Build target config: selected %d / %d targets before matrix build", - len(filtered_targets), - len(candidate_targets), - ) + target_catalog = TargetCatalogReader( + engine=builder.engine, + time_period=time_period, + ).load(target_filter) + target_selection = TargetSelectionPolicy.from_config(target_config).select( + target_catalog, + target_config_identity=resolved_target_identity, + valid_variables=sim.tax_benefit_system.variables, + ) + if len(target_selection.targets_df) == 0: + raise ValueError("Target config excluded all targets") + target_filter["target_selection"] = target_selection + logger.info( + "Build target selection: selected %d / %d active targets", + len(target_selection.targets_df), + len(target_catalog.targets), + ) # Step 6: Build sparse calibration matrix do_rerandomize = not skip_takeup_rerandomize @@ -1587,6 +1696,7 @@ def run_calibration( X_sparse.shape, X_sparse.nnz, ) + target_selection = target_selection.with_matrix_order(targets_df, target_names) # Step 6b: Save the calibration package. By default this is the # minimal package selected by target_config.yaml; use @@ -1599,33 +1709,43 @@ def run_calibration( "base_n_records": n_records, "seed": seed, "created_at": _utc_now_isoformat(), - "target_config_path": target_config_path, + "target_config_path": ( + resolved_target_identity.path if resolved_target_identity else None + ), + "target_config_sha256": ( + resolved_target_identity.sha256 if resolved_target_identity else None + ), + "target_config_mode": ( + resolved_target_identity.mode if resolved_target_identity else "explicit" + ), "package_scope": "minimal" if target_config else "all_active_targets", "matrix_builder": "chunked" if chunked_matrix else "precompute", "chunk_size": chunk_size if chunked_matrix else None, "chunk_dir": chunk_dir if chunked_matrix else None, + "target_selection_sha256": target_selection.checksum, + "target_selection_n_targets": target_selection.n_selected_targets, } metadata.update(get_git_provenance()) from policyengine_us_data.utils.manifest import compute_file_checksum metadata["dataset_sha256"] = compute_file_checksum(Path(dataset_path)) metadata["db_sha256"] = compute_file_checksum(Path(db_path)) - if target_config_path: - metadata["target_config_sha256"] = compute_file_checksum( - Path(target_config_path) - ) initial_weights = compute_initial_weights(X_sparse, targets_df) if package_output_path: - package_payload = { - "X_sparse": X_sparse, - "targets_df": targets_df, - "target_names": target_names, - "metadata": metadata, - "initial_weights": initial_weights, - "cd_geoid": geography.cd_geoid, - "block_geoid": geography.block_geoid, - } + package_path = Path(package_output_path) + targets_path = package_path.with_name(CALIBRATION_TARGETS_FILENAME) + target_facets_path = package_path.with_name(CALIBRATION_TARGET_FACETS_FILENAME) + target_selection.write_artifacts(targets_path, target_facets_path) + package_payload = CalibrationPackagePayload( + X_sparse=X_sparse, + targets_df=targets_df, + target_names=target_names, + metadata=metadata, + initial_weights=initial_weights, + cd_geoid=geography.cd_geoid, + block_geoid=geography.block_geoid, + ) save_calibration_package( package_output_path, X_sparse, @@ -1643,14 +1763,16 @@ def run_calibration( completed_at = _utc_now_isoformat() write_calibration_package_contract( - package_path=Path(package_output_path), + package_path=package_path, dataset_path=Path(dataset_path), db_path=Path(db_path), package=package_payload, parameters=_calibration_package_contract_parameters( workers=workers, n_clones=n_clones, - target_config_path=target_config_path, + target_config_path=metadata["target_config_path"], + target_config_sha256=metadata["target_config_sha256"], + target_config_mode=metadata["target_config_mode"], skip_county=skip_county, skip_source_impute=skip_source_impute, skip_takeup_rerandomize=skip_takeup_rerandomize, @@ -1665,9 +1787,12 @@ def run_calibration( duration_s=round(time.time() - t0, 1), code_sha=metadata.get("git_commit"), package_version=metadata.get("package_version"), + target_metadata_path=targets_path, + target_facets_path=target_facets_path, + target_selection_summary=target_selection.summary(), ) validate_calibration_package_contract( - package_path=Path(package_output_path), + package_path=package_path, package=package_payload, dataset_path=Path(dataset_path), db_path=Path(db_path), @@ -1812,9 +1937,15 @@ def main(argv=None): target_config = None target_config_path = None + target_config_identity = resolve_target_config_identity( + args.target_config, + all_active_targets=args.all_active_targets, + ) if not args.all_active_targets: - target_config_path = args.target_config or str(DEFAULT_TARGET_CONFIG_PATH) - target_config = load_target_config(target_config_path) + target_config_path = target_config_identity.path + target_config = load_target_config( + target_config_identity.resolved_path or target_config_path + ) package_output_path = args.package_output if args.build_only and not package_output_path: @@ -1850,6 +1981,7 @@ def main(argv=None): skip_county=not args.county_level, target_config=target_config, target_config_path=target_config_path, + target_config_identity=target_config_identity, build_only=args.build_only, package_path=args.package_path, package_output_path=package_output_path, diff --git a/policyengine_us_data/calibration/unified_matrix_builder.py b/policyengine_us_data/calibration/unified_matrix_builder.py index 63551b54f..0e8c523a9 100644 --- a/policyengine_us_data/calibration/unified_matrix_builder.py +++ b/policyengine_us_data/calibration/unified_matrix_builder.py @@ -34,6 +34,10 @@ apply_op, get_geo_level, ) +from policyengine_us_data.calibration_package.targets import ( + TargetCatalogReader, + TargetSelectionResult, +) from policyengine_us_data.pipeline_metadata import pipeline_node from policyengine_us_data.pipeline_schema import PipelineNode from policyengine_us_data.utils.target_variables import ( @@ -2106,98 +2110,19 @@ def _get_target_overview_columns(self) -> set: def _query_targets(self, target_filter: dict) -> pd.DataFrame: """Query targets via target_overview view with best-period selection.""" - and_conditions = [] - - if "domain_variables" in target_filter: - dvs = target_filter["domain_variables"] - ph = ",".join(f"'{dv}'" for dv in dvs) - and_conditions.append(f"tv.domain_variable IN ({ph})") - - if "variables" in target_filter: - vs = ",".join(f"'{v}'" for v in target_filter["variables"]) - and_conditions.append(f"tv.variable IN ({vs})") - - if "target_ids" in target_filter: - ids = ",".join(map(str, target_filter["target_ids"])) - and_conditions.append(f"tv.target_id IN ({ids})") - - if "stratum_ids" in target_filter: - ids = ",".join(map(str, target_filter["stratum_ids"])) - and_conditions.append(f"tv.stratum_id IN ({ids})") - - if not and_conditions: - where_clause = "1=1" - else: - where_clause = " AND ".join(f"({c})" for c in and_conditions) - - if "reform_id" in self._get_target_overview_columns(): - query = f""" - WITH filtered_targets AS ( - SELECT tv.target_id, tv.stratum_id, tv.variable, tv.reform_id, - tv.value, tv.period, tv.geo_level, - tv.geographic_id, tv.domain_variable - FROM target_overview tv - WHERE tv.active = 1 - AND ({where_clause}) - ), - best_periods AS ( - SELECT stratum_id, variable, reform_id, - CASE - WHEN MAX(CASE WHEN period <= :time_period - THEN period END) IS NOT NULL - THEN MAX(CASE WHEN period <= :time_period - THEN period END) - ELSE MIN(period) - END as best_period - FROM filtered_targets - GROUP BY stratum_id, variable, reform_id - ) - SELECT ft.* - FROM filtered_targets ft - JOIN best_periods bp - ON ft.stratum_id = bp.stratum_id - AND ft.variable = bp.variable - AND ft.reform_id = bp.reform_id - AND ft.period = bp.best_period - ORDER BY ft.target_id - """ - else: - query = f""" - WITH filtered_targets AS ( - SELECT tv.target_id, tv.stratum_id, tv.variable, - 0 AS reform_id, tv.value, tv.period, tv.geo_level, - tv.geographic_id, tv.domain_variable - FROM target_overview tv - WHERE tv.active = 1 - AND ({where_clause}) - ), - best_periods AS ( - SELECT stratum_id, variable, - CASE - WHEN MAX(CASE WHEN period <= :time_period - THEN period END) IS NOT NULL - THEN MAX(CASE WHEN period <= :time_period - THEN period END) - ELSE MIN(period) - END as best_period - FROM filtered_targets - GROUP BY stratum_id, variable - ) - SELECT ft.* - FROM filtered_targets ft - JOIN best_periods bp - ON ft.stratum_id = bp.stratum_id - AND ft.variable = bp.variable - AND ft.period = bp.best_period - ORDER BY ft.target_id - """ - - with self.engine.connect() as conn: - return pd.read_sql( - query, - conn, - params={"time_period": self.time_period}, + target_selection = target_filter.get("target_selection") + if target_selection is not None: + if not isinstance(target_selection, TargetSelectionResult): + raise ValueError("target_selection must be a TargetSelectionResult") + return target_selection.targets_df.copy() + return ( + TargetCatalogReader( + engine=self.engine, + time_period=self.time_period, ) + .load(target_filter) + .targets + ) def get_district_agi_targets(self) -> Dict[str, float]: """Return current-law district AGI targets for geography assignment.""" diff --git a/policyengine_us_data/calibration_package/__init__.py b/policyengine_us_data/calibration_package/__init__.py new file mode 100644 index 000000000..438a694b8 --- /dev/null +++ b/policyengine_us_data/calibration_package/__init__.py @@ -0,0 +1,83 @@ +"""Stage 2 calibration-package specifications.""" + +from .specs import ( + CALIBRATION_PACKAGE_CONTRACT_FILENAME, + CALIBRATION_PACKAGE_FILENAME, + CALIBRATION_PACKAGE_METADATA_FILENAME, + CALIBRATION_PACKAGE_SUBSTAGE_ID, + CALIBRATION_TARGET_FACETS_FILENAME, + CALIBRATION_TARGETS_FILENAME, + CALIBRATION_REPORTS_DIRNAME, + DATASET_BUILD_OUTPUT_CONTRACT_FILENAME, + DEFAULT_TARGET_CONFIG_PATH, + MATRIX_BUILD_DIRNAME, + SOURCE_DATASET_FILENAME, + TARGET_CONFIG_IDENTITY_MODES, + TARGET_DATABASE_FILENAME, + CalibrationPackageArtifactPaths, + CalibrationPackageOutputBundle, + Stage2BuildContext, + Stage2InputBundle, + Stage2InputBundleError, + Stage2InputSource, + TargetConfigIdentity, + calibration_package_artifact_paths, + resolve_target_config_identity, + stage2_build_context_for_run, + stage2_input_bundle_from_artifacts_dir, + stage2_input_bundle_from_stage1_contract, + stage2_input_bundle_from_stage1_contract_path, +) +from .payload import ( + LEGACY_MISSING_GEOGRAPHY_WARNING, + REQUIRED_PACKAGE_KEYS, + CalibrationPackagePayload, + CalibrationPackageReader, + CalibrationPackageWriter, +) +from .targets import ( + TargetCatalog, + TargetCatalogReader, + TargetSelectionPolicy, + TargetSelectionResult, + target_facets_from_rows, +) + +__all__ = [ + "CALIBRATION_PACKAGE_CONTRACT_FILENAME", + "CALIBRATION_PACKAGE_FILENAME", + "CALIBRATION_PACKAGE_METADATA_FILENAME", + "CALIBRATION_PACKAGE_SUBSTAGE_ID", + "CALIBRATION_TARGET_FACETS_FILENAME", + "CALIBRATION_TARGETS_FILENAME", + "CALIBRATION_REPORTS_DIRNAME", + "DATASET_BUILD_OUTPUT_CONTRACT_FILENAME", + "DEFAULT_TARGET_CONFIG_PATH", + "MATRIX_BUILD_DIRNAME", + "SOURCE_DATASET_FILENAME", + "TARGET_CONFIG_IDENTITY_MODES", + "TARGET_DATABASE_FILENAME", + "CalibrationPackageArtifactPaths", + "CalibrationPackageOutputBundle", + "CalibrationPackagePayload", + "CalibrationPackageReader", + "CalibrationPackageWriter", + "LEGACY_MISSING_GEOGRAPHY_WARNING", + "Stage2BuildContext", + "Stage2InputBundle", + "Stage2InputBundleError", + "Stage2InputSource", + "REQUIRED_PACKAGE_KEYS", + "TargetConfigIdentity", + "TargetCatalog", + "TargetCatalogReader", + "TargetSelectionPolicy", + "TargetSelectionResult", + "calibration_package_artifact_paths", + "resolve_target_config_identity", + "stage2_build_context_for_run", + "stage2_input_bundle_from_artifacts_dir", + "stage2_input_bundle_from_stage1_contract", + "stage2_input_bundle_from_stage1_contract_path", + "target_facets_from_rows", +] diff --git a/policyengine_us_data/calibration_package/payload.py b/policyengine_us_data/calibration_package/payload.py new file mode 100644 index 000000000..5108e0b9e --- /dev/null +++ b/policyengine_us_data/calibration_package/payload.py @@ -0,0 +1,387 @@ +"""Typed reader and writer boundary for Stage 2 package payloads.""" + +from __future__ import annotations + +import json +import pickle +from collections.abc import Mapping +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from policyengine_us_data.pipeline_metadata import pipeline_node +from policyengine_us_data.pipeline_schema import PipelineNode +from policyengine_us_data.utils.geography_checksum import ( + canonical_geography_checksum, + hash_string_array, +) +from policyengine_us_data.utils.step_manifest import sha256_file + +from .specs import CALIBRATION_PACKAGE_FILENAME, CALIBRATION_PACKAGE_METADATA_FILENAME + +if TYPE_CHECKING: + from policyengine_us_data.stage_contracts.calibration_package_schema import ( + CalibrationPackageSummary, + GeographyAssignmentSummary, + ) + +REQUIRED_PACKAGE_KEYS: frozenset[str] = frozenset( + {"X_sparse", "targets_df", "target_names", "metadata"} +) +LEGACY_MISSING_GEOGRAPHY_WARNING = ( + "legacy packages without block_geoid/cd_geoid cannot prove geography assignment" +) + + +@pipeline_node( + PipelineNode( + id="stage2_payload_boundary", + label="Stage 2 Package Payload", + node_type="library", + description="Typed access to the calibration_package.pkl matrix, targets, metadata, geography arrays, and compatibility warnings.", + source_file="policyengine_us_data/calibration_package/payload.py", + status="current", + stability="moving", + pathways=["calibration_package"], + artifacts_in=[CALIBRATION_PACKAGE_FILENAME], + validation_commands=[ + "uv run pytest tests/unit/calibration_package/test_payload.py" + ], + ) +) +@dataclass(frozen=True, kw_only=True) +class CalibrationPackagePayload: + """Typed access to the dictionary persisted in `calibration_package.pkl`.""" + + X_sparse: Any + targets_df: Any + target_names: Any + metadata: Mapping[str, Any] + initial_weights: Any | None = None + cd_geoid: Any | None = None + block_geoid: Any | None = None + compatibility_warnings: tuple[str, ...] = () + + @classmethod + def from_mapping( + cls, + package: Mapping[str, Any], + *, + require_required_keys: bool = True, + ) -> "CalibrationPackagePayload": + """Validate and wrap a legacy package mapping.""" + + if not isinstance(package, Mapping): + raise ValueError("Calibration package pickle must contain a mapping") + missing = sorted(REQUIRED_PACKAGE_KEYS - set(package)) + if missing and require_required_keys: + raise ValueError(f"Calibration package missing required key: {missing[0]}") + metadata = package.get("metadata", {}) + if metadata is None: + metadata = {} + if not isinstance(metadata, Mapping): + raise ValueError("Calibration package metadata must be a mapping") + cd_geoid = package.get("cd_geoid") + block_geoid = package.get("block_geoid") + warnings: list[str] = [] + if cd_geoid is None and block_geoid is None: + warnings.append(LEGACY_MISSING_GEOGRAPHY_WARNING) + return cls( + X_sparse=package.get("X_sparse"), + targets_df=package.get("targets_df"), + target_names=package.get("target_names"), + metadata=dict(metadata), + initial_weights=package.get("initial_weights"), + cd_geoid=cd_geoid, + block_geoid=block_geoid, + compatibility_warnings=tuple(warnings), + ) + + def to_mapping(self) -> dict[str, Any]: + """Return the pickle-compatible package dictionary.""" + + return { + "X_sparse": self.X_sparse, + "targets_df": self.targets_df, + "target_names": self.target_names, + "metadata": dict(self.metadata), + "initial_weights": self.initial_weights, + "cd_geoid": self.cd_geoid, + "block_geoid": self.block_geoid, + } + + def summary(self) -> CalibrationPackageSummary: + """Return the contract-safe package summary.""" + + from policyengine_us_data.stage_contracts.calibration_package_schema import ( + CalibrationPackageSummary, + ) + + try: + n_targets, n_columns = self.X_sparse.shape + except (AttributeError, ValueError) as exc: + raise ValueError("X_sparse must expose a two-dimensional shape") from exc + if not hasattr(self.X_sparse, "nnz"): + raise ValueError("X_sparse must expose nnz") + + n_targets = int(n_targets) + n_columns = int(n_columns) + nnz = int(self.X_sparse.nnz) + density = nnz / (n_targets * n_columns) if n_targets * n_columns else 0.0 + + return CalibrationPackageSummary( + matrix_shape=(n_targets, n_columns), + matrix_nnz=nnz, + matrix_density=float(density), + n_targets=int(len(self.targets_df)), + n_columns=n_columns, + target_name_count=int(len(self.target_names)), + dataset_sha256=self.metadata_string("dataset_sha256"), + db_sha256=self.metadata_string("db_sha256"), + target_config_path=self.metadata_string("target_config_path"), + target_config_sha256=self.metadata_string("target_config_sha256"), + n_clones=self.metadata_int("n_clones"), + seed=self.metadata_int("seed"), + base_n_records=self.metadata_int("base_n_records"), + package_scope=self.metadata_string("package_scope"), + matrix_builder=self.metadata_string("matrix_builder"), + chunk_size=self.metadata_int("chunk_size"), + chunk_dir=self.metadata_string("chunk_dir"), + has_initial_weights=self.initial_weights is not None, + has_cd_geoid=self.cd_geoid is not None, + has_block_geoid=self.block_geoid is not None, + cd_geoid_length=_optional_len(self.cd_geoid), + block_geoid_length=_optional_len(self.block_geoid), + ) + + def geography_summary(self) -> GeographyAssignmentSummary: + """Return the contract-safe geography assignment summary.""" + + from policyengine_us_data.stage_contracts.calibration_package_schema import ( + GeographyAssignmentSummary, + ) + + n_records = self.metadata_int("base_n_records") + n_clones = self.metadata_int("n_clones") + has_blocks = self.block_geoid is not None + has_cds = self.cd_geoid is not None + + if not has_blocks and not has_cds: + return GeographyAssignmentSummary( + source_kind="unavailable", + n_records=n_records, + n_clones=n_clones, + n_rows=None, + has_block_geoid=False, + has_cd_geoid=False, + block_geoid_length=None, + cd_geoid_length=None, + block_geoid_sha256=None, + cd_geoid_sha256=None, + canonical_geography_sha256=None, + ) + if not has_blocks or not has_cds: + raise ValueError( + "Calibration package geography requires both block_geoid and cd_geoid" + ) + if n_records is None or n_clones is None: + raise ValueError( + "Calibration package geography requires metadata base_n_records and n_clones" + ) + if n_records <= 0 or n_clones <= 0: + raise ValueError( + "Calibration package geography requires positive base_n_records and n_clones" + ) + + block_geoids = _one_dimensional_string_array(self.block_geoid, "block_geoid") + cd_geoids = _one_dimensional_string_array(self.cd_geoid, "cd_geoid") + n_rows = int(len(block_geoids)) + if n_rows == 0: + raise ValueError("Calibration package geography arrays must be non-empty") + if len(cd_geoids) != n_rows: + raise ValueError( + "Calibration package geography has mismatched block_geoid and cd_geoid " + f"lengths: {n_rows} != {len(cd_geoids)}" + ) + if n_records * n_clones != n_rows: + raise ValueError( + "Calibration package geography length does not match metadata: " + f"{n_rows} rows for {n_records} records x {n_clones} clones" + ) + + return GeographyAssignmentSummary( + source_kind="calibration_package", + n_records=n_records, + n_clones=n_clones, + n_rows=n_rows, + has_block_geoid=True, + has_cd_geoid=True, + block_geoid_length=n_rows, + cd_geoid_length=int(len(cd_geoids)), + block_geoid_sha256=hash_string_array(block_geoids), + cd_geoid_sha256=hash_string_array(cd_geoids), + canonical_geography_sha256=canonical_geography_checksum( + block_geoid=block_geoids, + cd_geoid=cd_geoids, + n_records=n_records, + n_clones=n_clones, + ), + ) + + def metadata_string(self, key: str) -> str | None: + """Return a metadata value coerced to a string, preserving nulls.""" + + value = self.metadata.get(key) + if value is None: + return None + return str(value) + + def metadata_int(self, key: str) -> int | None: + """Return a metadata value coerced to an integer, preserving nulls.""" + + value = self.metadata.get(key) + if value is None: + return None + if isinstance(value, bool): + raise ValueError(f"Calibration package metadata {key!r} must be an integer") + return int(value) + + +@pipeline_node( + PipelineNode( + id="stage2_payload_reader", + label="Stage 2 Payload Reader", + node_type="library", + description="Load calibration_package.pkl through the typed Stage 2 payload boundary and expose checksum/summary material.", + source_file="policyengine_us_data/calibration_package/payload.py", + status="current", + stability="moving", + pathways=["calibration_package"], + artifacts_in=[CALIBRATION_PACKAGE_FILENAME], + validation_commands=[ + "uv run pytest tests/unit/calibration_package/test_payload.py" + ], + ) +) +@dataclass(frozen=True, kw_only=True) +class CalibrationPackageReader: + """Read typed Stage 2 package payloads from disk.""" + + package_path: Path + + def read(self) -> CalibrationPackagePayload: + """Load and validate the persisted package payload.""" + + with Path(self.package_path).open("rb") as handle: + package = pickle.load(handle) + return CalibrationPackagePayload.from_mapping(package) + + def checksum(self) -> str: + """Return the package file checksum used for reuse comparisons.""" + + return f"sha256:{sha256_file(Path(self.package_path))}" + + def summary(self) -> CalibrationPackageSummary: + """Read the package and return its summary.""" + + return self.read().summary() + + +@pipeline_node( + PipelineNode( + id="stage2_payload_writer", + label="Stage 2 Payload Writer", + node_type="library", + description="Persist calibration_package.pkl and derive calibration_package_meta.json from typed payload and contract material.", + source_file="policyengine_us_data/calibration_package/payload.py", + status="current", + stability="moving", + pathways=["calibration_package"], + artifacts_out=[ + CALIBRATION_PACKAGE_FILENAME, + CALIBRATION_PACKAGE_METADATA_FILENAME, + ], + validation_commands=[ + "uv run pytest tests/unit/calibration_package/test_payload.py" + ], + ) +) +@dataclass(frozen=True, kw_only=True) +class CalibrationPackageWriter: + """Write typed Stage 2 package payloads and metadata sidecars.""" + + package_path: Path + + def write(self, payload: CalibrationPackagePayload) -> Path: + """Persist a package payload using the legacy pickle format.""" + + output_path = Path(self.package_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("wb") as handle: + pickle.dump(payload.to_mapping(), handle, protocol=pickle.HIGHEST_PROTOCOL) + return output_path + + def write_metadata_sidecar( + self, + payload: CalibrationPackagePayload, + *, + contract: Any | None = None, + sidecar_path: str | Path | None = None, + ) -> Path: + """Write `calibration_package_meta.json` from typed payload material.""" + + output_path = ( + Path(sidecar_path) + if sidecar_path is not None + else Path(self.package_path).with_name( + CALIBRATION_PACKAGE_METADATA_FILENAME + ) + ) + sidecar_payload = { + **dict(payload.metadata), + "package_sha256": f"sha256:{sha256_file(Path(self.package_path))}", + "package_summary": payload.summary().to_dict(), + "geography_assignment": payload.geography_summary().to_dict(), + "compatibility_warnings": list(payload.compatibility_warnings), + } + if contract is not None: + sidecar_payload["contract"] = { + "stage_id": getattr(contract, "stage_id", None), + "contract_type": getattr(contract, "contract_type", None), + "fingerprint": ( + contract.fingerprint.to_dict() + if getattr(contract, "fingerprint", None) is not None + else None + ), + } + output_path.write_text( + json.dumps(sidecar_payload, indent=2, sort_keys=True), + encoding="utf-8", + ) + return output_path + + +def _optional_len(value: Any) -> int | None: + if value is None: + return None + return int(len(value)) + + +def _one_dimensional_string_array(value: Any, key: str) -> Any: + import numpy as np + + array = np.asarray(value, dtype=str) + if array.ndim != 1: + raise ValueError(f"Calibration package geography {key} must be one-dimensional") + if np.any(array == ""): + raise ValueError(f"Calibration package geography {key} contains empty values") + return array + + +__all__ = [ + "LEGACY_MISSING_GEOGRAPHY_WARNING", + "REQUIRED_PACKAGE_KEYS", + "CalibrationPackagePayload", + "CalibrationPackageReader", + "CalibrationPackageWriter", +] diff --git a/policyengine_us_data/calibration_package/specs.py b/policyengine_us_data/calibration_package/specs.py new file mode 100644 index 000000000..4f0c2bcb5 --- /dev/null +++ b/policyengine_us_data/calibration_package/specs.py @@ -0,0 +1,528 @@ +"""Shared Stage 2 calibration-package identity and artifact specifications.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal +from urllib.parse import unquote, urlparse + +from policyengine_us_data.pipeline_metadata import pipeline_node +from policyengine_us_data.pipeline_schema import PipelineNode +from policyengine_us_data.utils.manifest import compute_file_checksum + +if TYPE_CHECKING: + from policyengine_us_data.stage_contracts import StageContract, ValidationReport + +DEFAULT_TARGET_CONFIG_PATH = "policyengine_us_data/calibration/target_config.yaml" +SOURCE_DATASET_FILENAME = "source_imputed_stratified_extended_cps.h5" +TARGET_DATABASE_FILENAME = "policy_data.db" +DATASET_BUILD_OUTPUT_CONTRACT_FILENAME = "dataset_build_output.json" +CALIBRATION_PACKAGE_FILENAME = "calibration_package.pkl" +CALIBRATION_PACKAGE_METADATA_FILENAME = "calibration_package_meta.json" +CALIBRATION_PACKAGE_CONTRACT_FILENAME = "calibration_package_contract.json" +CALIBRATION_TARGETS_FILENAME = "calibration_targets.jsonl" +CALIBRATION_TARGET_FACETS_FILENAME = "calibration_target_facets.json" +CALIBRATION_REPORTS_DIRNAME = "calibration_reports" +MATRIX_BUILD_DIRNAME = "matrix_build" +CALIBRATION_PACKAGE_SUBSTAGE_ID = "2a_matrix_build_calibration_target_construction" + +TargetConfigMode = Literal["default", "explicit", "all_active_targets"] +Stage2InputSource = Literal["stage1_contract", "artifacts_dir_fallback"] +TARGET_CONFIG_IDENTITY_MODES: frozenset[str] = frozenset( + {"default", "explicit", "all_active_targets"} +) +_SOURCE_DATASET_LOGICAL_NAMES = ( + "source_imputed_stratified_extended_cps", + "source_imputed_stratified_extended_cps_2024", +) +_TARGET_DATABASE_LOGICAL_NAMES = ("policy_data_db",) + + +@dataclass(frozen=True, kw_only=True) +class TargetConfigIdentity: + """Checksum-backed identity for the Stage 2 target selection config.""" + + path: str | None + sha256: str | None + mode: TargetConfigMode + resolved_path: str | None = None + + def __post_init__(self) -> None: + if self.mode not in TARGET_CONFIG_IDENTITY_MODES: + raise ValueError(f"Unknown target config identity mode: {self.mode!r}") + if self.mode == "all_active_targets": + if self.path is not None or self.sha256 is not None: + raise ValueError( + "all_active_targets target config identity cannot include " + "a path or checksum" + ) + return + if not self.path: + raise ValueError(f"{self.mode} target config identity requires a path") + if not self.sha256: + raise ValueError(f"{self.mode} target config identity requires a checksum") + + def to_parameters(self) -> dict[str, str | None]: + """Return the identity fields used in Stage 2 reuse parameters.""" + + return { + "target_config": self.path, + "target_config_sha256": self.sha256, + "target_config_mode": self.mode, + } + + +@dataclass(frozen=True, kw_only=True) +class Stage2InputBundle: + """Canonical Stage 2 input artifacts resolved for one run.""" + + artifacts_dir: Path + source_dataset: Path + target_database: Path + source: Stage2InputSource + stage1_contract_path: Path | None = None + stage1_contract_run_id: str | None = None + + @property + def manifest_inputs(self) -> dict[str, Path]: + """Return input paths recorded in Stage 2 step manifests.""" + + return { + "dataset": self.source_dataset, + "database": self.target_database, + } + + @property + def compatibility_only(self) -> bool: + """Return whether the bundle came from legacy filename discovery.""" + + return self.source == "artifacts_dir_fallback" + + def missing_required_artifacts(self) -> tuple[tuple[str, Path], ...]: + """Return missing required Stage 2 input labels and paths.""" + + missing: list[tuple[str, Path]] = [] + for label, path in self.manifest_inputs.items(): + if not path.exists(): + missing.append((label, path)) + return tuple(missing) + + def validation_report(self) -> "ValidationReport": + """Return a canonical validation report for Stage 2 input readiness.""" + + from policyengine_us_data.stage_contracts.validation import ( + ValidationFinding, + ValidationReport, + ) + + missing = self.missing_required_artifacts() + if missing: + findings = tuple( + ValidationFinding( + check_id=f"stage2_input_exists:{label}", + status="fail", + message=f"Missing Stage 2 {label} artifact: {path}", + metadata={ + "artifact_label": label, + "path": str(path), + "source": self.source, + }, + ) + for label, path in missing + ) + return ValidationReport( + status="fail", + findings=findings, + metadata=self._validation_metadata(), + ) + return ValidationReport( + status="pass", + findings=(), + metadata=self._validation_metadata(), + ) + + def require_existing(self) -> "Stage2InputBundle": + """Raise a structured error when required Stage 2 inputs are missing.""" + + report = self.validation_report() + if report.status != "pass": + raise Stage2InputBundleError(self, report) + return self + + def _validation_metadata(self) -> dict[str, Any]: + metadata: dict[str, Any] = { + "source": self.source, + "artifacts_dir": str(self.artifacts_dir), + "compatibility_only": self.compatibility_only, + } + if self.stage1_contract_path is not None: + metadata["stage1_contract_path"] = str(self.stage1_contract_path) + if self.stage1_contract_run_id is not None: + metadata["stage1_contract_run_id"] = self.stage1_contract_run_id + return metadata + + +class Stage2InputBundleError(FileNotFoundError): + """Input validation failure raised before Stage 2 package work starts.""" + + def __init__( + self, + bundle: Stage2InputBundle, + validation_report: "ValidationReport", + ) -> None: + missing = ", ".join( + f"{label}: {path}" for label, path in bundle.missing_required_artifacts() + ) + super().__init__(f"Missing Stage 2 input artifact(s): {missing}") + self.bundle = bundle + self.validation_report = validation_report + + +@dataclass(frozen=True, kw_only=True) +class CalibrationPackageOutputBundle: + """Canonical run-scoped Stage 2 output artifact paths.""" + + artifacts_dir: Path + package: Path + metadata: Path + contract: Path + targets: Path + target_facets: Path + reports_dir: Path + matrix_build_dir: Path + + @property + def manifest_outputs(self) -> tuple[Path, Path, Path, Path]: + """Return the durable Stage 2 outputs recorded in step manifests.""" + + return (self.package, self.contract, self.targets, self.target_facets) + + +CalibrationPackageArtifactPaths = CalibrationPackageOutputBundle + + +@dataclass(frozen=True, kw_only=True) +class Stage2BuildContext: + """Run-scoped Stage 2 input and output bundles.""" + + artifacts_dir: Path + input_bundle: Stage2InputBundle + output_bundle: CalibrationPackageOutputBundle + run_id: str | None = None + + def require_inputs(self) -> "Stage2BuildContext": + """Validate inputs and return this context when Stage 2 may start.""" + + self.input_bundle.require_existing() + return self + + +@pipeline_node( + PipelineNode( + id="stage2_input_bundle", + label="Stage 2 Input Bundle", + node_type="library", + description="Resolve the source-imputed dataset and policy target database from a Stage 1 contract or compatibility filename fallback.", + source_file="policyengine_us_data/calibration_package/specs.py", + status="current", + stability="moving", + pathways=["calibration_package"], + artifacts_in=[ + DATASET_BUILD_OUTPUT_CONTRACT_FILENAME, + SOURCE_DATASET_FILENAME, + TARGET_DATABASE_FILENAME, + ], + validation_commands=[ + "uv run pytest tests/unit/calibration_package/test_specs.py" + ], + ) +) +def stage2_input_bundle_from_artifacts_dir( + artifacts_dir: str | Path, +) -> Stage2InputBundle: + """Return a compatibility Stage 2 input bundle from canonical filenames.""" + + root = Path(artifacts_dir) + return Stage2InputBundle( + artifacts_dir=root, + source_dataset=root / SOURCE_DATASET_FILENAME, + target_database=root / TARGET_DATABASE_FILENAME, + source="artifacts_dir_fallback", + ) + + +def stage2_input_bundle_from_stage1_contract( + contract: "StageContract", + *, + artifacts_dir: str | Path | None = None, + contract_path: str | Path | None = None, +) -> Stage2InputBundle: + """Return a Stage 2 input bundle from a Stage 1 handoff contract.""" + + if getattr(contract, "stage_id", None) != "1_build_datasets": + raise ValueError("Stage 2 inputs require a Stage 1 dataset-build contract") + source_dataset = _contract_artifact_path( + contract, + logical_names=_SOURCE_DATASET_LOGICAL_NAMES, + label="source dataset", + ) + target_database = _contract_artifact_path( + contract, + logical_names=_TARGET_DATABASE_LOGICAL_NAMES, + label="target database", + ) + root = Path(artifacts_dir) if artifacts_dir is not None else source_dataset.parent + return Stage2InputBundle( + artifacts_dir=root, + source_dataset=source_dataset, + target_database=target_database, + source="stage1_contract", + stage1_contract_path=Path(contract_path) if contract_path is not None else None, + stage1_contract_run_id=getattr(contract, "run_id", None), + ) + + +def stage2_input_bundle_from_stage1_contract_path( + contract_path: str | Path, + *, + artifacts_dir: str | Path | None = None, +) -> Stage2InputBundle: + """Read a Stage 1 handoff contract and return the Stage 2 input bundle.""" + + from policyengine_us_data.stage_contracts.io import read_contract + + contract_file = Path(contract_path) + return stage2_input_bundle_from_stage1_contract( + read_contract(contract_file), + artifacts_dir=artifacts_dir, + contract_path=contract_file, + ) + + +@pipeline_node( + PipelineNode( + id="stage2_build_context", + label="Stage 2 Build Context", + node_type="library", + description="Bind one run_id to canonical Stage 2 input and output bundles before remote package construction starts.", + source_file="policyengine_us_data/calibration_package/specs.py", + status="current", + stability="moving", + pathways=["calibration_package"], + artifacts_in=[ + DATASET_BUILD_OUTPUT_CONTRACT_FILENAME, + SOURCE_DATASET_FILENAME, + TARGET_DATABASE_FILENAME, + ], + artifacts_out=[ + CALIBRATION_PACKAGE_FILENAME, + CALIBRATION_PACKAGE_CONTRACT_FILENAME, + ], + validation_commands=[ + "uv run pytest tests/unit/calibration_package/test_specs.py" + ], + ) +) +def stage2_build_context_for_run( + pipeline_mount: str | Path, + run_id: str | None = "", + *, + stage1_contract_path: str | Path | None = None, +) -> Stage2BuildContext: + """Return Stage 2 run context, preferring the Stage 1 handoff contract.""" + + artifacts_dir = Path(pipeline_mount) / "artifacts" + if run_id: + artifacts_dir = artifacts_dir / run_id + contract_path = ( + Path(stage1_contract_path) + if stage1_contract_path is not None + else artifacts_dir / DATASET_BUILD_OUTPUT_CONTRACT_FILENAME + ) + if contract_path.exists(): + input_bundle = stage2_input_bundle_from_stage1_contract_path( + contract_path, + artifacts_dir=artifacts_dir, + ) + else: + input_bundle = stage2_input_bundle_from_artifacts_dir(artifacts_dir) + return Stage2BuildContext( + artifacts_dir=artifacts_dir, + input_bundle=input_bundle, + output_bundle=calibration_package_artifact_paths(artifacts_dir), + run_id=run_id or None, + ) + + +@pipeline_node( + PipelineNode( + id="stage2_artifact_specs", + label="Stage 2 Artifact Specs", + node_type="library", + description="Centralize Stage 2 input, package, contract, metadata, report, and matrix-build artifact names.", + source_file="policyengine_us_data/calibration_package/specs.py", + status="current", + stability="moving", + pathways=["calibration_package"], + artifacts_in=[ + SOURCE_DATASET_FILENAME, + TARGET_DATABASE_FILENAME, + ], + artifacts_out=[ + CALIBRATION_PACKAGE_FILENAME, + CALIBRATION_PACKAGE_METADATA_FILENAME, + CALIBRATION_PACKAGE_CONTRACT_FILENAME, + CALIBRATION_TARGETS_FILENAME, + CALIBRATION_TARGET_FACETS_FILENAME, + ], + validation_commands=[ + "uv run pytest tests/unit/calibration_package/test_specs.py" + ], + ) +) +def calibration_package_artifact_paths( + artifacts_dir: str | Path, +) -> CalibrationPackageOutputBundle: + """Return canonical Stage 2 paths rooted in an artifacts directory.""" + + root = Path(artifacts_dir) + return CalibrationPackageOutputBundle( + artifacts_dir=root, + package=root / CALIBRATION_PACKAGE_FILENAME, + metadata=root / CALIBRATION_PACKAGE_METADATA_FILENAME, + contract=root / CALIBRATION_PACKAGE_CONTRACT_FILENAME, + targets=root / CALIBRATION_TARGETS_FILENAME, + target_facets=root / CALIBRATION_TARGET_FACETS_FILENAME, + reports_dir=root / CALIBRATION_REPORTS_DIRNAME, + matrix_build_dir=root / MATRIX_BUILD_DIRNAME, + ) + + +@pipeline_node( + PipelineNode( + id="stage2_target_config_identity", + label="Stage 2 Target Config Identity", + node_type="library", + description="Resolve the effective Stage 2 target config path and checksum before package reuse or rebuild.", + source_file="policyengine_us_data/calibration_package/specs.py", + status="current", + stability="moving", + pathways=["calibration_package"], + artifacts_in=[DEFAULT_TARGET_CONFIG_PATH], + validation_commands=[ + "uv run pytest tests/unit/calibration_package/test_specs.py" + ], + ) +) +def resolve_target_config_identity( + target_config_path: str | Path | None = None, + *, + all_active_targets: bool = False, + repo_root: str | Path | None = None, +) -> TargetConfigIdentity: + """Resolve the target config identity used by Stage 2 package construction.""" + + if all_active_targets: + if target_config_path is not None: + raise ValueError( + "--all-active-targets cannot be combined with a target config path" + ) + return TargetConfigIdentity( + path=None, + sha256=None, + mode="all_active_targets", + resolved_path=None, + ) + + root = Path(repo_root).resolve() if repo_root is not None else _repo_root() + mode: TargetConfigMode = "explicit" if target_config_path is not None else "default" + identity_path = Path(target_config_path or DEFAULT_TARGET_CONFIG_PATH) + resolved_path = _resolve_existing_config_path(identity_path, root) + logical_path = ( + DEFAULT_TARGET_CONFIG_PATH + if mode == "default" + else _logical_identity_path(identity_path, resolved_path, root) + ) + return TargetConfigIdentity( + path=logical_path, + sha256=compute_file_checksum(resolved_path), + mode=mode, + resolved_path=str(resolved_path), + ) + + +def _repo_root() -> Path: + return Path(__file__).resolve().parents[2] + + +def _resolve_existing_config_path(path: Path, repo_root: Path) -> Path: + candidates = [path] if path.is_absolute() else [repo_root / path, Path.cwd() / path] + for candidate in candidates: + resolved = candidate.resolve() + if resolved.exists() and resolved.is_file(): + return resolved + raise FileNotFoundError(f"Target config not found: {path}") + + +def _logical_identity_path(path: Path, resolved_path: Path, repo_root: Path) -> str: + try: + return resolved_path.relative_to(repo_root).as_posix() + except ValueError: + return resolved_path.as_posix() if path.is_absolute() else path.as_posix() + + +def _contract_artifact_path( + contract: "StageContract", + *, + logical_names: tuple[str, ...], + label: str, +) -> Path: + for logical_name in logical_names: + for artifact in getattr(contract, "outputs", ()): + if getattr(artifact, "logical_name", None) == logical_name: + return _artifact_uri_to_path(getattr(artifact, "uri")) + raise ValueError( + f"Stage 1 contract is missing required Stage 2 {label}: " + + " or ".join(logical_names) + ) + + +def _artifact_uri_to_path(uri: str) -> Path: + parsed = urlparse(uri) + if parsed.scheme == "file": + return Path(unquote(parsed.path)) + if not parsed.scheme: + return Path(uri) + raise ValueError(f"Unsupported artifact URI scheme for Stage 2 input: {uri}") + + +__all__ = [ + "CALIBRATION_PACKAGE_CONTRACT_FILENAME", + "CALIBRATION_PACKAGE_FILENAME", + "CALIBRATION_PACKAGE_METADATA_FILENAME", + "CALIBRATION_PACKAGE_SUBSTAGE_ID", + "CALIBRATION_TARGET_FACETS_FILENAME", + "CALIBRATION_TARGETS_FILENAME", + "CALIBRATION_REPORTS_DIRNAME", + "DATASET_BUILD_OUTPUT_CONTRACT_FILENAME", + "DEFAULT_TARGET_CONFIG_PATH", + "MATRIX_BUILD_DIRNAME", + "SOURCE_DATASET_FILENAME", + "TARGET_CONFIG_IDENTITY_MODES", + "TARGET_DATABASE_FILENAME", + "CalibrationPackageArtifactPaths", + "CalibrationPackageOutputBundle", + "Stage2BuildContext", + "Stage2InputBundle", + "Stage2InputBundleError", + "Stage2InputSource", + "TargetConfigIdentity", + "TargetConfigMode", + "calibration_package_artifact_paths", + "resolve_target_config_identity", + "stage2_build_context_for_run", + "stage2_input_bundle_from_artifacts_dir", + "stage2_input_bundle_from_stage1_contract", + "stage2_input_bundle_from_stage1_contract_path", +] diff --git a/policyengine_us_data/calibration_package/targets.py b/policyengine_us_data/calibration_package/targets.py new file mode 100644 index 000000000..3abbb1ef1 --- /dev/null +++ b/policyengine_us_data/calibration_package/targets.py @@ -0,0 +1,698 @@ +"""Target catalog and selection artifacts for Stage 2 package builds.""" + +from __future__ import annotations + +import hashlib +import json +from collections import Counter +from collections.abc import Iterable, Mapping +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import pandas as pd +from sqlalchemy import create_engine, text + +from policyengine_us_data.pipeline_metadata import pipeline_node +from policyengine_us_data.pipeline_schema import PipelineNode +from policyengine_us_data.utils.target_variables import target_variable_components + +from .specs import ( + CALIBRATION_TARGET_FACETS_FILENAME, + CALIBRATION_TARGETS_FILENAME, + TargetConfigIdentity, +) + +TARGET_CATALOG_COLUMNS: tuple[str, ...] = ( + "target_id", + "stratum_id", + "variable", + "reform_id", + "value", + "period", + "geo_level", + "geographic_id", + "domain_variable", + "source", + "notes", +) +GEO_CONSTRAINT_VARIABLES = frozenset( + { + "state_fips", + "congressional_district_geoid", + "ucgid_str", + } +) +TARGET_OVERVIEW_VIEW = """\ +CREATE VIEW IF NOT EXISTS target_overview AS +SELECT + t.target_id, + t.stratum_id, + t.variable, + t.reform_id, + t.value, + t.period, + t.active, + CASE + WHEN MAX(CASE + WHEN sc.constraint_variable = 'congressional_district_geoid' + THEN 1 + WHEN sc.constraint_variable = 'ucgid_str' + AND length(sc.value) = 13 THEN 1 + ELSE 0 END) = 1 THEN 'district' + WHEN MAX(CASE + WHEN sc.constraint_variable = 'state_fips' THEN 1 + WHEN sc.constraint_variable = 'ucgid_str' + AND length(sc.value) = 11 THEN 1 + ELSE 0 END) = 1 THEN 'state' + ELSE 'national' + END AS geo_level, + COALESCE( + MAX(CASE + WHEN sc.constraint_variable + = 'congressional_district_geoid' + THEN sc.value END), + MAX(CASE + WHEN sc.constraint_variable = 'state_fips' + THEN sc.value END), + MAX(CASE + WHEN sc.constraint_variable = 'ucgid_str' + THEN sc.value END), + 'US' + ) AS geographic_id, + ( + SELECT GROUP_CONCAT(cv, ',') + FROM ( + SELECT DISTINCT sc2.constraint_variable AS cv + FROM stratum_constraints sc2 + WHERE sc2.stratum_id = t.stratum_id + AND sc2.constraint_variable NOT IN ( + 'state_fips', 'congressional_district_geoid', + 'tax_unit_is_filer', 'ucgid_str' + ) + ORDER BY sc2.constraint_variable + ) + ) AS domain_variable +FROM targets t +LEFT JOIN stratum_constraints sc ON t.stratum_id = sc.stratum_id +GROUP BY t.target_id, t.stratum_id, t.variable, + t.reform_id, t.value, t.period, t.active; +""" + + +@pipeline_node( + PipelineNode( + id="stage2_target_catalog_reader", + label="Stage 2 Target Catalog Reader", + node_type="library", + description="Read active and disabled calibration targets plus stratum constraints from policy_data.db.", + source_file="policyengine_us_data/calibration_package/targets.py", + status="current", + stability="moving", + pathways=["calibration_package"], + artifacts_in=["policy_data.db"], + validation_commands=[ + "uv run pytest tests/unit/calibration_package/test_targets.py" + ], + ) +) +@dataclass(frozen=True, kw_only=True) +class TargetCatalogReader: + """Read the Stage 2 target catalog from the calibration target database.""" + + time_period: int + db_uri: str | None = None + engine: Any | None = None + + def load(self, target_filter: Mapping[str, Any] | None = None) -> "TargetCatalog": + """Load selected active targets and disabled target rows.""" + + if self.engine is None and self.db_uri is None: + raise ValueError("TargetCatalogReader requires db_uri or engine") + engine = self.engine if self.engine is not None else create_engine(self.db_uri) + owns_engine = self.engine is None + try: + _ensure_target_overview(engine) + target_columns = _table_columns(engine, "targets") + view_columns = _table_columns(engine, "target_overview") + targets = _query_targets( + engine, + time_period=self.time_period, + target_filter=target_filter or {}, + active_only=True, + target_columns=target_columns, + view_columns=view_columns, + ) + disabled_targets = _query_targets( + engine, + time_period=self.time_period, + target_filter=target_filter or {}, + active_only=False, + target_columns=target_columns, + view_columns=view_columns, + ) + constraints_by_stratum = _load_constraints_by_stratum(engine) + return TargetCatalog( + targets=targets, + disabled_targets=disabled_targets, + constraints_by_stratum=constraints_by_stratum, + ) + finally: + if owns_engine: + engine.dispose() + + +@dataclass(frozen=True, kw_only=True) +class TargetCatalog: + """Targets and stratum constraints available to Stage 2 selection.""" + + targets: pd.DataFrame + disabled_targets: pd.DataFrame = field(default_factory=pd.DataFrame) + constraints_by_stratum: Mapping[int, tuple[Mapping[str, Any], ...]] = field( + default_factory=dict + ) + + @classmethod + def from_targets( + cls, + targets: pd.DataFrame, + *, + disabled_targets: pd.DataFrame | None = None, + constraints_by_stratum: Mapping[int, Iterable[Mapping[str, Any]]] | None = None, + ) -> "TargetCatalog": + """Create a catalog from in-memory target rows.""" + + return cls( + targets=_normalize_target_frame(targets), + disabled_targets=_normalize_target_frame( + disabled_targets if disabled_targets is not None else pd.DataFrame() + ), + constraints_by_stratum={ + int(key): tuple(dict(item) for item in value) + for key, value in (constraints_by_stratum or {}).items() + }, + ) + + def constraints_for(self, stratum_id: int) -> tuple[Mapping[str, Any], ...]: + """Return deterministic constraints for a stratum.""" + + return self.constraints_by_stratum.get(int(stratum_id), ()) + + +@pipeline_node( + PipelineNode( + id="stage2_target_selection_policy", + label="Stage 2 Target Selection Policy", + node_type="library", + description="Apply target config include/exclude rules and validate additive target expressions before matrix construction.", + source_file="policyengine_us_data/calibration_package/targets.py", + status="current", + stability="moving", + pathways=["calibration_package"], + artifacts_in=["policy_data.db"], + validation_commands=[ + "uv run pytest tests/unit/calibration_package/test_targets.py" + ], + ) +) +@dataclass(frozen=True, kw_only=True) +class TargetSelectionPolicy: + """Target include/exclude rules applied before matrix materialization.""" + + include_rules: tuple[Mapping[str, Any], ...] = () + exclude_rules: tuple[Mapping[str, Any], ...] = () + + @classmethod + def from_config(cls, config: Mapping[str, Any] | None) -> "TargetSelectionPolicy": + """Create a policy from the Stage 2 target config mapping.""" + + config = config or {} + return cls( + include_rules=tuple(dict(rule) for rule in config.get("include", ())), + exclude_rules=tuple(dict(rule) for rule in config.get("exclude", ())), + ) + + def select( + self, + catalog: TargetCatalog, + *, + target_config_identity: TargetConfigIdentity | None = None, + valid_variables: Iterable[str] | Mapping[str, Any] | None = None, + ) -> "TargetSelectionResult": + """Apply this policy to a target catalog.""" + + targets = _normalize_target_frame(catalog.targets) + _validate_target_expressions(targets, valid_variables) + keep_mask = pd.Series(True, index=targets.index) + if self.include_rules: + keep_mask = _match_rules(targets, self.include_rules) + if self.exclude_rules: + keep_mask &= ~_match_rules(targets, self.exclude_rules) + selected = targets.loc[keep_mask].reset_index(drop=True) + disabled = _normalize_target_frame(catalog.disabled_targets) + return TargetSelectionResult( + targets_df=selected, + disabled_targets_df=disabled, + constraints_by_stratum=catalog.constraints_by_stratum, + target_config_path=( + target_config_identity.path + if target_config_identity is not None + else None + ), + target_config_sha256=( + target_config_identity.sha256 + if target_config_identity is not None + else None + ), + target_config_mode=( + target_config_identity.mode + if target_config_identity is not None + else None + ), + ) + + +@pipeline_node( + PipelineNode( + id="stage2_target_selection_result", + label="Stage 2 Target Selection Result", + node_type="library", + description="Stable selected target metadata, checksum, JSONL rows, and facet counts consumed by Stage 2 matrix building and diagnostics.", + source_file="policyengine_us_data/calibration_package/targets.py", + status="current", + stability="moving", + pathways=["calibration_package"], + artifacts_out=[ + CALIBRATION_TARGETS_FILENAME, + CALIBRATION_TARGET_FACETS_FILENAME, + ], + validation_commands=[ + "uv run pytest tests/unit/calibration_package/test_targets.py" + ], + ) +) +@dataclass(frozen=True, kw_only=True) +class TargetSelectionResult: + """Selected target metadata in the order consumed by Stage 2.""" + + targets_df: pd.DataFrame + disabled_targets_df: pd.DataFrame + constraints_by_stratum: Mapping[int, tuple[Mapping[str, Any], ...]] + target_config_path: str | None + target_config_sha256: str | None + target_config_mode: str | None + target_names: tuple[str, ...] = () + + @property + def target_ids(self) -> list[int]: + """Return selected target IDs in stable order.""" + + return [int(value) for value in self.targets_df["target_id"].tolist()] + + @property + def n_selected_targets(self) -> int: + """Return the number of selected package targets.""" + + return int(len(self.targets_df)) + + @property + def checksum(self) -> str: + """Return the stable target selection checksum.""" + + digest = hashlib.sha256() + for row in self.to_rows(): + digest.update( + json.dumps(row, sort_keys=True, separators=(",", ":")).encode() + ) + digest.update(b"\n") + return f"sha256:{digest.hexdigest()}" + + def with_matrix_order( + self, + targets_df: pd.DataFrame, + target_names: Iterable[str], + ) -> "TargetSelectionResult": + """Return this result in the matrix/package target order.""" + + ordered = _normalize_target_frame(targets_df) + names = tuple(str(name) for name in target_names) + if len(ordered) != len(names): + raise ValueError("Target metadata row count must match target_names") + return TargetSelectionResult( + targets_df=ordered.reset_index(drop=True), + disabled_targets_df=self.disabled_targets_df, + constraints_by_stratum=self.constraints_by_stratum, + target_config_path=self.target_config_path, + target_config_sha256=self.target_config_sha256, + target_config_mode=self.target_config_mode, + target_names=names, + ) + + def to_rows(self) -> list[dict[str, Any]]: + """Return JSONL-ready selected target rows.""" + + rows: list[dict[str, Any]] = [] + target_names = self.target_names or tuple( + _fallback_target_name(row) + for _, row in self.targets_df.reset_index(drop=True).iterrows() + ) + for target_index, (_, row) in enumerate( + self.targets_df.reset_index(drop=True).iterrows() + ): + constraints = [ + _jsonable_constraint(constraint) + for constraint in self.constraints_by_stratum.get( + int(row["stratum_id"]), + (), + ) + ] + components = target_variable_components(str(row["variable"])) + target_expression = str(row["variable"]) if len(components) > 1 else None + rows.append( + { + "target_id": int(row["target_id"]), + "target_index": int(target_index), + "target_name": str(target_names[target_index]), + "variable": str(row["variable"]), + "target_expression": target_expression, + "target_components": components, + "target_value": _optional_float(row.get("value")), + "period": _optional_int(row.get("period")), + "geography_level": _optional_string(row.get("geo_level")), + "geography_id": _optional_string(row.get("geographic_id")), + "domain_variable": _optional_string(row.get("domain_variable")), + "source_table": "targets", + "source": _optional_string(row.get("source")), + "target_config_path": self.target_config_path, + "target_config_sha256": self.target_config_sha256, + "target_config_mode": self.target_config_mode, + "included_in_package": True, + "notes": _optional_string(row.get("notes")), + "constraint_key": _constraint_key(constraints), + "target_constraints": constraints, + } + ) + return rows + + def disabled_rows(self) -> list[dict[str, Any]]: + """Return disabled target rows for reporting.""" + + rows: list[dict[str, Any]] = [] + for _, row in self.disabled_targets_df.reset_index(drop=True).iterrows(): + rows.append( + { + "target_id": int(row["target_id"]), + "variable": str(row["variable"]), + "period": _optional_int(row.get("period")), + "geography_level": _optional_string(row.get("geo_level")), + "geography_id": _optional_string(row.get("geographic_id")), + "domain_variable": _optional_string(row.get("domain_variable")), + "included_in_package": False, + "notes": _optional_string(row.get("notes")), + } + ) + return rows + + def facets(self) -> dict[str, Any]: + """Return compact counts derived from selected row-level metadata.""" + + return target_facets_from_rows(self.to_rows()) + + def summary(self) -> dict[str, Any]: + """Return a compact selection summary for package metadata.""" + + return { + "target_count": self.n_selected_targets, + "disabled_target_count": int(len(self.disabled_targets_df)), + "target_selection_sha256": self.checksum, + "target_config_path": self.target_config_path, + "target_config_sha256": self.target_config_sha256, + "target_config_mode": self.target_config_mode, + } + + def write_artifacts( + self, + targets_path: str | Path, + facets_path: str | Path, + ) -> tuple[Path, Path]: + """Write row-level target metadata and facet summary artifacts.""" + + rows = self.to_rows() + target_file = Path(targets_path) + facet_file = Path(facets_path) + target_file.parent.mkdir(parents=True, exist_ok=True) + with target_file.open("w", encoding="utf-8") as handle: + for row in rows: + handle.write(json.dumps(row, sort_keys=True)) + handle.write("\n") + facet_file.write_text( + json.dumps(self.facets(), indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + return target_file, facet_file + + +def target_facets_from_rows(rows: Iterable[Mapping[str, Any]]) -> dict[str, Any]: + """Derive target facet counts from row-level target metadata.""" + + material = [dict(row) for row in rows] + return { + "target_count": len(material), + "by_variable": _counts(material, "variable"), + "by_geography_level": _counts(material, "geography_level"), + "by_target_name": _counts(material, "target_name"), + "by_period": _counts(material, "period"), + "by_constraint_key": _counts(material, "constraint_key"), + } + + +def _query_targets( + engine: Any, + *, + time_period: int, + target_filter: Mapping[str, Any], + active_only: bool, + target_columns: set[str], + view_columns: set[str], +) -> pd.DataFrame: + where_clause, params = _target_filter_sql(target_filter, active_only=active_only) + reform_expr = "tv.reform_id" if "reform_id" in view_columns else "0" + reform_group = "reform_id" if "reform_id" in view_columns else "0" + source_expr = "t.source" if "source" in target_columns else "NULL" + notes_expr = "t.notes" if "notes" in target_columns else "NULL" + active_condition = "tv.active = 1" if active_only else "tv.active != 1" + if active_only: + query = f""" + WITH filtered_targets AS ( + SELECT tv.target_id, tv.stratum_id, tv.variable, + {reform_expr} AS reform_id, tv.value, tv.period, + tv.geo_level, tv.geographic_id, tv.domain_variable + FROM target_overview tv + WHERE {active_condition} + AND ({where_clause}) + ), + best_periods AS ( + SELECT stratum_id, variable, {reform_group} AS reform_id, + CASE + WHEN MAX(CASE WHEN period <= :time_period THEN period END) + IS NOT NULL + THEN MAX(CASE WHEN period <= :time_period THEN period END) + ELSE MIN(period) + END AS best_period + FROM filtered_targets + GROUP BY stratum_id, variable, reform_id + ) + SELECT ft.*, {source_expr} AS source, {notes_expr} AS notes + FROM filtered_targets ft + JOIN best_periods bp + ON ft.stratum_id = bp.stratum_id + AND ft.variable = bp.variable + AND ft.reform_id = bp.reform_id + AND ft.period = bp.best_period + LEFT JOIN targets t ON t.target_id = ft.target_id + ORDER BY ft.target_id + """ + params["time_period"] = int(time_period) + else: + query = f""" + SELECT tv.target_id, tv.stratum_id, tv.variable, + {reform_expr} AS reform_id, tv.value, tv.period, + tv.geo_level, tv.geographic_id, tv.domain_variable, + {source_expr} AS source, {notes_expr} AS notes + FROM target_overview tv + LEFT JOIN targets t ON t.target_id = tv.target_id + WHERE {active_condition} + AND ({where_clause}) + ORDER BY tv.target_id + """ + with engine.connect() as conn: + return _normalize_target_frame(pd.read_sql(text(query), conn, params=params)) + + +def _ensure_target_overview(engine: Any) -> None: + with engine.connect() as conn: + conn.execute(text(TARGET_OVERVIEW_VIEW)) + conn.commit() + + +def _target_filter_sql( + target_filter: Mapping[str, Any], + *, + active_only: bool, +) -> tuple[str, dict[str, Any]]: + conditions: list[str] = [] + params: dict[str, Any] = {} + filter_columns = { + "domain_variables": "tv.domain_variable", + "variables": "tv.variable", + "target_ids": "tv.target_id", + "stratum_ids": "tv.stratum_id", + } + for key, column in filter_columns.items(): + if key not in target_filter: + continue + values = list(target_filter[key]) + if not values: + conditions.append("0 = 1") + continue + placeholders = [] + for index, value in enumerate(values): + param = f"{key}_{index}_{'active' if active_only else 'disabled'}" + placeholders.append(f":{param}") + params[param] = value + conditions.append(f"{column} IN ({', '.join(placeholders)})") + return (" AND ".join(f"({condition})" for condition in conditions) or "1=1", params) + + +def _load_constraints_by_stratum( + engine: Any, +) -> dict[int, tuple[Mapping[str, Any], ...]]: + query = """ + SELECT stratum_id, constraint_variable AS variable, operation, value + FROM stratum_constraints + ORDER BY stratum_id, constraint_id + """ + with engine.connect() as conn: + frame = pd.read_sql(text(query), conn) + grouped: dict[int, list[Mapping[str, Any]]] = {} + for _, row in frame.iterrows(): + grouped.setdefault(int(row["stratum_id"]), []).append( + { + "variable": str(row["variable"]), + "operation": str(row["operation"]), + "value": str(row["value"]), + } + ) + return {key: tuple(value) for key, value in grouped.items()} + + +def _table_columns(engine: Any, table: str) -> set[str]: + with engine.connect() as conn: + rows = conn.execute(text(f"PRAGMA table_info({table})")).fetchall() + return {str(row[1]) for row in rows} + + +def _normalize_target_frame(frame: pd.DataFrame) -> pd.DataFrame: + normalized = frame.copy() + for column in TARGET_CATALOG_COLUMNS: + if column not in normalized.columns: + normalized[column] = None + if len(normalized): + if normalized["target_id"].isna().all(): + normalized["target_id"] = list(range(len(normalized))) + if normalized["stratum_id"].isna().all(): + normalized["stratum_id"] = list(range(len(normalized))) + normalized = normalized.loc[:, list(TARGET_CATALOG_COLUMNS)] + return normalized.reset_index(drop=True) + + +def _match_rules( + targets_df: pd.DataFrame, + rules: Iterable[Mapping[str, Any]], +) -> pd.Series: + mask = pd.Series(False, index=targets_df.index) + for rule in rules: + if "variable" not in rule: + raise ValueError("Target selection rules require a variable") + rule_mask = targets_df["variable"].astype(str) == str(rule["variable"]) + if "geo_level" in rule: + rule_mask &= targets_df["geo_level"].astype(str) == str(rule["geo_level"]) + if "domain_variable" in rule: + domain_values = targets_df["domain_variable"].fillna("").astype(str) + rule_mask &= domain_values == str(rule["domain_variable"]) + mask |= rule_mask + return mask + + +def _validate_target_expressions( + targets_df: pd.DataFrame, + valid_variables: Iterable[str] | Mapping[str, Any] | None, +) -> None: + if valid_variables is None: + return + valid = set(valid_variables) + for variable in targets_df["variable"].astype(str): + components = target_variable_components(variable) + missing = [component for component in components if component not in valid] + if missing: + raise ValueError( + "Target variable expression contains unknown component(s): " + + ", ".join(missing) + ) + + +def _jsonable_constraint(constraint: Mapping[str, Any]) -> dict[str, str]: + return { + "variable": str(constraint.get("variable")), + "operation": str(constraint.get("operation")), + "value": str(constraint.get("value")), + } + + +def _constraint_key(constraints: Iterable[Mapping[str, Any]]) -> str: + material = [ + f"{item['variable']}{item['operation']}{item['value']}" + for item in constraints + if item.get("variable") not in GEO_CONSTRAINT_VARIABLES + ] + return "|".join(material) if material else "none" + + +def _fallback_target_name(row: pd.Series) -> str: + geo = str(row.get("geographic_id") or "US") + return f"{geo}/{row.get('variable')}" + + +def _counts(rows: list[Mapping[str, Any]], key: str) -> dict[str, int]: + counter = Counter(str(row.get(key)) for row in rows) + return dict(sorted(counter.items())) + + +def _optional_string(value: Any) -> str | None: + if value is None or pd.isna(value): + return None + return str(value) + + +def _optional_int(value: Any) -> int | None: + if value is None or pd.isna(value): + return None + return int(value) + + +def _optional_float(value: Any) -> float | None: + if value is None or pd.isna(value): + return None + return float(value) + + +__all__ = [ + "GEO_CONSTRAINT_VARIABLES", + "TARGET_CATALOG_COLUMNS", + "TARGET_OVERVIEW_VIEW", + "TargetCatalog", + "TargetCatalogReader", + "TargetSelectionPolicy", + "TargetSelectionResult", + "target_facets_from_rows", +] diff --git a/policyengine_us_data/stage_contracts/calibration_package.py b/policyengine_us_data/stage_contracts/calibration_package.py index dc0385321..a235068c1 100644 --- a/policyengine_us_data/stage_contracts/calibration_package.py +++ b/policyengine_us_data/stage_contracts/calibration_package.py @@ -2,16 +2,23 @@ from __future__ import annotations -import pickle from collections.abc import Mapping from pathlib import Path from typing import Any -from policyengine_us_data.utils.step_manifest import sha256_file -from policyengine_us_data.utils.geography_checksum import ( - canonical_geography_checksum, - hash_string_array, +from policyengine_us_data.calibration_package.payload import ( + CalibrationPackagePayload, + CalibrationPackageReader, +) +from policyengine_us_data.calibration_package.specs import ( + CALIBRATION_PACKAGE_CONTRACT_FILENAME, + CALIBRATION_PACKAGE_SUBSTAGE_ID, + CALIBRATION_TARGET_FACETS_FILENAME, + CALIBRATION_TARGETS_FILENAME, ) +from policyengine_us_data.pipeline_metadata import pipeline_node +from policyengine_us_data.pipeline_schema import PipelineNode +from policyengine_us_data.utils.step_manifest import sha256_file from .artifacts import ArtifactRef from .calibration_package_schema import ( @@ -26,154 +33,25 @@ from .stages import STAGE_2_BUILD_CALIBRATION_PACKAGE, contract_type_for_stage from .substages import SubstageRecord -CALIBRATION_PACKAGE_CONTRACT_FILENAME = "calibration_package_contract.json" CALIBRATION_PACKAGE_CONTRACT_TYPE = contract_type_for_stage( STAGE_2_BUILD_CALIBRATION_PACKAGE ) -CALIBRATION_PACKAGE_SUBSTAGE_ID = "2a_matrix_build_calibration_target_construction" def summarize_geography_assignment( - package: Mapping[str, Any], + package: CalibrationPackagePayload | Mapping[str, Any], ) -> GeographyAssignmentSummary: """Return a contract-safe summary of package-backed geography assignment.""" - metadata = _package_metadata(package) - n_records = _optional_metadata_int(metadata, "base_n_records") - n_clones = _optional_metadata_int(metadata, "n_clones") - raw_blocks = package.get("block_geoid") - raw_cds = package.get("cd_geoid") - has_blocks = raw_blocks is not None - has_cds = raw_cds is not None - - if not has_blocks and not has_cds: - return _unavailable_geography_assignment_summary( - n_records=n_records, - n_clones=n_clones, - ) - if not has_blocks or not has_cds: - raise ValueError( - "Calibration package geography requires both block_geoid and cd_geoid" - ) - if n_records is None or n_clones is None: - raise ValueError( - "Calibration package geography requires metadata base_n_records and n_clones" - ) - if n_records <= 0 or n_clones <= 0: - raise ValueError( - "Calibration package geography requires positive base_n_records and n_clones" - ) - - block_geoids = _one_dimensional_string_array(raw_blocks, "block_geoid") - cd_geoids = _one_dimensional_string_array(raw_cds, "cd_geoid") - n_rows = int(len(block_geoids)) - if n_rows == 0: - raise ValueError("Calibration package geography arrays must be non-empty") - if len(cd_geoids) != n_rows: - raise ValueError( - "Calibration package geography has mismatched block_geoid and cd_geoid " - f"lengths: {n_rows} != {len(cd_geoids)}" - ) - if n_records * n_clones != n_rows: - raise ValueError( - "Calibration package geography length does not match metadata: " - f"{n_rows} rows for {n_records} records x {n_clones} clones" - ) - - return GeographyAssignmentSummary( - source_kind="calibration_package", - n_records=n_records, - n_clones=n_clones, - n_rows=n_rows, - has_block_geoid=True, - has_cd_geoid=True, - block_geoid_length=n_rows, - cd_geoid_length=int(len(cd_geoids)), - block_geoid_sha256=hash_string_array(block_geoids), - cd_geoid_sha256=hash_string_array(cd_geoids), - canonical_geography_sha256=canonical_geography_checksum( - block_geoid=block_geoids, - cd_geoid=cd_geoids, - n_records=n_records, - n_clones=n_clones, - ), - ) - - -def _unavailable_geography_assignment_summary( - *, - n_records: int | None, - n_clones: int | None, -) -> GeographyAssignmentSummary: - """Create a summary for legacy packages without geography assignment arrays.""" - - return GeographyAssignmentSummary( - source_kind="unavailable", - n_records=n_records, - n_clones=n_clones, - n_rows=None, - has_block_geoid=False, - has_cd_geoid=False, - block_geoid_length=None, - cd_geoid_length=None, - block_geoid_sha256=None, - cd_geoid_sha256=None, - canonical_geography_sha256=None, - ) + return _calibration_package_payload(package, require_core=False).geography_summary() def summarize_calibration_package( - package: Mapping[str, Any], + package: CalibrationPackagePayload | Mapping[str, Any], ) -> CalibrationPackageSummary: """Return a contract-safe summary of a calibration package pickle payload.""" - matrix = _required_package_value(package, "X_sparse") - targets_df = _required_package_value(package, "targets_df") - target_names = _required_package_value(package, "target_names") - metadata = _package_metadata(package) - - try: - n_targets, n_columns = matrix.shape - except (AttributeError, ValueError) as exc: - raise ValueError("X_sparse must expose a two-dimensional shape") from exc - if not hasattr(matrix, "nnz"): - raise ValueError("X_sparse must expose nnz") - - n_targets = int(n_targets) - n_columns = int(n_columns) - nnz = int(matrix.nnz) - density = nnz / (n_targets * n_columns) if n_targets * n_columns else 0.0 - - return CalibrationPackageSummary( - matrix_shape=(n_targets, n_columns), - matrix_nnz=nnz, - matrix_density=float(density), - n_targets=int(len(targets_df)), - n_columns=n_columns, - target_name_count=int(len(target_names)), - dataset_sha256=_optional_metadata_string(metadata, "dataset_sha256"), - db_sha256=_optional_metadata_string(metadata, "db_sha256"), - target_config_path=_optional_metadata_string( - metadata, - "target_config_path", - ), - target_config_sha256=_optional_metadata_string( - metadata, - "target_config_sha256", - ), - n_clones=_optional_metadata_int(metadata, "n_clones"), - seed=_optional_metadata_int(metadata, "seed"), - base_n_records=_optional_metadata_int(metadata, "base_n_records"), - package_scope=_optional_metadata_string(metadata, "package_scope"), - matrix_builder=_optional_metadata_string(metadata, "matrix_builder"), - chunk_size=_optional_metadata_int(metadata, "chunk_size"), - chunk_dir=_optional_metadata_string(metadata, "chunk_dir"), - has_initial_weights=package.get("initial_weights") is not None, - has_cd_geoid=package.get("cd_geoid") is not None, - has_block_geoid=package.get("block_geoid") is not None, - cd_geoid_length=_optional_len(package.get("cd_geoid")), - block_geoid_length=_optional_len(package.get("block_geoid")), - ) + return _calibration_package_payload(package).summary() def build_calibration_package_contract( @@ -181,7 +59,7 @@ def build_calibration_package_contract( package_path: Path, dataset_path: Path, db_path: Path, - package: Mapping[str, Any], + package: CalibrationPackagePayload | Mapping[str, Any], parameters: CalibrationPackageParameters | Mapping[str, Any], run_id: str | None, completed_at: str, @@ -189,6 +67,9 @@ def build_calibration_package_contract( duration_s: float | None = None, code_sha: str | None = None, package_version: str | None = None, + target_metadata_path: Path | None = None, + target_facets_path: Path | None = None, + target_selection_summary: Mapping[str, Any] | None = None, ) -> StageContract: """Build the Stage 2 handoff contract from a calibration package.""" @@ -200,10 +81,14 @@ def build_calibration_package_contract( _require_existing_file(db_path, "target database") parameter_schema = _calibration_package_parameters(parameters) - parameter_payload = parameter_schema.to_dict() - metadata = _package_metadata(package) - package_summary = summarize_calibration_package(package).to_dict() - geography_summary = summarize_geography_assignment(package).to_dict() + payload = _calibration_package_payload(package) + metadata = payload.metadata + parameter_payload = _parameters_with_package_identity( + parameter_schema.to_dict(), + metadata, + ) + package_summary = payload.summary().to_dict() + geography_summary = payload.geography_summary().to_dict() inputs = ( _artifact_ref_from_path( logical_name="source_imputed_stratified_extended_cps", @@ -224,7 +109,7 @@ def build_calibration_package_contract( }, ), ) - outputs = ( + outputs = [ _artifact_ref_from_path( logical_name="calibration_package", path=package_path, @@ -233,7 +118,36 @@ def build_calibration_package_contract( "substage_id": CALIBRATION_PACKAGE_SUBSTAGE_ID, }, ), - ) + ] + if target_metadata_path is not None: + _require_existing_file(target_metadata_path, "calibration targets metadata") + outputs.append( + _artifact_ref_from_path( + logical_name="calibration_targets", + path=Path(target_metadata_path), + media_type="application/x-ndjson", + metadata={ + "artifact_family": "target_metadata", + "substage_id": CALIBRATION_PACKAGE_SUBSTAGE_ID, + "stable_join_keys": ("target_id", "target_index"), + }, + ) + ) + if target_facets_path is not None: + _require_existing_file(target_facets_path, "calibration target facets") + outputs.append( + _artifact_ref_from_path( + logical_name="calibration_target_facets", + path=Path(target_facets_path), + media_type="application/json", + metadata={ + "artifact_family": "target_metadata", + "substage_id": CALIBRATION_PACKAGE_SUBSTAGE_ID, + "derived_from": CALIBRATION_TARGETS_FILENAME, + }, + ) + ) + outputs = tuple(outputs) code_sha = code_sha or _optional_metadata_string(metadata, "git_commit") package_version = package_version or _optional_metadata_string( metadata, @@ -246,9 +160,9 @@ def build_calibration_package_contract( duration_s=duration_s, reuse_decision="not_applicable", reuse_summary=ReuseSummary( - expected_outputs=1, + expected_outputs=len(outputs), valid_reused_outputs=0, - recomputed_outputs=1, + recomputed_outputs=len(outputs), invalid_outputs=0, ), ) @@ -261,6 +175,7 @@ def build_calibration_package_contract( "parameters": parameter_payload, "package_summary": package_summary, "geography_assignment": geography_summary, + "target_selection": target_selection_summary or {}, } ) return StageContract( @@ -291,16 +206,38 @@ def build_calibration_package_contract( "contract_file": CALIBRATION_PACKAGE_CONTRACT_FILENAME, "geography_assignment": geography_summary, "package_summary": package_summary, + "target_selection": dict(target_selection_summary or {}), }, ) +@pipeline_node( + PipelineNode( + id="stage2_calibration_package_contract_writer", + label="Stage 2 Contract Writer", + node_type="library", + description="Write the Stage 2 calibration-package handoff contract next to the package artifact.", + source_file="policyengine_us_data/stage_contracts/calibration_package.py", + status="current", + stability="moving", + pathways=["calibration_package"], + artifacts_in=["calibration_package.pkl"], + artifacts_out=[ + CALIBRATION_PACKAGE_CONTRACT_FILENAME, + CALIBRATION_TARGETS_FILENAME, + CALIBRATION_TARGET_FACETS_FILENAME, + ], + validation_commands=[ + "uv run pytest tests/unit/test_calibration_package_stage_contract.py" + ], + ) +) def write_calibration_package_contract( *, package_path: Path, dataset_path: Path, db_path: Path, - package: Mapping[str, Any], + package: CalibrationPackagePayload | Mapping[str, Any], parameters: CalibrationPackageParameters | Mapping[str, Any], run_id: str | None, completed_at: str, @@ -309,6 +246,9 @@ def write_calibration_package_contract( code_sha: str | None = None, package_version: str | None = None, contract_path: Path | None = None, + target_metadata_path: Path | None = None, + target_facets_path: Path | None = None, + target_selection_summary: Mapping[str, Any] | None = None, ) -> StageContract: """Write and return the Stage 2 calibration-package contract.""" @@ -325,6 +265,9 @@ def write_calibration_package_contract( duration_s=duration_s, code_sha=code_sha, package_version=package_version, + target_metadata_path=target_metadata_path, + target_facets_path=target_facets_path, + target_selection_summary=target_selection_summary, ) write_contract( contract, @@ -333,11 +276,30 @@ def write_calibration_package_contract( return contract +@pipeline_node( + PipelineNode( + id="stage2_calibration_package_contract_validator", + label="Stage 2 Contract Validator", + node_type="validation", + description="Validate that the persisted Stage 2 contract describes the calibration package and inputs.", + source_file="policyengine_us_data/stage_contracts/calibration_package.py", + status="current", + stability="moving", + pathways=["calibration_package"], + artifacts_in=[ + "calibration_package.pkl", + CALIBRATION_PACKAGE_CONTRACT_FILENAME, + ], + validation_commands=[ + "uv run pytest tests/unit/test_calibration_package_stage_contract.py" + ], + ) +) def validate_calibration_package_contract( *, package_path: Path, contract_path: Path | None = None, - package: Mapping[str, Any] | None = None, + package: CalibrationPackagePayload | Mapping[str, Any] | None = None, dataset_path: Path | None = None, db_path: Path | None = None, ) -> StageContract: @@ -419,29 +381,29 @@ def validate_persisted_calibration_package_contract( ) -def load_calibration_package_payload(package_path: Path) -> Mapping[str, Any]: - """Load a calibration package pickle for sidecar validation.""" +def load_calibration_package_payload(package_path: Path) -> CalibrationPackagePayload: + """Load a typed calibration package payload for sidecar validation.""" - with Path(package_path).open("rb") as handle: - package = pickle.load(handle) - if not isinstance(package, Mapping): - raise ValueError("Calibration package pickle must contain a mapping") - return package + return CalibrationPackageReader(package_path=Path(package_path)).read() -def _required_package_value(package: Mapping[str, Any], key: str) -> Any: - if key not in package: - raise ValueError(f"Calibration package missing required key: {key}") - return package[key] +def _calibration_package_payload( + package: CalibrationPackagePayload | Mapping[str, Any], + *, + require_core: bool = True, +) -> CalibrationPackagePayload: + if isinstance(package, CalibrationPackagePayload): + return package + return CalibrationPackagePayload.from_mapping( + package, + require_required_keys=require_core, + ) -def _package_metadata(package: Mapping[str, Any]) -> Mapping[str, Any]: - metadata = package.get("metadata", {}) - if metadata is None: - return {} - if not isinstance(metadata, Mapping): - raise ValueError("Calibration package metadata must be a mapping") - return metadata +def _package_metadata( + package: CalibrationPackagePayload | Mapping[str, Any], +) -> Mapping[str, Any]: + return _calibration_package_payload(package).metadata def _optional_metadata_string( @@ -463,23 +425,6 @@ def _optional_metadata_int(metadata: Mapping[str, Any], key: str) -> int | None: return int(value) -def _optional_len(value: Any) -> int | None: - if value is None: - return None - return int(len(value)) - - -def _one_dimensional_string_array(value: Any, key: str) -> Any: - import numpy as np - - array = np.asarray(value, dtype=str) - if array.ndim != 1: - raise ValueError(f"Calibration package geography {key} must be one-dimensional") - if np.any(array == ""): - raise ValueError(f"Calibration package geography {key} contains empty values") - return array - - def _calibration_package_parameters( parameters: CalibrationPackageParameters | Mapping[str, Any], ) -> CalibrationPackageParameters: @@ -488,6 +433,46 @@ def _calibration_package_parameters( return CalibrationPackageParameters.from_dict(parameters) +def _parameters_with_package_identity( + parameters: Mapping[str, Any], + metadata: Mapping[str, Any], +) -> dict[str, Any]: + payload = dict(parameters) + metadata_path = _optional_metadata_string(metadata, "target_config_path") + metadata_sha = _optional_metadata_string(metadata, "target_config_sha256") + metadata_mode = _optional_metadata_string(metadata, "target_config_mode") + + if metadata_path: + if payload.get("target_config") is None: + payload["target_config"] = metadata_path + if payload["target_config"] != metadata_path: + raise ValueError( + "Calibration package contract target_config does not match " + "package metadata" + ) + if metadata_sha: + if payload.get("target_config_sha256") is None: + payload["target_config_sha256"] = metadata_sha + if payload["target_config_sha256"] != metadata_sha: + raise ValueError( + "Calibration package contract target_config_sha256 does not match " + "package metadata" + ) + if metadata_mode: + if payload.get("target_config_mode") is None: + payload["target_config_mode"] = metadata_mode + if payload["target_config_mode"] != metadata_mode: + raise ValueError( + "Calibration package contract target_config_mode does not match " + "package metadata" + ) + if payload.get("target_config_mode") is None: + payload["target_config_mode"] = ( + "all_active_targets" if payload.get("target_config") is None else "explicit" + ) + return payload + + def _require_existing_file(path: Path, label: str) -> None: if not path.exists(): raise FileNotFoundError(f"Missing {label}: {path}") @@ -500,13 +485,14 @@ def _artifact_ref_from_path( logical_name: str, path: Path, metadata: Mapping[str, Any], + media_type: str | None = None, ) -> ArtifactRef: return ArtifactRef( logical_name=logical_name, uri=path.resolve().as_uri(), sha256=f"sha256:{sha256_file(path)}", size_bytes=path.stat().st_size, - media_type=_media_type_for_path(path), + media_type=media_type or _media_type_for_path(path), metadata=metadata, ) @@ -519,6 +505,8 @@ def _media_type_for_path(path: Path) -> str: return "application/vnd.sqlite3" if suffix == ".json": return "application/json" + if suffix == ".jsonl": + return "application/x-ndjson" if suffix == ".pkl": return "application/python-pickle" return "application/octet-stream" diff --git a/policyengine_us_data/stage_contracts/calibration_package_schema.py b/policyengine_us_data/stage_contracts/calibration_package_schema.py index 06030812e..7d00b9800 100644 --- a/policyengine_us_data/stage_contracts/calibration_package_schema.py +++ b/policyengine_us_data/stage_contracts/calibration_package_schema.py @@ -7,6 +7,10 @@ from math import isfinite from typing import Any +from policyengine_us_data.calibration_package.specs import ( + TARGET_CONFIG_IDENTITY_MODES, +) + GEOGRAPHY_ASSIGNMENT_SOURCE_KINDS = frozenset( { "calibration_package", @@ -39,6 +43,8 @@ "skip_source_impute", "skip_takeup_rerandomize", "target_config", + "target_config_mode", + "target_config_sha256", "workers", } ) @@ -207,6 +213,8 @@ class CalibrationPackageParameters: workers: int | None n_clones: int target_config: str | None + target_config_sha256: str | None + target_config_mode: str | None skip_county: bool skip_source_impute: bool skip_takeup_rerandomize: bool @@ -230,6 +238,25 @@ def __post_init__(self) -> None: _validate_bool(self.parallel_matrix, "parallel_matrix") if self.target_config is not None and not isinstance(self.target_config, str): raise ValueError("target_config must be a string or None") + if self.target_config_sha256 is not None and not isinstance( + self.target_config_sha256, + str, + ): + raise ValueError("target_config_sha256 must be a string or None") + if self.target_config_mode is not None: + if not isinstance(self.target_config_mode, str): + raise ValueError("target_config_mode must be a string or None") + if self.target_config_mode not in TARGET_CONFIG_IDENTITY_MODES: + raise ValueError( + "target_config_mode must be one of " + f"{sorted(TARGET_CONFIG_IDENTITY_MODES)}" + ) + if self.target_config_mode == "all_active_targets": + if self.target_config is not None or self.target_config_sha256 is not None: + raise ValueError( + "all_active_targets target config parameters cannot include " + "a path or checksum" + ) if self.chunked_matrix: if self.workers is not None: raise ValueError("workers must be None when chunked_matrix is true") @@ -265,14 +292,21 @@ def from_runtime_args( chunk_size: int, parallel: bool, num_matrix_workers: int, + target_config_sha256: str | None = None, + target_config_mode: str | None = None, ) -> "CalibrationPackageParameters": """Build canonical Stage 2 parameters from runtime CLI arguments.""" parallel_matrix = bool(chunked_matrix and parallel) + resolved_mode = target_config_mode or ( + "all_active_targets" if target_config_path is None else "explicit" + ) return cls( workers=workers if not chunked_matrix else None, n_clones=n_clones, target_config=target_config_path, + target_config_sha256=target_config_sha256, + target_config_mode=resolved_mode, skip_county=skip_county, skip_source_impute=skip_source_impute, skip_takeup_rerandomize=skip_takeup_rerandomize, @@ -291,15 +325,26 @@ def from_dict( if not isinstance(data, Mapping): raise ValueError("calibration package parameters must be a mapping") - _require_exact_keys( + _require_compatible_keys( data, "calibration package parameters", CALIBRATION_PACKAGE_PARAMETER_KEYS, + legacy_optional_keys=frozenset( + {"target_config_mode", "target_config_sha256"} + ), ) + target_config = _optional_string_field(data, "target_config") + target_config_mode = _optional_string_field(data, "target_config_mode") return cls( workers=_optional_int_field(data, "workers"), n_clones=_required_int_field(data, "n_clones"), - target_config=_optional_string_field(data, "target_config"), + target_config=target_config, + target_config_sha256=_optional_string_field( + data, + "target_config_sha256", + ), + target_config_mode=target_config_mode + or ("all_active_targets" if target_config is None else "explicit"), skip_county=_required_bool_field(data, "skip_county"), skip_source_impute=_required_bool_field(data, "skip_source_impute"), skip_takeup_rerandomize=_required_bool_field( @@ -325,6 +370,8 @@ def to_dict(self) -> dict[str, Any]: "skip_source_impute": self.skip_source_impute, "skip_takeup_rerandomize": self.skip_takeup_rerandomize, "target_config": self.target_config, + "target_config_mode": self.target_config_mode, + "target_config_sha256": self.target_config_sha256, "workers": self.workers, } @@ -463,9 +510,24 @@ def _require_exact_keys( data: Mapping[str, Any], label: str, expected_keys: frozenset[str], +) -> None: + _require_compatible_keys( + data, + label, + expected_keys, + legacy_optional_keys=frozenset(), + ) + + +def _require_compatible_keys( + data: Mapping[str, Any], + label: str, + expected_keys: frozenset[str], + *, + legacy_optional_keys: frozenset[str], ) -> None: keys = {str(key) for key in data} - missing = sorted(expected_keys - keys) + missing = sorted((expected_keys - legacy_optional_keys) - keys) unexpected = sorted(keys - expected_keys) if missing: raise ValueError(f"{label} missing required key: {missing[0]}") diff --git a/tests/unit/calibration/test_unified_calibration.py b/tests/unit/calibration/test_unified_calibration.py index 41022a17b..cb815ba56 100644 --- a/tests/unit/calibration/test_unified_calibration.py +++ b/tests/unit/calibration/test_unified_calibration.py @@ -54,6 +54,8 @@ def test_calibration_package_contract_parameters_track_effective_matrix_mode(): workers=8, n_clones=430, target_config_path="policyengine_us_data/calibration/target_config.yaml", + target_config_sha256="abc123", + target_config_mode="default", skip_county=True, skip_source_impute=True, skip_takeup_rerandomize=False, @@ -68,6 +70,8 @@ def test_calibration_package_contract_parameters_track_effective_matrix_mode(): "workers": None, "n_clones": 430, "target_config": "policyengine_us_data/calibration/target_config.yaml", + "target_config_sha256": "abc123", + "target_config_mode": "default", "skip_county": True, "skip_source_impute": True, "skip_takeup_rerandomize": False, @@ -83,6 +87,8 @@ def test_calibration_package_contract_parameters_ignore_unused_chunk_options(): workers=8, n_clones=430, target_config_path=None, + target_config_sha256=None, + target_config_mode="all_active_targets", skip_county=True, skip_source_impute=True, skip_takeup_rerandomize=False, diff --git a/tests/unit/calibration_package/test_payload.py b/tests/unit/calibration_package/test_payload.py new file mode 100644 index 000000000..ec88d627c --- /dev/null +++ b/tests/unit/calibration_package/test_payload.py @@ -0,0 +1,89 @@ +import json + +import pytest + +from tests.unit.fixtures.calibration_package_stage_contract import ( + calibration_package_contract, + calibration_package_payload, + calibration_package_payload_without_geography, +) + +from policyengine_us_data.calibration_package.payload import ( + LEGACY_MISSING_GEOGRAPHY_WARNING, + CalibrationPackagePayload, + CalibrationPackageReader, + CalibrationPackageWriter, +) +from policyengine_us_data.stage_contracts.calibration_package import ( + summarize_calibration_package, +) +from policyengine_us_data.stage_contracts.calibration_package_schema import ( + CalibrationPackageSummary, + GeographyAssignmentSummary, +) + + +def test_calibration_package_payload_read_write_round_trip(tmp_path): + package_path = tmp_path / "calibration_package.pkl" + payload = CalibrationPackagePayload.from_mapping(calibration_package_payload()) + + written = CalibrationPackageWriter(package_path=package_path).write(payload) + loaded = CalibrationPackageReader(package_path=package_path).read() + + assert written == package_path + assert loaded.summary() == payload.summary() + assert loaded.geography_summary() == payload.geography_summary() + assert ( + CalibrationPackageReader(package_path=package_path) + .checksum() + .startswith("sha256:") + ) + + +@pytest.mark.parametrize( + "missing_key", + ["X_sparse", "targets_df", "target_names", "metadata"], +) +def test_calibration_package_payload_rejects_missing_required_keys(missing_key): + package = calibration_package_payload() + package.pop(missing_key) + + with pytest.raises(ValueError, match=missing_key): + CalibrationPackagePayload.from_mapping(package) + + +def test_legacy_package_without_geography_records_compatibility_warning(): + payload = CalibrationPackagePayload.from_mapping( + calibration_package_payload_without_geography() + ) + + assert payload.compatibility_warnings == (LEGACY_MISSING_GEOGRAPHY_WARNING,) + geography = payload.geography_summary() + assert geography.source_kind == "unavailable" + + +def test_payload_summary_matches_existing_contract_summary(): + payload = CalibrationPackagePayload.from_mapping(calibration_package_payload()) + + assert payload.summary() == summarize_calibration_package(payload.to_mapping()) + + +def test_metadata_sidecar_uses_payload_and_contract(tmp_path): + package_path = tmp_path / "calibration_package.pkl" + payload = CalibrationPackagePayload.from_mapping(calibration_package_payload()) + contract = calibration_package_contract(tmp_path) + writer = CalibrationPackageWriter(package_path=package_path) + writer.write(payload) + + sidecar_path = writer.write_metadata_sidecar(payload, contract=contract) + sidecar = json.loads(sidecar_path.read_text(encoding="utf-8")) + + assert sidecar_path == tmp_path / "calibration_package_meta.json" + assert CalibrationPackageSummary.from_dict(sidecar["package_summary"]) == ( + payload.summary() + ) + assert GeographyAssignmentSummary.from_dict(sidecar["geography_assignment"]) == ( + payload.geography_summary() + ) + assert sidecar["contract"]["stage_id"] == "2_build_calibration_package" + assert sidecar["compatibility_warnings"] == [] diff --git a/tests/unit/calibration_package/test_specs.py b/tests/unit/calibration_package/test_specs.py new file mode 100644 index 000000000..2f124d835 --- /dev/null +++ b/tests/unit/calibration_package/test_specs.py @@ -0,0 +1,231 @@ +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from policyengine_us_data.calibration_package.specs import ( + CALIBRATION_PACKAGE_CONTRACT_FILENAME, + CALIBRATION_PACKAGE_FILENAME, + CALIBRATION_PACKAGE_METADATA_FILENAME, + CALIBRATION_TARGET_FACETS_FILENAME, + CALIBRATION_TARGETS_FILENAME, + CALIBRATION_REPORTS_DIRNAME, + DATASET_BUILD_OUTPUT_CONTRACT_FILENAME, + DEFAULT_TARGET_CONFIG_PATH, + MATRIX_BUILD_DIRNAME, + SOURCE_DATASET_FILENAME, + TARGET_DATABASE_FILENAME, + TargetConfigIdentity, + calibration_package_artifact_paths, + resolve_target_config_identity, + stage2_build_context_for_run, + stage2_input_bundle_from_artifacts_dir, + stage2_input_bundle_from_stage1_contract, +) +from policyengine_us_data.stage_contracts.dataset_build import ( + build_dataset_build_output_contract, +) +from policyengine_us_data.stage_contracts.io import write_contract +from policyengine_us_data.utils.manifest import compute_file_checksum + + +def _write_default_target_config(repo_root: Path, body: str = "include: []\n") -> Path: + config_path = repo_root / DEFAULT_TARGET_CONFIG_PATH + config_path.parent.mkdir(parents=True) + config_path.write_text(body, encoding="utf-8") + return config_path + + +def test_default_target_config_identity_resolution(tmp_path): + config_path = _write_default_target_config(tmp_path) + + identity = resolve_target_config_identity(repo_root=tmp_path) + + assert identity == TargetConfigIdentity( + path=DEFAULT_TARGET_CONFIG_PATH, + sha256=compute_file_checksum(config_path), + mode="default", + resolved_path=str(config_path.resolve()), + ) + assert identity.to_parameters() == { + "target_config": DEFAULT_TARGET_CONFIG_PATH, + "target_config_sha256": compute_file_checksum(config_path), + "target_config_mode": "default", + } + + +def test_explicit_target_config_identity_resolution(tmp_path): + config_path = _write_default_target_config(tmp_path) + + identity = resolve_target_config_identity( + DEFAULT_TARGET_CONFIG_PATH, + repo_root=tmp_path, + ) + + assert identity.path == DEFAULT_TARGET_CONFIG_PATH + assert identity.sha256 == compute_file_checksum(config_path) + assert identity.mode == "explicit" + assert identity.resolved_path == str(config_path.resolve()) + + +def test_all_active_targets_identity_resolution(): + identity = resolve_target_config_identity(all_active_targets=True) + + assert identity.to_parameters() == { + "target_config": None, + "target_config_sha256": None, + "target_config_mode": "all_active_targets", + } + + +def test_all_active_targets_rejects_config_path(): + with pytest.raises(ValueError, match="all-active-targets"): + resolve_target_config_identity( + DEFAULT_TARGET_CONFIG_PATH, + all_active_targets=True, + ) + + +def test_calibration_package_artifact_paths(): + paths = calibration_package_artifact_paths("/pipeline/artifacts/run-a") + + assert paths.package == Path("/pipeline/artifacts/run-a") / ( + CALIBRATION_PACKAGE_FILENAME + ) + assert paths.metadata == Path("/pipeline/artifacts/run-a") / ( + CALIBRATION_PACKAGE_METADATA_FILENAME + ) + assert paths.contract == Path("/pipeline/artifacts/run-a") / ( + CALIBRATION_PACKAGE_CONTRACT_FILENAME + ) + assert paths.targets == Path("/pipeline/artifacts/run-a") / ( + CALIBRATION_TARGETS_FILENAME + ) + assert paths.target_facets == Path("/pipeline/artifacts/run-a") / ( + CALIBRATION_TARGET_FACETS_FILENAME + ) + assert paths.reports_dir == Path("/pipeline/artifacts/run-a") / ( + CALIBRATION_REPORTS_DIRNAME + ) + assert paths.matrix_build_dir == Path("/pipeline/artifacts/run-a") / ( + MATRIX_BUILD_DIRNAME + ) + assert paths.manifest_outputs == ( + paths.package, + paths.contract, + paths.targets, + paths.target_facets, + ) + + +def test_stage2_input_bundle_from_artifacts_dir(tmp_path): + (tmp_path / SOURCE_DATASET_FILENAME).write_bytes(b"h5") + (tmp_path / TARGET_DATABASE_FILENAME).write_bytes(b"db") + + bundle = stage2_input_bundle_from_artifacts_dir(tmp_path) + + assert bundle.source == "artifacts_dir_fallback" + assert bundle.compatibility_only is True + assert bundle.manifest_inputs == { + "dataset": tmp_path / SOURCE_DATASET_FILENAME, + "database": tmp_path / TARGET_DATABASE_FILENAME, + } + assert bundle.validation_report().status == "pass" + + +def test_stage2_input_bundle_from_fake_stage1_contract(tmp_path): + dataset = tmp_path / SOURCE_DATASET_FILENAME + database = tmp_path / TARGET_DATABASE_FILENAME + dataset.write_bytes(b"h5") + database.write_bytes(b"db") + contract = SimpleNamespace( + stage_id="1_build_datasets", + run_id="run-a", + outputs=( + SimpleNamespace( + logical_name="source_imputed_stratified_extended_cps", + uri=dataset.resolve().as_uri(), + ), + SimpleNamespace( + logical_name="policy_data_db", + uri=database.resolve().as_uri(), + ), + ), + ) + + bundle = stage2_input_bundle_from_stage1_contract( + contract, + artifacts_dir=tmp_path, + contract_path=tmp_path / DATASET_BUILD_OUTPUT_CONTRACT_FILENAME, + ) + + assert bundle.source == "stage1_contract" + assert bundle.compatibility_only is False + assert bundle.stage1_contract_run_id == "run-a" + assert bundle.stage1_contract_path == ( + tmp_path / DATASET_BUILD_OUTPUT_CONTRACT_FILENAME + ) + assert bundle.source_dataset == dataset + assert bundle.target_database == database + assert bundle.validation_report().status == "pass" + + +def test_stage2_input_bundle_missing_required_artifacts_are_actionable(tmp_path): + (tmp_path / SOURCE_DATASET_FILENAME).write_bytes(b"h5") + bundle = stage2_input_bundle_from_artifacts_dir(tmp_path) + + report = bundle.validation_report() + + assert report.status == "fail" + assert [finding.check_id for finding in report.findings] == [ + "stage2_input_exists:database" + ] + assert str(tmp_path / TARGET_DATABASE_FILENAME) in report.findings[0].message + with pytest.raises(FileNotFoundError, match="database"): + bundle.require_existing() + + +def test_stage2_build_context_prefers_stage1_contract(tmp_path): + artifacts_dir = tmp_path / "artifacts" / "run-a" + artifacts_dir.mkdir(parents=True) + for filename in ( + "acs_2022.h5", + "irs_puf_2015.h5", + "cps_2024.h5", + "puf_2024.h5", + "extended_cps_2024.h5", + "enhanced_cps_2024.h5", + "small_enhanced_cps_2024.h5", + "stratified_extended_cps_2024.h5", + "source_imputed_stratified_extended_cps_2024.h5", + SOURCE_DATASET_FILENAME, + TARGET_DATABASE_FILENAME, + "build_log.txt", + "data_build_checkpoint_stats.json", + ): + (artifacts_dir / filename).write_bytes(filename.encode("utf-8")) + contract_path = artifacts_dir / DATASET_BUILD_OUTPUT_CONTRACT_FILENAME + write_contract( + build_dataset_build_output_contract( + artifacts_dir=artifacts_dir, + run_id="run-a", + code_sha="abc123", + package_version="1.0.0", + checkpoint_stats={}, + started_at="2026-01-01T00:00:00+00:00", + completed_at="2026-01-01T00:00:01+00:00", + duration_s=1.0, + ), + contract_path, + ) + + context = stage2_build_context_for_run(tmp_path, "run-a") + + assert context.input_bundle.source == "stage1_contract" + assert ( + context.input_bundle.source_dataset == artifacts_dir / SOURCE_DATASET_FILENAME + ) + assert ( + context.input_bundle.target_database == artifacts_dir / TARGET_DATABASE_FILENAME + ) + assert context.output_bundle.package == artifacts_dir / CALIBRATION_PACKAGE_FILENAME diff --git a/tests/unit/calibration_package/test_targets.py b/tests/unit/calibration_package/test_targets.py new file mode 100644 index 000000000..54090f0aa --- /dev/null +++ b/tests/unit/calibration_package/test_targets.py @@ -0,0 +1,178 @@ +import json + +import pytest +from sqlalchemy import create_engine, text + +from policyengine_us_data.calibration_package.specs import TargetConfigIdentity +from policyengine_us_data.calibration_package.targets import ( + TARGET_OVERVIEW_VIEW, + TargetCatalogReader, + TargetSelectionPolicy, + target_facets_from_rows, +) + + +@pytest.fixture +def target_db(tmp_path): + db_path = tmp_path / "targets.db" + engine = create_engine(f"sqlite:///{db_path}") + with engine.connect() as conn: + conn.execute( + text( + "CREATE TABLE strata (" + "stratum_id INTEGER PRIMARY KEY, " + "definition_hash VARCHAR(64), " + "parent_stratum_id INTEGER, " + "notes VARCHAR)" + ) + ) + conn.execute( + text( + "CREATE TABLE stratum_constraints (" + "constraint_id INTEGER PRIMARY KEY, " + "stratum_id INTEGER, " + "constraint_variable TEXT, " + "operation TEXT, " + "value TEXT)" + ) + ) + conn.execute( + text( + "CREATE TABLE targets (" + "target_id INTEGER PRIMARY KEY, " + "stratum_id INTEGER, " + "variable TEXT, " + "reform_id INTEGER DEFAULT 0, " + "value REAL, " + "period INTEGER, " + "active INTEGER DEFAULT 1, " + "source TEXT, " + "notes TEXT)" + ) + ) + conn.execute(text(TARGET_OVERVIEW_VIEW)) + conn.execute(text("INSERT INTO strata VALUES (1, NULL, NULL, 'national')")) + conn.execute(text("INSERT INTO strata VALUES (2, NULL, 1, 'state snap')")) + conn.execute(text("INSERT INTO strata VALUES (3, NULL, 1, 'national rent')")) + conn.execute( + text( + "INSERT INTO stratum_constraints VALUES " + "(1, 2, 'state_fips', '=', '6'), " + "(2, 2, 'snap', '>', '0'), " + "(3, 3, 'rent', '>', '0')" + ) + ) + conn.execute( + text( + "INSERT INTO targets " + "(target_id, stratum_id, variable, reform_id, value, period, active, source, notes) " + "VALUES " + "(1, 1, 'snap', 0, 100.0, 2022, 1, 'SOI', 'base snap'), " + "(2, 1, 'snap', 0, 200.0, 2024, 0, 'SOI', 'disabled newer snap'), " + "(3, 2, 'eitc+ctc', 0, 300.0, 2024, 1, 'IRS', 'additive'), " + "(4, 3, 'rent', 0, 400.0, 2024, 1, 'ACS', 'rent domain')" + ) + ) + conn.commit() + try: + yield f"sqlite:///{db_path}" + finally: + engine.dispose() + + +def _identity() -> TargetConfigIdentity: + return TargetConfigIdentity( + path="policyengine_us_data/calibration/target_config.yaml", + sha256="sha256:target-config", + mode="default", + resolved_path="/repo/policyengine_us_data/calibration/target_config.yaml", + ) + + +def test_target_catalog_reader_loads_active_and_disabled_targets(target_db): + catalog = TargetCatalogReader(db_uri=target_db, time_period=2024).load() + + assert catalog.targets["target_id"].tolist() == [1, 3, 4] + assert catalog.disabled_targets["target_id"].tolist() == [2] + assert catalog.constraints_for(2)[0]["variable"] == "state_fips" + + +def test_additive_target_expressions_require_valid_components(target_db): + catalog = TargetCatalogReader(db_uri=target_db, time_period=2024).load( + {"variables": ["eitc+ctc"]} + ) + policy = TargetSelectionPolicy.from_config({}) + + selected = policy.select(catalog, valid_variables={"eitc", "ctc"}) + + assert selected.targets_df["variable"].tolist() == ["eitc+ctc"] + with pytest.raises(ValueError, match="ctc"): + policy.select(catalog, valid_variables={"eitc"}) + + +def test_target_selection_policy_filters_config_and_reports_disabled(target_db): + catalog = TargetCatalogReader(db_uri=target_db, time_period=2024).load() + policy = TargetSelectionPolicy.from_config( + { + "include": [ + {"variable": "snap", "geo_level": "national"}, + {"variable": "rent", "geo_level": "national"}, + ], + "exclude": [{"variable": "rent", "geo_level": "national"}], + } + ) + + selected = policy.select(catalog, target_config_identity=_identity()) + + assert selected.target_ids == [1] + assert selected.disabled_rows()[0]["target_id"] == 2 + assert selected.summary()["target_config_sha256"] == "sha256:target-config" + + +def test_target_selection_order_and_checksum_change_with_config(target_db): + catalog = TargetCatalogReader(db_uri=target_db, time_period=2024).load() + all_targets = TargetSelectionPolicy.from_config({}).select(catalog) + snap_only = TargetSelectionPolicy.from_config( + {"include": [{"variable": "snap", "geo_level": "national"}]} + ).select(catalog) + + assert all_targets.target_ids == [1, 3, 4] + assert snap_only.target_ids == [1] + assert all_targets.checksum != snap_only.checksum + + +def test_target_metadata_artifacts_match_package_order_and_facets(target_db, tmp_path): + catalog = TargetCatalogReader(db_uri=target_db, time_period=2024).load() + selected = TargetSelectionPolicy.from_config({}).select( + catalog, + target_config_identity=_identity(), + ) + matrix_order = selected.targets_df.iloc[[1, 0, 2]].reset_index(drop=True) + selected = selected.with_matrix_order( + matrix_order, + ["state_6/eitc+ctc[snap>0]", "national/snap", "national/rent[rent>0]"], + ) + + targets_path, facets_path = selected.write_artifacts( + tmp_path / "calibration_targets.jsonl", + tmp_path / "calibration_target_facets.json", + ) + rows = [ + json.loads(line) + for line in targets_path.read_text(encoding="utf-8").splitlines() + ] + facets = json.loads(facets_path.read_text(encoding="utf-8")) + + assert [row["target_id"] for row in rows] == [3, 1, 4] + assert [row["target_index"] for row in rows] == [0, 1, 2] + assert rows[0]["target_expression"] == "eitc+ctc" + assert rows[0]["target_components"] == ["eitc", "ctc"] + assert facets == target_facets_from_rows(rows) + + fake_fit_rows = [{"target_id": 3, "target_index": 0, "fitted": 301.0}] + rows_by_id = {row["target_id"]: row for row in rows} + rows_by_index = {row["target_index"]: row for row in rows} + assert ( + rows_by_id[fake_fit_rows[0]["target_id"]]["target_name"] + == (rows_by_index[fake_fit_rows[0]["target_index"]]["target_name"]) + ) diff --git a/tests/unit/fixtures/calibration_package_stage_contract.py b/tests/unit/fixtures/calibration_package_stage_contract.py index f11640ee9..fe2e70524 100644 --- a/tests/unit/fixtures/calibration_package_stage_contract.py +++ b/tests/unit/fixtures/calibration_package_stage_contract.py @@ -67,6 +67,7 @@ def calibration_package_payload() -> dict[str, Any]: "db_sha256": "sha256:db", "target_config_path": TARGET_CONFIG_PATH, "target_config_sha256": "sha256:target-config", + "target_config_mode": "explicit", "n_clones": 3, "seed": 42, "base_n_records": 1, @@ -148,6 +149,8 @@ def calibration_package_parameters() -> dict[str, Any]: "workers": None, "n_clones": 3, "target_config": TARGET_CONFIG_PATH, + "target_config_sha256": "sha256:target-config", + "target_config_mode": "explicit", "skip_county": True, "skip_source_impute": True, "skip_takeup_rerandomize": False, diff --git a/tests/unit/test_calibration_package_stage_contract.py b/tests/unit/test_calibration_package_stage_contract.py index f00f646da..444b88062 100644 --- a/tests/unit/test_calibration_package_stage_contract.py +++ b/tests/unit/test_calibration_package_stage_contract.py @@ -55,11 +55,41 @@ def test_calibration_package_contract_records_stage_2_handoff(tmp_path): assert contract.outputs[0].media_type == "application/python-pickle" +def test_calibration_package_contract_references_target_metadata_artifacts(tmp_path): + dataset_path, db_path, package_path = contract_input_paths(tmp_path) + package = write_calibration_package_payload(package_path) + targets_path = tmp_path / "calibration_targets.jsonl" + facets_path = tmp_path / "calibration_target_facets.json" + targets_path.write_text('{"target_id":1,"target_index":0}\n', encoding="utf-8") + facets_path.write_text('{"target_count":1}\n', encoding="utf-8") + + contract = build_calibration_package_contract( + package_path=package_path, + dataset_path=dataset_path, + db_path=db_path, + package=package, + parameters=calibration_package_parameters(), + run_id="run-a", + completed_at="2026-05-08T12:02:00Z", + target_metadata_path=targets_path, + target_facets_path=facets_path, + target_selection_summary={"target_count": 1}, + ) + + outputs = {artifact.logical_name: artifact for artifact in contract.outputs} + assert outputs["calibration_targets"].media_type == "application/x-ndjson" + assert outputs["calibration_target_facets"].media_type == "application/json" + assert contract.metadata["target_selection"] == {"target_count": 1} + assert contract.execution.reuse_summary.expected_outputs == 3 + + def test_calibration_package_parameters_parse_runtime_args(): params = CalibrationPackageParameters.from_runtime_args( workers=8, n_clones=430, target_config_path=TARGET_CONFIG_PATH, + target_config_sha256="sha256:target-config", + target_config_mode="explicit", skip_county=True, skip_source_impute=True, skip_takeup_rerandomize=False, @@ -79,6 +109,8 @@ def test_calibration_package_parameters_parse_runtime_args(): "skip_source_impute": True, "skip_takeup_rerandomize": False, "target_config": TARGET_CONFIG_PATH, + "target_config_mode": "explicit", + "target_config_sha256": "sha256:target-config", "workers": None, } @@ -89,6 +121,8 @@ def test_calibration_package_parameters_reject_inconsistent_chunk_shape(): workers=8, n_clones=430, target_config=None, + target_config_sha256=None, + target_config_mode="all_active_targets", skip_county=True, skip_source_impute=True, skip_takeup_rerandomize=False, diff --git a/tests/unit/test_pipeline.py b/tests/unit/test_pipeline.py index 69f67bb82..7bcfaf540 100644 --- a/tests/unit/test_pipeline.py +++ b/tests/unit/test_pipeline.py @@ -9,6 +9,10 @@ modal = pytest.importorskip("modal") +from policyengine_us_data.calibration_package.specs import ( # noqa: E402 + DEFAULT_TARGET_CONFIG_PATH, +) +from policyengine_us_data.utils.manifest import compute_file_checksum # noqa: E402 from modal_app.pipeline import ( # noqa: E402 NATIONAL_FIT_LAMBDA_L0, _build_diagnostics_upload_script, @@ -44,6 +48,11 @@ def test_calibration_package_parameters_track_matrix_mode(): assert params["chunked_matrix"] is True assert "workers" not in params + assert params["target_config"] == DEFAULT_TARGET_CONFIG_PATH + assert params["target_config_sha256"] == compute_file_checksum( + DEFAULT_TARGET_CONFIG_PATH + ) + assert params["target_config_mode"] == "default" assert params["chunk_size"] == 10_000 assert params["parallel_matrix"] is True assert params["num_matrix_workers"] == 25 @@ -63,6 +72,11 @@ def test_calibration_package_parameters_ignore_unused_matrix_options(): assert params["chunked_matrix"] is False assert params["workers"] == 50 + assert params["target_config"] == DEFAULT_TARGET_CONFIG_PATH + assert params["target_config_sha256"] == compute_file_checksum( + DEFAULT_TARGET_CONFIG_PATH + ) + assert params["target_config_mode"] == "default" assert "chunk_size" not in params assert params["parallel_matrix"] is False assert "num_matrix_workers" not in params diff --git a/tests/unit/test_pipeline_docs_extractor.py b/tests/unit/test_pipeline_docs_extractor.py index 72c4daa4a..85d140e6d 100644 --- a/tests/unit/test_pipeline_docs_extractor.py +++ b/tests/unit/test_pipeline_docs_extractor.py @@ -122,3 +122,26 @@ def test_pipeline_map_manifest_validates(): assert bundle["metadata"]["mapped_decorated_node_count"] >= 45 assert sum(len(stage["nodes"]) for stage in bundle["stages"]) >= 160 assert sum(len(stage["edges"]) for stage in bundle["stages"]) >= 170 + stage2 = next( + stage + for stage in bundle["stages"] + if stage["id"] == "2a_matrix_build_calibration_target_construction" + ) + stage2_node_ids = {node["id"] for node in stage2["nodes"]} + assert { + "stage2_input_bundle", + "stage2_build_context", + "stage2_target_config_identity", + "stage2_target_catalog_load", + "stage2_target_catalog_reader", + "stage2_target_selection_policy", + "stage2_target_selection_result", + "build_matrix", + "build_matrix_chunked", + "stage2_payload_boundary", + "stage2_payload_reader", + "stage2_payload_writer", + "stage2_calibration_package_writer", + "stage2_calibration_package_contract_writer", + "stage2_calibration_package_contract_validator", + } <= stage2_node_ids diff --git a/tests/unit/test_pipeline_source_contracts.py b/tests/unit/test_pipeline_source_contracts.py index a022ef3fe..95d761fee 100644 --- a/tests/unit/test_pipeline_source_contracts.py +++ b/tests/unit/test_pipeline_source_contracts.py @@ -58,6 +58,31 @@ def test_run_pipeline_stage_1_stages_datasets_without_promoting() -> None: assert keywords["version"].id == "candidate_version" +def test_calibration_package_parameters_record_target_config_identity() -> None: + source_text = PIPELINE_SOURCE.read_text() + tree = ast.parse(source_text) + helper = _function_def(tree, "_calibration_package_parameters") + source = ast.get_source_segment(source_text, helper) + + assert "resolve_target_config_identity(" in source + assert '"target_config": target_config_identity.path' in source + assert '"target_config_sha256": target_config_identity.sha256' in source + assert '"target_config_mode": target_config_identity.mode' in source + + +def test_stage_2_manifest_records_package_and_contract_outputs() -> None: + source_text = PIPELINE_SOURCE.read_text() + tree = ast.parse(source_text) + run_pipeline = _function_def(tree, "run_pipeline") + source = ast.get_source_segment(source_text, run_pipeline) + + assert "package_context = stage2_build_context_for_run(" in source + assert "package_context.input_bundle.manifest_inputs" in source + assert 'package_inputs["input_validation"]' in source + assert "package_artifacts = package_context.output_bundle" in source + assert "package_artifacts.manifest_outputs" in source + + def test_promote_run_fails_closed_for_required_promotion_steps() -> None: tree = ast.parse(PIPELINE_SOURCE.read_text()) promote_run = _function_def(tree, "promote_run") diff --git a/tests/unit/test_remote_calibration_runner.py b/tests/unit/test_remote_calibration_runner.py index 77053dc78..db295ea6a 100644 --- a/tests/unit/test_remote_calibration_runner.py +++ b/tests/unit/test_remote_calibration_runner.py @@ -1,3 +1,4 @@ +import json import inspect import importlib import sys @@ -243,3 +244,32 @@ def fake_run_streaming(cmd, env=None, label=""): ensure_prereqs.assert_called_once() volume.reload.assert_called_once() volume.commit.assert_called_once() + + +def test_write_package_sidecar_reads_payload_and_contract(tmp_path): + remote_runner = _load_remote_calibration_runner_module() + from tests.unit.fixtures.calibration_package_stage_contract import ( + calibration_package_contract, + ) + + from policyengine_us_data.calibration_package.specs import ( + CALIBRATION_PACKAGE_CONTRACT_FILENAME, + CALIBRATION_PACKAGE_METADATA_FILENAME, + ) + from policyengine_us_data.stage_contracts.io import write_contract + + package_path = tmp_path / "calibration_package.pkl" + contract = calibration_package_contract(tmp_path) + write_contract(contract, tmp_path / CALIBRATION_PACKAGE_CONTRACT_FILENAME) + + assert remote_runner._write_package_sidecar(str(package_path)) is True + + sidecar = json.loads( + (tmp_path / CALIBRATION_PACKAGE_METADATA_FILENAME).read_text( + encoding="utf-8", + ) + ) + assert sidecar["package_sha256"].startswith("sha256:") + assert sidecar["package_summary"]["matrix_shape"] == [2, 3] + assert sidecar["geography_assignment"]["source_kind"] == "calibration_package" + assert sidecar["contract"]["stage_id"] == "2_build_calibration_package"