Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions .github/workflows/pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Setup uv
uses: astral-sh/setup-uv@v5
- name: Check for changelog fragment
run: |
FRAGMENTS=$(find changelog.d -type f ! -name '.gitkeep' | wc -l)
if [ "$FRAGMENTS" -eq 0 ]; then
echo "::error::No changelog fragment found in changelog.d/"
echo "Add one with: echo 'Description.' > changelog.d/\$(git branch --show-current).<type>.md"
echo "Types: added, changed, fixed, removed, breaking"
exit 1
fi
run: uv run --with "towncrier>=24.8.0" towncrier check --compare-with origin/master
test_container_builds:
name: Docker
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions changelog.d/3394.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Attach PolicyEngine bundle metadata to economy results.
1 change: 0 additions & 1 deletion changelog.d/fixed/3394.md

This file was deleted.

6 changes: 6 additions & 0 deletions policyengine_api/libs/simulation_api_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ModalSimulationExecution:

job_id: str
status: str
run_id: Optional[str] = None
result: Optional[dict] = None
error: Optional[str] = None
policyengine_bundle: Optional[dict] = None
Expand Down Expand Up @@ -88,6 +89,7 @@ def run(self, payload: dict) -> ModalSimulationExecution:
{
"message": "Modal simulation job submitted",
"job_id": data.get("job_id"),
"run_id": data.get("run_id"),
"status": data.get("status"),
},
severity="INFO",
Expand All @@ -98,12 +100,14 @@ def run(self, payload: dict) -> ModalSimulationExecution:
status=data["status"],
policyengine_bundle=data.get("policyengine_bundle"),
resolved_app_name=data.get("resolved_app_name"),
run_id=data.get("run_id"),
)

