From 85b8488b3487d4723cee5770dee582eb6ac8e2db Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 21 May 2026 19:27:27 +0200 Subject: [PATCH] Stabilize chunked matrix execution contracts Fixes #1100 --- changelog.d/1100.changed | 1 + docs/pipeline_map.yaml | 35 ++ modal_app/matrix_chunk_worker.py | 145 ++++-- .../calibration/chunked_matrix_assembler.py | 75 ++- .../calibration/chunked_matrix_modal.py | 32 +- .../calibration/signatures.py | 5 +- .../calibration/unified_matrix_builder.py | 23 +- .../calibration_package/__init__.py | 18 + .../calibration_package/matrix.py | 482 +++++++++++++++++- .../test_chunked_matrix_assembler.py | 37 ++ .../calibration/test_chunked_matrix_modal.py | 151 ++++-- tests/unit/calibration_package/test_matrix.py | 131 +++++ tests/unit/test_pipeline_docs_extractor.py | 2 + 13 files changed, 1029 insertions(+), 108 deletions(-) create mode 100644 changelog.d/1100.changed diff --git a/changelog.d/1100.changed b/changelog.d/1100.changed new file mode 100644 index 000000000..61de46bf1 --- /dev/null +++ b/changelog.d/1100.changed @@ -0,0 +1 @@ +Stabilize Stage 2 chunked matrix execution request, result, and resume metadata. diff --git a/docs/pipeline_map.yaml b/docs/pipeline_map.yaml index 1e6da6058..e6d30d1ea 100644 --- a/docs/pipeline_map.yaml +++ b/docs/pipeline_map.yaml @@ -816,6 +816,8 @@ stages: - stage2_geography_assignment_result - stage2_matrix_build_spec - stage2_matrix_build_service + - stage2_chunk_build_request + - stage2_chunk_worker_result - stage2_matrix_build_result - target_resolve - stage2_target_config_apply @@ -838,6 +840,7 @@ stages: - out_targets - out_target_facets - out_geography_summary + - out_chunk_result_manifests - out_matrix_summary - stage2_calibration_package_contract_writer - out_contract @@ -911,6 +914,18 @@ stages: label: matrix_summary.json node_type: artifact description: Compact Stage 2 matrix shape, sparsity, target order, builder mode, and chunk lineage summary + - id: out_chunk_result_manifests + label: chunk_results/*.json + node_type: artifact + description: Per-chunk structured progress, cache, and error metadata for chunked matrix execution + - id: stage2_chunk_build_request + label: Chunk Build Request + node_type: library + description: Typed Modal worker request carrying run, chunk, state path, resume flag, and lineage signature material + - id: stage2_chunk_worker_result + label: Chunk Worker Result + node_type: library + description: Typed Modal worker result carrying per-chunk completion, cache, and error metadata - id: out_contract label: calibration_package_contract.json node_type: artifact @@ -1024,6 +1039,10 @@ stages: target: stage2_matrix_build_service edge_type: data_flow label: builder mode and chunk settings + - source: stage2_matrix_build_service + target: stage2_chunk_build_request + edge_type: data_flow + label: chunk ids and lineage signature - source: stage2_target_catalog_load target: stage2_target_config_apply edge_type: data_flow @@ -1077,6 +1096,22 @@ stages: target: stage2_matrix_build_result edge_type: data_flow label: chunk manifest and shards + - source: stage2_chunk_build_request + target: build_matrix_chunked + edge_type: data_flow + label: typed worker request + - source: build_matrix_chunked + target: stage2_chunk_worker_result + edge_type: data_flow + label: typed worker result + - source: stage2_chunk_worker_result + target: out_chunk_result_manifests + edge_type: produces_artifact + label: per-chunk progress and error metadata + - source: out_chunk_result_manifests + target: stage2_matrix_build_result + edge_type: data_flow + label: chunk execution diagnostics - source: stage2_matrix_build_service target: build_matrix edge_type: uses_library diff --git a/modal_app/matrix_chunk_worker.py b/modal_app/matrix_chunk_worker.py index b89f77502..6d1d69741 100644 --- a/modal_app/matrix_chunk_worker.py +++ b/modal_app/matrix_chunk_worker.py @@ -55,71 +55,140 @@ def _chunk_root(run_id: str) -> str: nonpreemptible=True, ) def build_matrix_chunk_worker( - run_id: str, - chunk_ids: List[int], + request: Dict | None = None, + run_id: str | None = None, + chunk_ids: List[int] | None = None, resume_chunks: bool = False, ) -> Dict: - """Materialize ``chunk_ids`` from the pickled ``SharedBuildState``. + """Materialize a typed chunk request from pickled ``SharedBuildState``. Args: - run_id: Pipeline run identifier; selects the volume path for - this worker's shared state and shard output directory. - chunk_ids: Chunk indices this worker is responsible for. - resume_chunks: Whether to trust matching pre-existing COO shards. - Fresh builds pass ``False`` so workers overwrite stale chunks. + request: Typed worker request material. Legacy ``run_id`` and + ``chunk_ids`` arguments are accepted for compatibility with + older local tests and undeployed call sites. + run_id: Legacy pipeline run identifier. + chunk_ids: Legacy chunk indices this worker is responsible for. + resume_chunks: Legacy resume flag. Returns: - Dict with ``chunk_ids``, ``nnz_per_chunk``, and ``errors`` - lists suitable for the coordinator to aggregate. + Structured worker result material suitable for the coordinator to + aggregate. """ from policyengine_us_data.calibration.chunked_matrix_assembler import ( ChunkedMatrixAssembler, ) + from policyengine_us_data.calibration.signatures import signature_mismatches + from policyengine_us_data.calibration_package.matrix import ( + CHUNK_EXECUTION_SCHEMA_VERSION, + ChunkBuildRequest, + ChunkExecutionResult, + ChunkWorkerResult, + write_chunk_result_manifest, + ) pipeline_vol.reload() - chunk_root = Path(_chunk_root(run_id)) - state_path = chunk_root / "chunk_build_state.pkl" + if request is None: + if run_id is None or chunk_ids is None: + raise ValueError("request or legacy run_id/chunk_ids are required") + chunk_root = Path(_chunk_root(run_id)) + request_obj = ChunkBuildRequest( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id=run_id, + chunk_ids=tuple(chunk_ids), + chunk_root=str(chunk_root), + state_path=str(chunk_root / "chunk_build_state.pkl"), + resume_chunks=resume_chunks, + lineage_signature={}, + ) + else: + request_obj = ChunkBuildRequest.from_dict(request) + chunk_root = Path(request_obj.chunk_root) + state_path = Path(request_obj.state_path) if not state_path.exists(): - return { - "chunk_ids": list(chunk_ids), - "nnz_per_chunk": [], - "errors": [ - { - "chunk_ids": list(chunk_ids), - "error": f"Missing shared state at {state_path}", - } - ], - } + chunk_results = tuple( + ChunkExecutionResult.failure( + run_id=request_obj.run_id, + chunk_id=chunk_id, + error=f"Missing shared state at {state_path}", + ) + for chunk_id in request_obj.chunk_ids + ) + for result in chunk_results: + write_chunk_result_manifest(chunk_root, result) + pipeline_vol.commit() + return ChunkWorkerResult( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id=request_obj.run_id, + chunk_ids=request_obj.chunk_ids, + chunk_results=chunk_results, + ).to_dict() with open(state_path, "rb") as f: shared_state = pickle.load(f) + if request_obj.lineage_signature: + state_lineage_signature = getattr(shared_state, "lineage_signature", {}) + fatal, _ = signature_mismatches( + state_lineage_signature, + request_obj.lineage_signature, + ) + if fatal: + error = "Chunk request lineage mismatch: " + "; ".join(fatal) + chunk_results = tuple( + ChunkExecutionResult.failure( + run_id=request_obj.run_id, + chunk_id=chunk_id, + error=error, + ) + for chunk_id in request_obj.chunk_ids + ) + for result in chunk_results: + write_chunk_result_manifest(chunk_root, result) + pipeline_vol.commit() + return ChunkWorkerResult( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id=request_obj.run_id, + chunk_ids=request_obj.chunk_ids, + chunk_results=chunk_results, + ).to_dict() assembler = ChunkedMatrixAssembler( shared_state=shared_state, chunk_root=chunk_root, chunk_size=shared_state.chunk_size, - resume=resume_chunks, + resume=request_obj.resume_chunks, keep_chunks=False, ) - errors: List[Dict] = [] - nnz_per_chunk: List[int] = [] - for chunk_id in chunk_ids: + chunk_results: List[ChunkExecutionResult] = [] + for chunk_id in request_obj.chunk_ids: try: result = assembler.run_single_chunk(chunk_id) - nnz_per_chunk.append(result.nnz) + chunk_results.append( + ChunkExecutionResult.from_chunk_result( + run_id=request_obj.run_id, + result=result, + ) + ) except Exception as exc: - errors.append( - { - "chunk_id": chunk_id, - "error": str(exc), - "traceback": traceback.format_exc(), - } + traceback_text = traceback.format_exc() + assembler.record_chunk_error( + chunk_id=chunk_id, + error=str(exc), + traceback=traceback_text, + ) + chunk_results.append( + ChunkExecutionResult.failure( + run_id=request_obj.run_id, + chunk_id=chunk_id, + error=str(exc), + traceback=traceback_text, + ) ) pipeline_vol.commit() - return { - "chunk_ids": list(chunk_ids), - "nnz_per_chunk": nnz_per_chunk, - "errors": errors, - } + return ChunkWorkerResult( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id=request_obj.run_id, + chunk_ids=request_obj.chunk_ids, + chunk_results=tuple(chunk_results), + ).to_dict() diff --git a/policyengine_us_data/calibration/chunked_matrix_assembler.py b/policyengine_us_data/calibration/chunked_matrix_assembler.py index 59b6699bb..1c5fb21ff 100644 --- a/policyengine_us_data/calibration/chunked_matrix_assembler.py +++ b/policyengine_us_data/calibration/chunked_matrix_assembler.py @@ -14,11 +14,16 @@ import time from dataclasses import dataclass from pathlib import Path -from typing import Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple import numpy as np from scipy import sparse +from policyengine_us_data.calibration_package.matrix import ( + ChunkExecutionResult, + write_chunk_result_manifest, +) + logger = logging.getLogger(__name__) @@ -81,6 +86,7 @@ class SharedBuildState: cd_geoid: np.ndarray county_fips: np.ndarray state_fips: np.ndarray + lineage_signature: Dict[str, Any] @property def n_total(self) -> int: @@ -190,6 +196,26 @@ def stream_csr_from_shards( return X +def _format_duration(seconds: float) -> str: + seconds = max(0, int(round(seconds))) + hours, remainder = divmod(seconds, 3600) + minutes, seconds = divmod(remainder, 60) + if hours: + return f"{hours}h {minutes:02d}m {seconds:02d}s" + if minutes: + return f"{minutes}m {seconds:02d}s" + return f"{seconds}s" + + +def _current_rss_mb() -> Optional[float]: + try: + import psutil + + return psutil.Process().memory_info().rss / 1024**2 + except Exception: + return None + + class ChunkedMatrixAssembler: """Coordinate partitioning, per-chunk execution, and streaming assembly. @@ -267,11 +293,6 @@ def run_chunks(self, chunk_ids: Iterable[int]) -> List[ChunkResult]: cached_chunks, ) else: - from policyengine_us_data.calibration.unified_matrix_builder import ( - _current_rss_mb, - _format_duration, - ) - rss = _current_rss_mb() rss_part = f", rss={rss:,.0f} MB" if rss is not None else "" logger.info( @@ -322,7 +343,9 @@ def run_single_chunk(self, chunk_id: int) -> ChunkResult: f"{cached_col_start}-{cached_col_end - 1}, " f"expected {plan.col_start}-{plan.col_end - 1}" ) - return ChunkResult(chunk_id=chunk_id, nnz=cached_nnz, cached=True) + result = ChunkResult(chunk_id=chunk_id, nnz=cached_nnz, cached=True) + self.write_result_manifest(result) + return result # Imports are local so the module is import-safe in lightweight # environments (e.g., cold Modal containers that haven't yet @@ -520,7 +543,7 @@ def run_single_chunk(self, chunk_id: int) -> ChunkResult: if not self.keep_chunks and plan.h5_path.exists(): plan.h5_path.unlink() - return ChunkResult( + result = ChunkResult( chunk_id=chunk_id, nnz=int(vals.shape[0]), cached=False, @@ -530,6 +553,42 @@ def run_single_chunk(self, chunk_id: int) -> ChunkResult: unique_counties=getattr(summary, "unique_counties", None), unique_cds=getattr(summary, "unique_cds", None), ) + self.write_result_manifest(result) + return result + + def write_result_manifest(self, result: ChunkResult) -> Path: + """Persist structured progress metadata for one chunk.""" + + return write_chunk_result_manifest( + self.chunk_root, + ChunkExecutionResult.from_chunk_result( + run_id=self._manifest_run_id(), + result=result, + ), + ) + + def record_chunk_error( + self, + *, + chunk_id: int, + error: str, + traceback: str | None = None, + ) -> Path: + """Persist structured error metadata for one chunk.""" + + return write_chunk_result_manifest( + self.chunk_root, + ChunkExecutionResult.failure( + run_id=self._manifest_run_id(), + chunk_id=chunk_id, + error=error, + traceback=traceback, + ), + ) + + def _manifest_run_id(self) -> str: + lineage_signature = getattr(self.shared_state, "lineage_signature", {}) + return str(lineage_signature.get("run_id", "")) def assemble_final(self) -> sparse.csr_matrix: """Stream-assemble the final CSR matrix from all shards on disk.""" diff --git a/policyengine_us_data/calibration/chunked_matrix_modal.py b/policyengine_us_data/calibration/chunked_matrix_modal.py index ed280671d..2d2d01104 100644 --- a/policyengine_us_data/calibration/chunked_matrix_modal.py +++ b/policyengine_us_data/calibration/chunked_matrix_modal.py @@ -25,6 +25,11 @@ ChunkedMatrixAssembler, SharedBuildState, ) +from policyengine_us_data.calibration_package.matrix import ( + CHUNK_EXECUTION_SCHEMA_VERSION, + ChunkBuildRequest, + ChunkWorkerResult, +) logger = logging.getLogger(__name__) @@ -178,10 +183,17 @@ def dispatch_chunks_modal( t_dispatch = time.time() handles = [] for batch_idx, chunk_ids in enumerate(batches): - handle = worker_function.spawn( + request = ChunkBuildRequest( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, run_id=run_id, - chunk_ids=chunk_ids, + chunk_ids=tuple(chunk_ids), + chunk_root=str(chunk_root), + state_path=str(state_path), resume_chunks=resume_chunks, + lineage_signature=shared_state.lineage_signature, + ) + handle = worker_function.spawn( + request=request.to_dict(), ) logger.info( "Worker %d/%d: %d chunks (%d-%d), fc=%s", @@ -212,16 +224,26 @@ def dispatch_chunks_modal( {"batch": batch_idx, "error": "Worker returned None"} ) continue - errors = result.get("errors", []) + try: + worker_result = ChunkWorkerResult.from_dict(result) + except ValueError as exc: + aggregated_errors.append( + { + "batch": batch_idx, + "error": f"Worker returned invalid result: {exc}", + } + ) + continue + errors = worker_result.errors if errors: for err in errors: - err_copy = dict(err) + err_copy = err.to_dict() err_copy["batch"] = batch_idx aggregated_errors.append(err_copy) logger.info( "Worker %d done: %d chunks completed, %d errors", batch_idx, - len(result.get("chunk_ids", [])) - len(errors), + worker_result.completed_count, len(errors), ) diff --git a/policyengine_us_data/calibration/signatures.py b/policyengine_us_data/calibration/signatures.py index a161ecebe..d338b5372 100644 --- a/policyengine_us_data/calibration/signatures.py +++ b/policyengine_us_data/calibration/signatures.py @@ -152,6 +152,7 @@ def build_chunk_lineage_signature( target_names: list[str], chunk_size: int, rerandomize_takeup: bool, + run_id: str | None = None, ) -> dict: """Build a signature for validating chunk cache lineage.""" target_columns = [ @@ -170,7 +171,9 @@ def build_chunk_lineage_signature( ] target_frame = targets_df[target_columns].copy() return { - "format_version": 1, + "format_version": 2, + "run_id": run_id or "", + "matrix_builder": "chunked", "dataset_sha256": compute_file_checksum(Path(dataset_path)), "db_sha256": sqlite_checksum(db_uri), "time_period": int(time_period), diff --git a/policyengine_us_data/calibration/unified_matrix_builder.py b/policyengine_us_data/calibration/unified_matrix_builder.py index 0e8c523a9..76102b61e 100644 --- a/policyengine_us_data/calibration/unified_matrix_builder.py +++ b/policyengine_us_data/calibration/unified_matrix_builder.py @@ -10,7 +10,6 @@ Column ordering: index i = clone_idx * n_records + record_idx """ -import json import logging from collections import defaultdict from pathlib import Path @@ -27,7 +26,6 @@ from policyengine_us_data.utils.census import STATE_ABBREV_TO_FIPS, STATE_NAME_TO_FIPS from policyengine_us_data.calibration.signatures import ( build_chunk_lineage_signature, - signature_mismatches, ) from policyengine_us_data.calibration.calibration_utils import ( get_calculated_variables, @@ -38,6 +36,7 @@ TargetCatalogReader, TargetSelectionResult, ) +from policyengine_us_data.calibration_package.matrix import ChunkCacheManifest from policyengine_us_data.pipeline_metadata import pipeline_node from policyengine_us_data.pipeline_schema import PipelineNode from policyengine_us_data.utils.target_variables import ( @@ -99,15 +98,11 @@ def _current_rss_mb() -> Optional[float]: def _load_chunk_manifest(path: Path) -> dict: - with open(path, "r", encoding="utf-8") as f: - return json.load(f) + return ChunkCacheManifest.read(path).to_dict() def _save_chunk_manifest(path: Path, signature: dict) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - with open(path, "w", encoding="utf-8") as f: - json.dump({"signature": signature}, f, indent=2, sort_keys=True) - f.write("\n") + ChunkCacheManifest.from_signature(signature).write(path) def _validate_chunk_manifest(path: Path, expected_signature: dict) -> None: @@ -115,14 +110,8 @@ def _validate_chunk_manifest(path: Path, expected_signature: dict) -> None: raise ValueError( f"Cannot resume chunk cache at {path.parent}: missing chunk manifest" ) - stored = _load_chunk_manifest(path) - stored_signature = stored.get("signature") - if stored_signature is None: - raise ValueError(f"Chunk manifest at {path} is missing its signature") - fatal, _ = signature_mismatches(stored_signature, expected_signature) - if fatal: - joined = "; ".join(fatal) - raise ValueError(f"Chunk cache lineage mismatch for {path.parent}: {joined}") + stored = ChunkCacheManifest.read(path) + stored.validate_lineage(expected_signature, cache_root=path.parent) def _has_existing_chunk_cache(coo_dir: Path) -> bool: @@ -3370,6 +3359,7 @@ def build_matrix_chunked( target_names=target_names, chunk_size=chunk_size, rerandomize_takeup=rerandomize_takeup, + run_id=run_id, ) if resume_chunks: if chunk_manifest_path.exists(): @@ -3404,6 +3394,7 @@ def build_matrix_chunked( cd_geoid=np.asarray(geography.cd_geoid, dtype=str), county_fips=np.asarray(geography.county_fips, dtype=str), state_fips=np.asarray(geography.state_fips), + lineage_signature=chunk_signature, ) assembler = ChunkedMatrixAssembler( shared_state=shared_state, diff --git a/policyengine_us_data/calibration_package/__init__.py b/policyengine_us_data/calibration_package/__init__.py index ef9671142..aa2d2e52e 100644 --- a/policyengine_us_data/calibration_package/__init__.py +++ b/policyengine_us_data/calibration_package/__init__.py @@ -39,11 +39,20 @@ geography_summary_from_package, ) from .matrix import ( + CHUNK_CACHE_MANIFEST_SCHEMA_VERSION, + CHUNK_EXECUTION_SCHEMA_VERSION, + CHUNK_EXECUTION_STATUSES, MATRIX_BUILD_SCHEMA_VERSION, MATRIX_BUILDER_MODES, + ChunkBuildRequest, + ChunkCacheManifest, + ChunkExecutionResult, + ChunkWorkerResult, MatrixBuildResult, MatrixBuildService, MatrixBuildSpec, + chunk_result_manifest_path, + write_chunk_result_manifest, ) from .payload import ( LEGACY_MISSING_GEOGRAPHY_WARNING, @@ -68,6 +77,9 @@ "CALIBRATION_TARGET_FACETS_FILENAME", "CALIBRATION_TARGETS_FILENAME", "CALIBRATION_REPORTS_DIRNAME", + "CHUNK_CACHE_MANIFEST_SCHEMA_VERSION", + "CHUNK_EXECUTION_SCHEMA_VERSION", + "CHUNK_EXECUTION_STATUSES", "DATASET_BUILD_OUTPUT_CONTRACT_FILENAME", "DEFAULT_TARGET_CONFIG_PATH", "GEOGRAPHY_ASSIGNMENT_ORDERING", @@ -85,6 +97,10 @@ "CalibrationPackagePayload", "CalibrationPackageReader", "CalibrationPackageWriter", + "ChunkBuildRequest", + "ChunkCacheManifest", + "ChunkExecutionResult", + "ChunkWorkerResult", "GeographyAssignmentResult", "GeographyAssignmentSpec", "LEGACY_MISSING_GEOGRAPHY_WARNING", @@ -102,6 +118,7 @@ "TargetSelectionPolicy", "TargetSelectionResult", "calibration_package_artifact_paths", + "chunk_result_manifest_path", "geography_spec_from_metadata", "geography_summary_from_package", "resolve_target_config_identity", @@ -110,4 +127,5 @@ "stage2_input_bundle_from_stage1_contract", "stage2_input_bundle_from_stage1_contract_path", "target_facets_from_rows", + "write_chunk_result_manifest", ] diff --git a/policyengine_us_data/calibration_package/matrix.py b/policyengine_us_data/calibration_package/matrix.py index 9f14d818e..3bcb64e0e 100644 --- a/policyengine_us_data/calibration_package/matrix.py +++ b/policyengine_us_data/calibration_package/matrix.py @@ -6,7 +6,7 @@ import json from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, Mapping from policyengine_us_data.pipeline_metadata import pipeline_node from policyengine_us_data.pipeline_schema import PipelineNode @@ -20,8 +20,421 @@ ) MATRIX_BUILD_SCHEMA_VERSION = 1 +CHUNK_CACHE_MANIFEST_SCHEMA_VERSION = 2 +CHUNK_EXECUTION_SCHEMA_VERSION = 1 MatrixBuilderMode = Literal["precompute", "chunked"] MATRIX_BUILDER_MODES = frozenset({"precompute", "chunked"}) +ChunkExecutionStatus = Literal["completed", "cached", "failed"] +CHUNK_EXECUTION_STATUSES = frozenset({"completed", "cached", "failed"}) + + +@dataclass(frozen=True, kw_only=True) +class ChunkCacheManifest: + """Structured lineage manifest for resumable chunked matrix shards.""" + + schema_version: int + lineage_signature: Mapping[str, Any] + + def __post_init__(self) -> None: + _validate_positive_int(self.schema_version, "schema_version") + if not isinstance(self.lineage_signature, Mapping): + raise ValueError("lineage_signature must be a mapping") + + @classmethod + def from_signature(cls, signature: Mapping[str, Any]) -> "ChunkCacheManifest": + """Create a cache manifest from a computed chunk lineage signature.""" + + return cls( + schema_version=CHUNK_CACHE_MANIFEST_SCHEMA_VERSION, + lineage_signature=dict(signature), + ) + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "ChunkCacheManifest": + """Parse a chunk cache manifest from JSON-compatible material.""" + + if not isinstance(data, Mapping): + raise ValueError("chunk cache manifest must be a mapping") + if "signature" in data and "lineage_signature" not in data: + # PR-2f caches wrote {"signature": ...}; preserve compatibility + # while the schema name becomes explicit. + lineage = data["signature"] + else: + lineage = data.get("lineage_signature") + return cls( + schema_version=int(data.get("schema_version", 1)), + lineage_signature=_mapping_value(lineage, "lineage_signature"), + ) + + @classmethod + def read(cls, path: str | Path) -> "ChunkCacheManifest": + """Read a structured chunk cache manifest from disk.""" + + return cls.from_dict(json.loads(Path(path).read_text(encoding="utf-8"))) + + def validate_lineage( + self, + expected_signature: Mapping[str, Any], + *, + cache_root: str | Path | None = None, + ) -> None: + """Reject resume when stored lineage differs from expected lineage.""" + + fatal = _lineage_mismatches(self.lineage_signature, expected_signature) + if fatal: + root = Path(cache_root) if cache_root is not None else "chunk cache" + joined = "; ".join(fatal) + raise ValueError(f"Chunk cache lineage mismatch for {root}: {joined}") + + def to_dict(self) -> dict[str, Any]: + """Return deterministic JSON-compatible cache manifest material.""" + + return { + "lineage_signature": dict(self.lineage_signature), + "schema_version": self.schema_version, + } + + def write(self, path: str | Path) -> Path: + """Write the cache manifest and return the path.""" + + output_path = Path(path) + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text( + json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + return output_path + + +@dataclass(frozen=True, kw_only=True) +class ChunkBuildRequest: + """Typed request consumed by a chunked matrix worker.""" + + schema_version: int + run_id: str + chunk_ids: tuple[int, ...] + chunk_root: str + state_path: str + resume_chunks: bool + lineage_signature: Mapping[str, Any] + + def __post_init__(self) -> None: + _validate_positive_int(self.schema_version, "schema_version") + _validate_string(self.run_id, "run_id") + _validate_string(self.chunk_root, "chunk_root") + _validate_string(self.state_path, "state_path") + _validate_bool(self.resume_chunks, "resume_chunks") + if not isinstance(self.lineage_signature, Mapping): + raise ValueError("lineage_signature must be a mapping") + if not isinstance(self.chunk_ids, tuple) or not self.chunk_ids: + raise ValueError("chunk_ids must be a non-empty tuple") + for chunk_id in self.chunk_ids: + _validate_non_negative_int(chunk_id, "chunk_ids") + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "ChunkBuildRequest": + """Parse a chunk worker request from JSON-compatible material.""" + + if not isinstance(data, Mapping): + raise ValueError("chunk build request must be a mapping") + return cls( + schema_version=int(data.get("schema_version", 1)), + run_id=_string_value(data.get("run_id"), "run_id"), + chunk_ids=tuple( + _non_negative_int_value(chunk_id, "chunk_ids") + for chunk_id in data.get("chunk_ids", ()) + ), + chunk_root=_string_value(data.get("chunk_root"), "chunk_root"), + state_path=_string_value(data.get("state_path"), "state_path"), + resume_chunks=_bool_value(data.get("resume_chunks"), "resume_chunks"), + lineage_signature=_mapping_value( + data.get("lineage_signature"), + "lineage_signature", + ), + ) + + def to_dict(self) -> dict[str, Any]: + """Return deterministic JSON-compatible worker request material.""" + + return { + "chunk_ids": list(self.chunk_ids), + "chunk_root": self.chunk_root, + "lineage_signature": dict(self.lineage_signature), + "resume_chunks": self.resume_chunks, + "run_id": self.run_id, + "schema_version": self.schema_version, + "state_path": self.state_path, + } + + +@dataclass(frozen=True, kw_only=True) +class ChunkExecutionResult: + """Structured per-chunk progress or error metadata.""" + + schema_version: int + run_id: str + chunk_id: int + status: ChunkExecutionStatus + nnz: int | None = None + cached: bool = False + error: str | None = None + traceback: str | None = None + n_households: int | None = None + n_persons: int | None = None + unique_states: int | None = None + unique_counties: int | None = None + unique_cds: int | None = None + + def __post_init__(self) -> None: + _validate_positive_int(self.schema_version, "schema_version") + _validate_string(self.run_id, "run_id") + _validate_non_negative_int(self.chunk_id, "chunk_id") + _validate_bool(self.cached, "cached") + if self.status not in CHUNK_EXECUTION_STATUSES: + raise ValueError( + f"status must be one of {sorted(CHUNK_EXECUTION_STATUSES)}" + ) + _validate_optional_non_negative_int(self.nnz, "nnz") + for key in ( + "n_households", + "n_persons", + "unique_states", + "unique_counties", + "unique_cds", + ): + _validate_optional_non_negative_int(getattr(self, key), key) + for key in ("error", "traceback"): + value = getattr(self, key) + if value is not None and not isinstance(value, str): + raise ValueError(f"{key} must be a string or None") + if self.status == "failed": + if not self.error: + raise ValueError("failed chunk results require error") + if self.nnz is not None: + raise ValueError("failed chunk results must not report nnz") + else: + _validate_non_negative_int(self.nnz, "nnz") + if self.error is not None or self.traceback is not None: + raise ValueError("successful chunk results must not report errors") + if self.status == "cached" and not self.cached: + raise ValueError("cached chunk results require cached=True") + if self.status == "completed" and self.cached: + raise ValueError("completed chunk results require cached=False") + + @classmethod + def from_chunk_result( + cls, + *, + run_id: str, + result: Any, + ) -> "ChunkExecutionResult": + """Normalize an assembler `ChunkResult` into manifest metadata.""" + + cached = bool(getattr(result, "cached")) + return cls( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id=run_id, + chunk_id=int(getattr(result, "chunk_id")), + status="cached" if cached else "completed", + nnz=int(getattr(result, "nnz")), + cached=cached, + n_households=getattr(result, "n_households", None), + n_persons=getattr(result, "n_persons", None), + unique_states=getattr(result, "unique_states", None), + unique_counties=getattr(result, "unique_counties", None), + unique_cds=getattr(result, "unique_cds", None), + ) + + @classmethod + def failure( + cls, + *, + run_id: str, + chunk_id: int, + error: str, + traceback: str | None = None, + ) -> "ChunkExecutionResult": + """Create failure metadata for a chunk worker error.""" + + return cls( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id=run_id, + chunk_id=chunk_id, + status="failed", + cached=False, + error=error, + traceback=traceback, + ) + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "ChunkExecutionResult": + """Parse per-chunk result metadata.""" + + if not isinstance(data, Mapping): + raise ValueError("chunk execution result must be a mapping") + return cls( + schema_version=int(data.get("schema_version", 1)), + run_id=_string_value(data.get("run_id"), "run_id"), + chunk_id=_non_negative_int_value(data.get("chunk_id"), "chunk_id"), + status=_string_value(data.get("status"), "status"), + nnz=_optional_non_negative_int_value(data.get("nnz"), "nnz"), + cached=_bool_value(data.get("cached", False), "cached"), + error=data.get("error"), + traceback=data.get("traceback"), + n_households=_optional_non_negative_int_value( + data.get("n_households"), + "n_households", + ), + n_persons=_optional_non_negative_int_value( + data.get("n_persons"), + "n_persons", + ), + unique_states=_optional_non_negative_int_value( + data.get("unique_states"), + "unique_states", + ), + unique_counties=_optional_non_negative_int_value( + data.get("unique_counties"), + "unique_counties", + ), + unique_cds=_optional_non_negative_int_value( + data.get("unique_cds"), + "unique_cds", + ), + ) + + def to_dict(self) -> dict[str, Any]: + """Return deterministic JSON-compatible chunk result material.""" + + return { + "cached": self.cached, + "chunk_id": self.chunk_id, + "error": self.error, + "n_households": self.n_households, + "n_persons": self.n_persons, + "nnz": self.nnz, + "run_id": self.run_id, + "schema_version": self.schema_version, + "status": self.status, + "traceback": self.traceback, + "unique_cds": self.unique_cds, + "unique_counties": self.unique_counties, + "unique_states": self.unique_states, + } + + +@dataclass(frozen=True, kw_only=True) +class ChunkWorkerResult: + """Structured result returned by one chunked matrix worker.""" + + schema_version: int + run_id: str + chunk_ids: tuple[int, ...] + chunk_results: tuple[ChunkExecutionResult, ...] + + def __post_init__(self) -> None: + _validate_positive_int(self.schema_version, "schema_version") + _validate_string(self.run_id, "run_id") + if not isinstance(self.chunk_ids, tuple): + raise ValueError("chunk_ids must be a tuple") + for chunk_id in self.chunk_ids: + _validate_non_negative_int(chunk_id, "chunk_ids") + if not isinstance(self.chunk_results, tuple) or not all( + isinstance(result, ChunkExecutionResult) for result in self.chunk_results + ): + raise ValueError("chunk_results must contain ChunkExecutionResult entries") + result_ids = tuple(result.chunk_id for result in self.chunk_results) + if sorted(result_ids) != sorted(self.chunk_ids): + raise ValueError( + "chunk_results must include one result per requested chunk" + ) + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "ChunkWorkerResult": + """Parse worker result material.""" + + if not isinstance(data, Mapping): + raise ValueError("chunk worker result must be a mapping") + chunk_results = tuple( + ChunkExecutionResult.from_dict(item) + for item in data.get("chunk_results", ()) + ) + return cls( + schema_version=int(data.get("schema_version", 1)), + run_id=_string_value(data.get("run_id"), "run_id"), + chunk_ids=tuple( + _non_negative_int_value(chunk_id, "chunk_ids") + for chunk_id in data.get("chunk_ids", ()) + ), + chunk_results=chunk_results, + ) + + @property + def errors(self) -> tuple[ChunkExecutionResult, ...]: + """Return failed per-chunk results.""" + + return tuple( + result for result in self.chunk_results if result.status == "failed" + ) + + @property + def completed_count(self) -> int: + """Return the number of completed or cached chunks.""" + + return len(self.chunk_results) - len(self.errors) + + @property + def cached_count(self) -> int: + """Return the number of cached chunks.""" + + return sum(1 for result in self.chunk_results if result.status == "cached") + + def to_dict(self) -> dict[str, Any]: + """Return deterministic JSON-compatible worker result material.""" + + errors = [ + { + "chunk_id": result.chunk_id, + "error": result.error, + "traceback": result.traceback, + } + for result in self.errors + ] + return { + "cached_count": self.cached_count, + "chunk_ids": list(self.chunk_ids), + "chunk_results": [result.to_dict() for result in self.chunk_results], + "completed_count": self.completed_count, + "error_count": len(errors), + "errors": errors, + "nnz_per_chunk": [ + result.nnz for result in self.chunk_results if result.status != "failed" + ], + "run_id": self.run_id, + "schema_version": self.schema_version, + } + + +def chunk_result_manifest_path(chunk_root: str | Path, chunk_id: int) -> Path: + """Return the per-chunk result manifest path for a chunk root.""" + + _validate_non_negative_int(chunk_id, "chunk_id") + return Path(chunk_root) / "chunk_results" / f"chunk_{chunk_id:06d}.json" + + +def write_chunk_result_manifest( + chunk_root: str | Path, + result: ChunkExecutionResult, +) -> Path: + """Write one structured chunk result manifest and return the path.""" + + output_path = chunk_result_manifest_path(chunk_root, result.chunk_id) + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text( + json.dumps(result.to_dict(), indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + return output_path @pipeline_node( @@ -402,21 +815,88 @@ def _validate_bool(value: Any, key: str) -> None: raise ValueError(f"{key} must be a boolean") +def _validate_string(value: Any, key: str) -> None: + if not isinstance(value, str): + raise ValueError(f"{key} must be a string") + + def _validate_positive_int(value: Any, key: str) -> None: if isinstance(value, bool) or not isinstance(value, int) or value <= 0: raise ValueError(f"{key} must be a positive integer") +def _validate_non_negative_int(value: Any, key: str) -> None: + if isinstance(value, bool) or not isinstance(value, int) or value < 0: + raise ValueError(f"{key} must be a non-negative integer") + + def _validate_optional_positive_int(value: Any, key: str) -> None: if value is not None: _validate_positive_int(value, key) +def _validate_optional_non_negative_int(value: Any, key: str) -> None: + if value is not None: + _validate_non_negative_int(value, key) + + +def _string_value(value: Any, key: str) -> str: + _validate_string(value, key) + return value + + +def _bool_value(value: Any, key: str) -> bool: + _validate_bool(value, key) + return value + + +def _non_negative_int_value(value: Any, key: str) -> int: + _validate_non_negative_int(value, key) + return value + + +def _optional_non_negative_int_value(value: Any, key: str) -> int | None: + if value is None: + return None + _validate_non_negative_int(value, key) + return value + + +def _mapping_value(value: Any, key: str) -> Mapping[str, Any]: + if not isinstance(value, Mapping): + raise ValueError(f"{key} must be a mapping") + return value + + +def _lineage_mismatches( + stored_signature: Mapping[str, Any], + expected_signature: Mapping[str, Any], +) -> list[str]: + mismatches: list[str] = [] + for key, expected_value in expected_signature.items(): + stored_value = stored_signature.get(key) + if stored_value is None: + mismatches.append(f"{key} missing from stored signature") + elif stored_value != expected_value: + mismatches.append(f"{key} expected {stored_value}, got {expected_value}") + return mismatches + + __all__ = [ + "CHUNK_CACHE_MANIFEST_SCHEMA_VERSION", + "CHUNK_EXECUTION_SCHEMA_VERSION", + "CHUNK_EXECUTION_STATUSES", "MATRIX_BUILD_SCHEMA_VERSION", "MATRIX_BUILDER_MODES", + "ChunkBuildRequest", + "ChunkCacheManifest", + "ChunkExecutionResult", + "ChunkExecutionStatus", + "ChunkWorkerResult", "MatrixBuilderMode", "MatrixBuildResult", "MatrixBuildService", "MatrixBuildSpec", + "chunk_result_manifest_path", + "write_chunk_result_manifest", ] diff --git a/tests/unit/calibration/test_chunked_matrix_assembler.py b/tests/unit/calibration/test_chunked_matrix_assembler.py index 589867880..8d8a77caa 100644 --- a/tests/unit/calibration/test_chunked_matrix_assembler.py +++ b/tests/unit/calibration/test_chunked_matrix_assembler.py @@ -10,6 +10,7 @@ from __future__ import annotations import gc +import json from pathlib import Path from typing import List from unittest import mock @@ -25,6 +26,10 @@ partition_chunks, stream_csr_from_shards, ) +from policyengine_us_data.calibration_package.matrix import ( + ChunkExecutionResult, + chunk_result_manifest_path, +) def _write_shard( @@ -76,6 +81,7 @@ def _make_shared_state( cd_geoid=np.zeros(n_total, dtype="U4"), county_fips=np.zeros(n_total, dtype="U5"), state_fips=np.zeros(n_total, dtype=np.int32), + lineage_signature={"run_id": "run-test", "chunk_size": 10}, ) @@ -274,6 +280,11 @@ def test_assembler_skips_existing_shards_when_resume(tmp_path: Path) -> None: assert result.cached is True assert result.nnz == 2 assert result.chunk_id == 0 + manifest = ChunkExecutionResult.from_dict( + json.loads(chunk_result_manifest_path(tmp_path, 0).read_text(encoding="utf-8")) + ) + assert manifest.status == "cached" + assert manifest.run_id == "run-test" def test_assembler_rejects_shard_with_mismatched_range(tmp_path: Path) -> None: @@ -330,6 +341,7 @@ def test_shared_build_state_roundtrips_pickle() -> None: assert np.array_equal(restored.cd_geoid, state.cd_geoid) assert np.array_equal(restored.county_fips, state.county_fips) assert np.array_equal(restored.state_fips, state.state_fips) + assert restored.lineage_signature == state.lineage_signature def test_assembler_run_chunks_dispatches_each_id(tmp_path: Path) -> None: @@ -354,3 +366,28 @@ def fake_run(chunk_id: int) -> ChunkResult: assert observed == [0, 2] assert [r.chunk_id for r in results] == [0, 2] + + +def test_assembler_records_chunk_error_manifest(tmp_path: Path) -> None: + state = _make_shared_state(n_records=10, n_clones=2, n_targets=3) + assembler = ChunkedMatrixAssembler( + shared_state=state, + chunk_root=tmp_path, + chunk_size=10, + resume=False, + keep_chunks=False, + ) + + manifest_path = assembler.record_chunk_error( + chunk_id=1, + error="worker failed", + traceback="traceback text", + ) + + assert manifest_path == chunk_result_manifest_path(tmp_path, 1) + manifest = ChunkExecutionResult.from_dict( + json.loads(manifest_path.read_text(encoding="utf-8")) + ) + assert manifest.status == "failed" + assert manifest.error == "worker failed" + assert manifest.traceback == "traceback text" diff --git a/tests/unit/calibration/test_chunked_matrix_modal.py b/tests/unit/calibration/test_chunked_matrix_modal.py index 641ede502..ec61d4233 100644 --- a/tests/unit/calibration/test_chunked_matrix_modal.py +++ b/tests/unit/calibration/test_chunked_matrix_modal.py @@ -19,6 +19,12 @@ from policyengine_us_data.calibration.chunked_matrix_assembler import ( SharedBuildState, ) +from policyengine_us_data.calibration_package.matrix import ( + CHUNK_EXECUTION_SCHEMA_VERSION, + ChunkBuildRequest, + ChunkExecutionResult, + ChunkWorkerResult, +) from policyengine_us_data.calibration.chunked_matrix_modal import ( _lookup_worker_function, dispatch_chunks_modal, @@ -112,6 +118,7 @@ def _minimal_shared_state( cd_geoid=np.zeros(n_total, dtype="U4"), county_fips=np.zeros(n_total, dtype="U5"), state_fips=np.zeros(n_total, dtype=np.int32), + lineage_signature={"run_id": "run-test", "chunk_size": chunk_size}, ) @@ -162,26 +169,37 @@ def test_dispatch_spawns_per_batch_and_assembles(tmp_path: Path) -> None: # by the time assemble_final() runs, the shard files exist. spawn_calls: List[Dict] = [] - def fake_spawn( - *, run_id: str, chunk_ids: List[int], resume_chunks: bool - ) -> _FakeHandle: + def fake_spawn(*, request: Dict) -> _FakeHandle: + request_obj = ChunkBuildRequest.from_dict(request) spawn_calls.append( { - "run_id": run_id, - "chunk_ids": list(chunk_ids), - "resume_chunks": resume_chunks, + "run_id": request_obj.run_id, + "chunk_ids": list(request_obj.chunk_ids), + "resume_chunks": request_obj.resume_chunks, + "lineage_signature": dict(request_obj.lineage_signature), } ) - for chunk_id in chunk_ids: + results = [] + for chunk_id in request_obj.chunk_ids: col_start = chunk_id * state.chunk_size col_end = col_start + state.chunk_size _write_fake_shard(tmp_path / "coo", chunk_id, col_start, col_end) + results.append( + ChunkExecutionResult( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id=request_obj.run_id, + chunk_id=chunk_id, + status="completed", + nnz=1, + ) + ) return _FakeHandle( - { - "chunk_ids": list(chunk_ids), - "nnz_per_chunk": [1] * len(chunk_ids), - "errors": [], - } + ChunkWorkerResult( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id=request_obj.run_id, + chunk_ids=request_obj.chunk_ids, + chunk_results=tuple(results), + ).to_dict() ) fake_worker = mock.MagicMock() @@ -201,6 +219,7 @@ def fake_spawn( assert [c["chunk_ids"] for c in spawn_calls] == [[0, 1], [2, 3]] # Every spawn carried the run_id. assert all(c["run_id"] == "run-test" for c in spawn_calls) + assert all(c["lineage_signature"] == state.lineage_signature for c in spawn_calls) # Fresh dispatch must not let workers trust stale shards. assert all(c["resume_chunks"] is False for c in spawn_calls) # Final CSR covers all 4 chunks' nnz. @@ -227,12 +246,28 @@ def fake_volume_from_name(name: str, **kwargs): ) monkeypatch.setenv("US_DATA_PIPELINE_VOLUME_NAME", "pipeline-artifacts-run") - def fake_spawn( - *, run_id: str, chunk_ids: List[int], resume_chunks: bool - ) -> _FakeHandle: - for chunk_id in chunk_ids: + def fake_spawn(*, request: Dict) -> _FakeHandle: + request_obj = ChunkBuildRequest.from_dict(request) + results = [] + for chunk_id in request_obj.chunk_ids: _write_fake_shard(tmp_path / "coo", chunk_id, 0, 5) - return _FakeHandle({"chunk_ids": chunk_ids, "nnz_per_chunk": [1], "errors": []}) + results.append( + ChunkExecutionResult( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id=request_obj.run_id, + chunk_id=chunk_id, + status="completed", + nnz=1, + ) + ) + return _FakeHandle( + ChunkWorkerResult( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id=request_obj.run_id, + chunk_ids=request_obj.chunk_ids, + chunk_results=tuple(results), + ).to_dict() + ) fake_worker = mock.MagicMock() fake_worker.spawn.side_effect = fake_spawn @@ -274,17 +309,23 @@ def test_dispatch_aggregates_worker_errors(tmp_path: Path) -> None: state = _minimal_shared_state(n_records=10, n_clones=2, chunk_size=10) # n_total=20, chunk_size=10 -> 2 chunks, 2 workers. - def fake_spawn( - *, run_id: str, chunk_ids: List[int], resume_chunks: bool - ) -> _FakeHandle: + def fake_spawn(*, request: Dict) -> _FakeHandle: + request_obj = ChunkBuildRequest.from_dict(request) # First worker returns a per-chunk error; second crashes in .get(). - if chunk_ids == [0]: + if request_obj.chunk_ids == (0,): return _FakeHandle( - { - "chunk_ids": chunk_ids, - "nnz_per_chunk": [], - "errors": [{"chunk_id": 0, "error": "boom"}], - } + ChunkWorkerResult( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id=request_obj.run_id, + chunk_ids=request_obj.chunk_ids, + chunk_results=( + ChunkExecutionResult.failure( + run_id=request_obj.run_id, + chunk_id=0, + error="boom", + ), + ), + ).to_dict() ) return _FakeHandle(None, raise_on_get=RuntimeError("worker oom")) @@ -309,12 +350,28 @@ def test_dispatch_writes_shared_state_pickle(tmp_path: Path) -> None: state = _minimal_shared_state(n_records=5, n_clones=1, chunk_size=10) # n_total=5, chunk_size=10 -> 1 chunk, 1 worker. - def fake_spawn( - *, run_id: str, chunk_ids: List[int], resume_chunks: bool - ) -> _FakeHandle: - for chunk_id in chunk_ids: + def fake_spawn(*, request: Dict) -> _FakeHandle: + request_obj = ChunkBuildRequest.from_dict(request) + results = [] + for chunk_id in request_obj.chunk_ids: _write_fake_shard(tmp_path / "coo", chunk_id, 0, 5) - return _FakeHandle({"chunk_ids": chunk_ids, "nnz_per_chunk": [1], "errors": []}) + results.append( + ChunkExecutionResult( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id=request_obj.run_id, + chunk_id=chunk_id, + status="completed", + nnz=1, + ) + ) + return _FakeHandle( + ChunkWorkerResult( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id=request_obj.run_id, + chunk_ids=request_obj.chunk_ids, + chunk_results=tuple(results), + ).to_dict() + ) fake_worker = mock.MagicMock() fake_worker.spawn.side_effect = fake_spawn @@ -342,19 +399,35 @@ def test_dispatch_forwards_resume_chunks_to_workers(tmp_path: Path) -> None: state = _minimal_shared_state(n_records=5, n_clones=1, chunk_size=10) spawn_calls: List[Dict] = [] - def fake_spawn( - *, run_id: str, chunk_ids: List[int], resume_chunks: bool - ) -> _FakeHandle: + def fake_spawn(*, request: Dict) -> _FakeHandle: + request_obj = ChunkBuildRequest.from_dict(request) spawn_calls.append( { - "run_id": run_id, - "chunk_ids": list(chunk_ids), - "resume_chunks": resume_chunks, + "run_id": request_obj.run_id, + "chunk_ids": list(request_obj.chunk_ids), + "resume_chunks": request_obj.resume_chunks, } ) - for chunk_id in chunk_ids: + results = [] + for chunk_id in request_obj.chunk_ids: _write_fake_shard(tmp_path / "coo", chunk_id, 0, 5) - return _FakeHandle({"chunk_ids": chunk_ids, "nnz_per_chunk": [1], "errors": []}) + results.append( + ChunkExecutionResult( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id=request_obj.run_id, + chunk_id=chunk_id, + status="completed", + nnz=1, + ) + ) + return _FakeHandle( + ChunkWorkerResult( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id=request_obj.run_id, + chunk_ids=request_obj.chunk_ids, + chunk_results=tuple(results), + ).to_dict() + ) fake_worker = mock.MagicMock() fake_worker.spawn.side_effect = fake_spawn diff --git a/tests/unit/calibration_package/test_matrix.py b/tests/unit/calibration_package/test_matrix.py index d8ea5397f..5b0e7aa92 100644 --- a/tests/unit/calibration_package/test_matrix.py +++ b/tests/unit/calibration_package/test_matrix.py @@ -5,9 +5,15 @@ from scipy import sparse from policyengine_us_data.calibration_package.matrix import ( + CHUNK_EXECUTION_SCHEMA_VERSION, + ChunkBuildRequest, + ChunkCacheManifest, + ChunkExecutionResult, + ChunkWorkerResult, MatrixBuildResult, MatrixBuildService, MatrixBuildSpec, + write_chunk_result_manifest, ) from policyengine_us_data.stage_contracts.calibration_package_schema import ( MatrixBuildSummary, @@ -196,3 +202,128 @@ def test_matrix_build_service_normalizes_standard_and_chunked_outputs(tmp_path): assert standard.summary().matrix_builder == "precompute" assert chunked.summary().matrix_builder == "chunked" assert chunked.summary().chunk_shard_count == 1 + + +def test_chunk_cache_manifest_round_trips_and_rejects_lineage_mismatch(tmp_path): + signature = { + "format_version": 2, + "run_id": "run-a", + "matrix_builder": "chunked", + "dataset_sha256": "dataset-a", + "db_sha256": "db-a", + "target_names_sha256": "targets-a", + "targets_sha256": "target-frame-a", + "state_fips_sha256": "states-a", + "county_fips_sha256": "counties-a", + "cd_geoid_sha256": "districts-a", + "block_geoid_sha256": "blocks-a", + "chunk_size": 10, + } + manifest_path = ChunkCacheManifest.from_signature(signature).write( + tmp_path / "chunk_manifest.json" + ) + + restored = ChunkCacheManifest.read(manifest_path) + restored.validate_lineage(signature) + + for key in ( + "run_id", + "dataset_sha256", + "db_sha256", + "target_names_sha256", + "targets_sha256", + "state_fips_sha256", + "county_fips_sha256", + "cd_geoid_sha256", + "block_geoid_sha256", + "chunk_size", + ): + expected = dict(signature) + expected[key] = "different" if key != "chunk_size" else 25 + with pytest.raises(ValueError, match=key): + restored.validate_lineage(expected) + + +def test_chunk_build_request_round_trips(): + request = ChunkBuildRequest( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id="run-a", + chunk_ids=(0, 2), + chunk_root="/pipeline/artifacts/run-a/matrix_build", + state_path="/pipeline/artifacts/run-a/matrix_build/chunk_build_state.pkl", + resume_chunks=True, + lineage_signature={"run_id": "run-a", "chunk_size": 10}, + ) + + restored = ChunkBuildRequest.from_dict(request.to_dict()) + + assert restored == request + assert restored.to_dict()["chunk_ids"] == [0, 2] + + +def test_chunk_execution_result_manifest_round_trips(tmp_path): + result = ChunkExecutionResult( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id="run-a", + chunk_id=3, + status="completed", + nnz=12, + n_households=5, + n_persons=9, + unique_states=2, + ) + + manifest_path = write_chunk_result_manifest(tmp_path, result) + restored = ChunkExecutionResult.from_dict( + json.loads(manifest_path.read_text(encoding="utf-8")) + ) + + assert restored == result + + +def test_chunk_worker_result_round_trips_errors(): + worker_result = ChunkWorkerResult( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id="run-a", + chunk_ids=(0, 1), + chunk_results=( + ChunkExecutionResult( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id="run-a", + chunk_id=0, + status="completed", + nnz=3, + ), + ChunkExecutionResult.failure( + run_id="run-a", + chunk_id=1, + error="boom", + traceback="traceback", + ), + ), + ) + + restored = ChunkWorkerResult.from_dict(worker_result.to_dict()) + + assert restored == worker_result + assert restored.completed_count == 1 + assert len(restored.errors) == 1 + assert restored.to_dict()["errors"][0]["error"] == "boom" + + +def test_chunk_worker_result_rejects_missing_chunk_result(): + with pytest.raises(ValueError, match="one result per requested chunk"): + ChunkWorkerResult( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id="run-a", + chunk_ids=(0, 1), + chunk_results=( + ChunkExecutionResult( + schema_version=CHUNK_EXECUTION_SCHEMA_VERSION, + run_id="run-a", + chunk_id=0, + status="completed", + nnz=3, + ), + ), + ) diff --git a/tests/unit/test_pipeline_docs_extractor.py b/tests/unit/test_pipeline_docs_extractor.py index 221dce796..19aa75ad4 100644 --- a/tests/unit/test_pipeline_docs_extractor.py +++ b/tests/unit/test_pipeline_docs_extractor.py @@ -140,6 +140,8 @@ def test_pipeline_map_manifest_validates(): "stage2_geography_assignment_result", "stage2_matrix_build_spec", "stage2_matrix_build_service", + "stage2_chunk_build_request", + "stage2_chunk_worker_result", "stage2_matrix_build_result", "build_matrix", "build_matrix_chunked",