Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/1053.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add epsilon-insensitive calibration target tolerances, target-policy artifacts, and hard-fail versus warning enforcement for calibration diagnostics.
16 changes: 15 additions & 1 deletion modal_app/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def _calibration_package_parameters(
workers: int,
n_clones: int,
target_config: str | None,
target_policy: str | None,
skip_county: bool,
chunked_matrix: bool,
chunk_size: int,
Expand All @@ -174,6 +175,7 @@ def _calibration_package_parameters(
"workers": workers if not chunked_matrix else None,
"n_clones": n_clones,
"target_config": target_config,
"target_policy": target_policy,
"skip_county": skip_county,
"chunked_matrix": bool(chunked_matrix),
"chunk_size": chunk_size if chunked_matrix else None,
Expand Down Expand Up @@ -281,6 +283,8 @@ def archive_diagnostics(
"log": f"{prefix}unified_diagnostics.csv",
"cal_log": f"{prefix}calibration_log.csv",
"config": f"{prefix}unified_run_config.json",
"target_policy": f"{prefix}calibration_target_policy.jsonl",
"target_policy_summary": (f"{prefix}calibration_target_policy_summary.json"),
}

for key, filename in file_map.items():
Expand Down Expand Up @@ -1242,6 +1246,7 @@ def run_pipeline(
workers=num_workers,
n_clones=n_clones,
target_config=None,
target_policy="policyengine_us_data/calibration/target_policy.yaml",
skip_county=True,
chunked_matrix=chunked_matrix,
chunk_size=chunk_size,
Expand Down Expand Up @@ -1302,7 +1307,12 @@ def run_pipeline(
completed_package_manifest = _complete_step_manifest(
active_step_manifest,
outputs=collect_artifacts(
[_artifacts_dir(run_id) / "calibration_package.pkl"],
[
_artifacts_dir(run_id) / "calibration_package.pkl",
_artifacts_dir(run_id) / "calibration_target_policy.jsonl",
_artifacts_dir(run_id)
/ "calibration_target_policy_summary.json",
],
missing_ok=True,
),
vol=pipeline_volume,
Expand All @@ -1321,19 +1331,23 @@ def run_pipeline(
"gpu": gpu,
"epochs": epochs,
"target_config": "policyengine_us_data/calibration/target_config.yaml",
"target_policy": "policyengine_us_data/calibration/target_policy.yaml",
"beta": 0.65,
"lambda_l0": 1e-7,
"lambda_l2": 1e-8,
"log_freq": 100,
"loss_type": "relative_epsilon",
}
national_fit_parameters = {
"gpu": national_gpu,
"epochs": national_epochs,
"target_config": "policyengine_us_data/calibration/target_config.yaml",
"target_policy": "policyengine_us_data/calibration/target_policy.yaml",
"beta": 0.65,
"lambda_l0": NATIONAL_FIT_LAMBDA_L0,
"lambda_l2": 1e-12,
"log_freq": 100,
"loss_type": "relative_epsilon",
"skip_national": skip_national,
}
regional_fit_reuse = _step_reusable(
Expand Down
20 changes: 20 additions & 0 deletions modal_app/remote_calibration_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def _collect_outputs(cal_lines):
log_path = None
cal_log_path = None
config_path = None
target_policy_path = None
target_policy_summary_path = None
for line in cal_lines:
if "OUTPUT_PATH:" in line:
output_path = line.split("OUTPUT_PATH:")[1].strip()
Expand All @@ -110,6 +112,12 @@ def _collect_outputs(cal_lines):
cal_log_path = line.split("CAL_LOG_PATH:")[1].strip()
elif "LOG_PATH:" in line:
log_path = line.split("LOG_PATH:")[1].strip()
elif "TARGET_POLICY_PATH:" in line:
target_policy_path = line.split("TARGET_POLICY_PATH:")[1].strip()
elif "TARGET_POLICY_SUMMARY_PATH:" in line:
target_policy_summary_path = line.split("TARGET_POLICY_SUMMARY_PATH:")[
1
].strip()

with open(output_path, "rb") as f:
weights_bytes = f.read()
Expand All @@ -134,12 +142,24 @@ def _collect_outputs(cal_lines):
with open(config_path, "rb") as f:
config_bytes = f.read()

target_policy_bytes = None
if target_policy_path:
with open(target_policy_path, "rb") as f:
target_policy_bytes = f.read()

target_policy_summary_bytes = None
if target_policy_summary_path:
with open(target_policy_summary_path, "rb") as f:
target_policy_summary_bytes = f.read()

return {
"weights": weights_bytes,
"geography": geography_bytes,
"log": log_bytes,
"cal_log": cal_log_bytes,
"config": config_bytes,
"target_policy": target_policy_bytes,
"target_policy_summary": target_policy_summary_bytes,
}


Expand Down
2 changes: 1 addition & 1 deletion modal_app/step_manifests/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def artifact_identities(paths: dict[str, str | Path]) -> dict:
def collect_diagnostics(run_id: str) -> list[ArtifactReference]:
return collect_directory_artifacts(
run_dir(run_id) / "diagnostics",
patterns=("*.csv", "*.json", "*.txt"),
patterns=("*.csv", "*.json", "*.jsonl", "*.txt"),
role="diagnostic",
)

Expand Down
17 changes: 17 additions & 0 deletions policyengine_us_data/calibration/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def build_checkpoint_signature(
lambda_l2: float,
learning_rate: float,
target_groups: np.ndarray | None = None,
target_weights: np.ndarray | None = None,
target_tolerances: np.ndarray | None = None,
target_scales: np.ndarray | None = None,
calibration_loss_type: str = "relative",
) -> dict:
"""Build a compact signature to validate calibration checkpoint resume."""
targets_arr = np.asarray(targets, dtype=np.float64)
Expand All @@ -116,20 +120,33 @@ def build_checkpoint_signature(
if target_groups is None
else np.asarray(target_groups, dtype=np.int64)
)
target_weights_arr = _optional_float_signature_array(target_weights)
target_tolerances_arr = _optional_float_signature_array(target_tolerances)
target_scales_arr = _optional_float_signature_array(target_scales)
return {
"n_features": int(X_sparse.shape[1]),
"n_targets": int(len(targets_arr)),
"x_sparse_sha256": hash_sparse_matrix(X_sparse),
"target_names_sha256": hash_string_list(target_names),
"targets_sha256": hashlib.sha256(targets_arr.tobytes()).hexdigest(),
"target_groups_sha256": hash_numpy_array(target_groups_arr),
"target_weights_sha256": hash_numpy_array(target_weights_arr),
"target_tolerances_sha256": hash_numpy_array(target_tolerances_arr),
"target_scales_sha256": hash_numpy_array(target_scales_arr),
"calibration_loss_type": str(calibration_loss_type),
"lambda_l0": float(lambda_l0),
"beta": float(beta),
"lambda_l2": float(lambda_l2),
"learning_rate": float(learning_rate),
}


def _optional_float_signature_array(values: np.ndarray | None) -> np.ndarray:
if values is None:
return np.array([], dtype=np.float64)
return np.asarray(values, dtype=np.float64)


def checkpoint_signature_mismatches(
expected: dict,
actual: dict,
Expand Down
Loading
Loading