except httpx.HTTPStatusError as e:
logger.log_struct(
{
"message": f"Modal API HTTP error: {e.response.status_code}",
"run_id": (payload.get("_telemetry") or {}).get("run_id"),
"response_text": e.response.text[:500],
},
severity="ERROR",
Expand All @@ -114,6 +118,7 @@ def run(self, payload: dict) -> ModalSimulationExecution:
logger.log_struct(
{
"message": f"Modal API request error: {str(e)}",
"run_id": (payload.get("_telemetry") or {}).get("run_id"),
},
severity="ERROR",
)
Expand Down Expand Up @@ -174,6 +179,7 @@ def get_execution_by_id(self, job_id: str) -> ModalSimulationExecution:
return ModalSimulationExecution(
job_id=job_id,
status=data["status"],
run_id=data.get("run_id"),
result=data.get("result"),
error=data.get("error"),
policyengine_bundle=data.get("policyengine_bundle"),
Expand Down
86 changes: 80 additions & 6 deletions policyengine_api/services/economy_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from policyengine.utils.data.datasets import get_default_dataset
import json
import datetime
import hashlib
import uuid
from typing import Literal, Any, Optional, Annotated
from dotenv import load_dotenv
from pydantic import BaseModel
Expand Down Expand Up @@ -357,7 +359,6 @@ def _determine_impact_action(
self,
most_recent_impact: dict | None,
) -> ImpactAction:

if not most_recent_impact:
return ImpactAction.CREATE

Expand Down Expand Up @@ -448,7 +449,6 @@ def _handle_computing_impact(
setup_options: EconomicImpactSetupOptions,
most_recent_impact: dict,
) -> EconomicImpactResult:

execution = simulation_api.get_execution_by_id(
most_recent_impact["execution_id"]
)
Expand Down Expand Up @@ -484,17 +484,21 @@ def _handle_create_impact(
data_version=setup_options.data_version,
)

sim_params = sim_config.model_dump(mode="json")
telemetry = self._build_simulation_telemetry(
setup_options=setup_options,
sim_config=sim_params,
)

logger.log_struct(
{
"message": "Setting up sim API job",
"run_id": telemetry["run_id"],
**setup_options.model_dump(),
}
)

# Build params with metadata for Logfire tracing in the simulation API.
# The _metadata field will be captured by the Logfire span before
# SimulationOptions validation (which silently ignores extra fields).
sim_params = sim_config.model_dump()
# Preserve both legacy metadata and the new telemetry envelope.
sim_params["_metadata"] = {
"reform_policy_id": setup_options.reform_policy_id,
"baseline_policy_id": setup_options.baseline_policy_id,
Expand All @@ -505,14 +509,17 @@ def _handle_create_impact(
"dataset": setup_options.dataset,
"resolved_app_name": setup_options.runtime_app_name,
}
sim_params["_telemetry"] = telemetry

sim_api_execution = simulation_api.run(sim_params)
execution_id = simulation_api.get_execution_id(sim_api_execution)
run_id = getattr(sim_api_execution, "run_id", None) or telemetry["run_id"]

progress_log = {
**setup_options.model_dump(),
"message": "Sim API job started",
"execution_id": execution_id,
"run_id": run_id,
}
logger.log_struct(progress_log, severity="INFO")

Expand Down Expand Up @@ -759,6 +766,73 @@ def _setup_data(
)
raise

def _build_simulation_telemetry(
self,
setup_options: EconomicImpactSetupOptions,
sim_config: dict[str, Any],
) -> dict[str, Any]:
simulation_kind, geography_type, geography_code = (
self._classify_simulation_geography(
country_id=setup_options.country_id,
region=setup_options.region,
)
)

return {
"run_id": str(uuid.uuid4()),
"process_id": setup_options.process_id,
"traceparent": self._get_current_traceparent(),
"requested_at": datetime.datetime.now(datetime.UTC).isoformat(),
"simulation_kind": simulation_kind,
"geography_code": geography_code,
"geography_type": geography_type,
"config_hash": self._stable_config_hash(sim_config),
"capture_mode": "disabled",
}

def _classify_simulation_geography(
self,
country_id: str,
region: str,
) -> tuple[str, str, str]:
if region == country_id:
return "national", "national", country_id

if "/" not in region:
return "other", "other", region

geography_type, geography_code = region.split("/", maxsplit=1)
simulation_kind = (
"district" if geography_type == "congressional_district" else geography_type
)
return simulation_kind, geography_type, geography_code

def _stable_config_hash(self, payload: dict[str, Any]) -> str:
encoded = json.dumps(
payload,
sort_keys=True,
separators=(",", ":"),
default=str,
).encode("utf-8")
return f"sha256:{hashlib.sha256(encoded).hexdigest()}"

def _get_current_traceparent(self) -> str | None:
try:
from opentelemetry import trace
except Exception:
return None

span = trace.get_current_span()
span_context = span.get_span_context()
if not getattr(span_context, "is_valid", False):
return None

trace_flags = int(getattr(span_context, "trace_flags", 0))
return (
f"00-{span_context.trace_id:032x}-"
f"{span_context.span_id:016x}-{trace_flags:02x}"
)

# Note: The following methods that interface with the ReformImpactsService
# are written separately because the service relies upon mutating an original
# 'computing' record to 'ok' or 'error' status, rather than creating a new record.
Expand Down
11 changes: 11 additions & 0 deletions tests/fixtures/libs/simulation_api_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

# Mock data constants
MOCK_MODAL_JOB_ID = "fc-abc123xyz"
MOCK_RUN_ID = "run-abc123xyz"
MOCK_MODAL_BASE_URL = "https://test-modal-api.modal.run"

MOCK_SIMULATION_PAYLOAD = {
Expand All @@ -31,6 +32,15 @@
"include_cliffs": False,
}

MOCK_SIMULATION_PAYLOAD_WITH_TELEMETRY = {
**MOCK_SIMULATION_PAYLOAD,
"_telemetry": {
"run_id": MOCK_RUN_ID,
"process_id": "job_20250626120000_1234",
"capture_mode": "disabled",
},
}

MOCK_SIMULATION_RESULT = {
"poverty_impact": {"baseline": 0.12, "reform": 0.10},
"budget_impact": {"baseline": 1000, "reform": 1200},
Expand All @@ -46,6 +56,7 @@

MOCK_SUBMIT_RESPONSE_SUCCESS = {
"job_id": MOCK_MODAL_JOB_ID,
"run_id": MOCK_RUN_ID,
"status": MODAL_EXECUTION_STATUS_SUBMITTED,
"poll_url": f"/jobs/{MOCK_MODAL_JOB_ID}",
"country": "us",
Expand Down
2 changes: 2 additions & 0 deletions tests/fixtures/services/economy_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
MOCK_MODAL_JOB_ID = "fc-test123xyz"
MOCK_EXECUTION_ID = MOCK_MODAL_JOB_ID # Alias for test compatibility
MOCK_RUN_ID = "run-test123xyz"
MOCK_PROCESS_ID = "job_20250626120000_1234"
MOCK_MODEL_VERSION = "1.2.3"
MOCK_POLICYENGINE_VERSION = "3.4.0"
Expand Down Expand Up @@ -248,6 +249,7 @@ def create_mock_modal_execution(
"""
mock_execution = MagicMock()
mock_execution.job_id = job_id
mock_execution.run_id = MOCK_RUN_ID
mock_execution.name = job_id # Alias for compatibility
mock_execution.status = status
mock_execution.result = result
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/libs/test_simulation_api_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
MOCK_MODAL_JOB_ID,
MOCK_MODAL_BASE_URL,
MOCK_SIMULATION_PAYLOAD,
MOCK_SIMULATION_PAYLOAD_WITH_TELEMETRY,
MOCK_RUN_ID,
MOCK_SIMULATION_RESULT,
MOCK_POLICYENGINE_BUNDLE,
MOCK_RESOLVED_APP_NAME,
Expand Down Expand Up @@ -136,6 +138,7 @@ def test__given_valid_payload__then_returns_execution_with_job_id(

# Then
assert execution.job_id == MOCK_MODAL_JOB_ID
assert execution.run_id == MOCK_RUN_ID
assert execution.status == MODAL_EXECUTION_STATUS_SUBMITTED
assert execution.policyengine_bundle == MOCK_POLICYENGINE_BUNDLE
assert execution.resolved_app_name == MOCK_RESOLVED_APP_NAME
Expand All @@ -161,6 +164,22 @@ def test__given_valid_payload__then_posts_to_correct_endpoint(
assert "/simulate/economy/comparison" in call_args[0][0]
assert call_args[1]["json"] == MOCK_SIMULATION_PAYLOAD

def test__given_telemetry_payload__then_preserves_it_in_post_body(
self,
mock_httpx_client,
mock_modal_logger,
):
mock_httpx_client.post.return_value = create_mock_httpx_response(
status_code=202,
json_data=MOCK_SUBMIT_RESPONSE_SUCCESS,
)
api = SimulationAPIModal()

api.run(MOCK_SIMULATION_PAYLOAD_WITH_TELEMETRY)

call_args = mock_httpx_client.post.call_args
assert call_args[1]["json"]["_telemetry"]["run_id"] == MOCK_RUN_ID

def test__given_http_error__then_raises_exception(
self,
mock_httpx_client,
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/services/test_economy_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
MOCK_OPTIONS_HASH,
MOCK_EXECUTION_ID,
MOCK_PROCESS_ID,
MOCK_RUN_ID,
MOCK_REFORM_IMPACT_DATA,
MOCK_RESOLVED_DATASET,
MOCK_RESOLVED_APP_NAME,
Expand Down Expand Up @@ -276,6 +277,38 @@ def test__given_no_previous_impact__includes_metadata_in_simulation_params(
sim_params["_metadata"]["resolved_app_name"] == MOCK_RESOLVED_APP_NAME
)

def test__given_no_previous_impact__includes_telemetry_in_simulation_params(
self,
economy_service,
base_params,
mock_country_package_versions,
mock_get_dataset_version,
mock_policy_service,
mock_reform_impacts_service,
mock_simulation_api,
mock_logger,
mock_datetime,
mock_numpy_random,
):
mock_reform_impacts_service.get_all_reform_impacts.return_value = []

economy_service.get_economic_impact(**base_params)

sim_params = mock_simulation_api.run.call_args[0][0]

assert sim_params["_telemetry"]["run_id"]
assert sim_params["_telemetry"]["process_id"] == MOCK_PROCESS_ID
assert sim_params["_telemetry"]["simulation_kind"] == "national"
assert sim_params["_telemetry"]["geography_type"] == "national"
assert sim_params["_telemetry"]["geography_code"] == MOCK_COUNTRY_ID
assert sim_params["_telemetry"]["capture_mode"] == "disabled"
assert sim_params["_telemetry"]["config_hash"].startswith("sha256:")
progress_log = mock_logger.log_struct.call_args_list[-1].args[0]
assert progress_log["run_id"] == MOCK_RUN_ID
assert (
mock_logger.log_struct.call_args_list[-1].kwargs["severity"] == "INFO"
)

def test__given_runtime_cache_version__uses_versioned_economy_cache_key(
self,
economy_service,
Expand Down
Loading