Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/fixed/3394.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Record resolved PolicyEngine bundle metadata from the runtime that actually executed society-wide simulations, and key reproduce/cache behavior off the resolved dataset bundle rather than caller-side defaults.
156 changes: 103 additions & 53 deletions policyengine_api/country.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import importlib
from flask import Response
import inspect
import json
from policyengine_core.taxbenefitsystems import TaxBenefitSystem
from typing import Union, Optional
Expand All @@ -22,14 +22,6 @@
build_congressional_district_metadata,
)

# Note: The following policyengine_[xx] imports are probably redundant.
# These modules are imported dynamically in the __init__ function below.
import policyengine_uk
import policyengine_us
import policyengine_canada
import policyengine_ng
import policyengine_il

from policyengine_api.data import local_database
from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS

Expand All @@ -45,24 +37,40 @@ def __init__(self, country_package_name: str, country_id: str):
self.build_metadata()

def build_metadata(self):
self.metadata = dict(
variables=self.build_variables(),
parameters=self.build_parameters(),
entities=self.build_entities(),
variableModules=self.tax_benefit_system.variable_module_metadata,
economy_options=self.build_microsimulation_options(),
current_law_id={
"uk": 1,
"us": 2,
"ca": 3,
"ng": 4,
"il": 5,
}[self.country_id],
basicInputs=self.tax_benefit_system.basic_inputs,
modelled_policies=self.tax_benefit_system.modelled_policies,
version=get_package_version(self.country_package_name.replace("_", "-")),
self.metadata = self._json_safe(
dict(
variables=self.build_variables(),
parameters=self.build_parameters(),
entities=self.build_entities(),
variableModules=self.tax_benefit_system.variable_module_metadata,
economy_options=self.build_microsimulation_options(),
current_law_id={
"uk": 1,
"us": 2,
"ca": 3,
"ng": 4,
"il": 5,
}[self.country_id],
basicInputs=self.tax_benefit_system.basic_inputs,
modelled_policies=self.tax_benefit_system.modelled_policies,
version=get_package_version(
self.country_package_name.replace("_", "-")
),
)
)

def _json_safe(self, value):
if isinstance(value, Path):
return str(value)
if isinstance(value, dict):
return {
key: self._json_safe(nested_value)
for key, nested_value in value.items()
}
if isinstance(value, list):
return [self._json_safe(nested_value) for nested_value in value]
return value

def build_microsimulation_options(self) -> dict:
# { region: [{ name: "uk", label: "the UK" }], time_period: [{ name: 2022, label: "2022", ... }] }
options = dict()
Expand Down Expand Up @@ -363,31 +371,7 @@ def calculate(
household_id: Optional[int] = None,
policy_id: Optional[int] = None,
):
if reform is not None and len(reform.keys()) > 0:
system = self.tax_benefit_system.clone()
for parameter_name in reform:
for time_period, value in reform[parameter_name].items():
start_instant, end_instant = time_period.split(".")
parameter = get_parameter(system.parameters, parameter_name)
node_type = type(parameter.values_list[-1].value)
if node_type == int:
node_type = float
try:
value = float(value)
except:
pass
parameter.update(
start=instant(start_instant),
stop=instant(end_instant),
value=node_type(value),
)
else:
system = self.tax_benefit_system

simulation = self.country_package.Simulation(
tax_benefit_system=system,
situation=household,
)
simulation, system = self._create_simulation(household, reform)

household = json.loads(json.dumps(household))

Expand Down Expand Up @@ -429,14 +413,14 @@ def calculate(
entity_index = population.get_index(entity_id)
if variable.value_type == Enum:
entity_result = result.decode()[entity_index].name
elif variable.value_type == float:
elif variable.value_type is float:
entity_result = float(str(result[entity_index]))
# Convert infinities to JSON infinities
if entity_result == float("inf"):
entity_result = "Infinity"
elif entity_result == float("-inf"):
entity_result = "-Infinity"
elif variable.value_type == str:
elif variable.value_type is str:
entity_result = str(result[entity_index])
else:
entity_result = result.tolist()[entity_index]
Expand Down Expand Up @@ -473,6 +457,72 @@ def calculate(

return household

def _create_simulation(
self,
household: dict,
reform: Union[dict, None],
):
normalized_reform = None
if reform:
system = self.tax_benefit_system.clone()
normalized_reform = self._normalize_reform_values(reform, system)
else:
system = self.tax_benefit_system

if self._simulation_accepts_tax_benefit_system():
if normalized_reform:
self._apply_reform_to_system(system, normalized_reform)
simulation = self.country_package.Simulation(
tax_benefit_system=system,
situation=household,
)
return simulation, system

simulation_kwargs = {"situation": household}
if normalized_reform:
simulation_kwargs["reform"] = normalized_reform
simulation = self.country_package.Simulation(**simulation_kwargs)
return simulation, simulation.tax_benefit_system

def _simulation_accepts_tax_benefit_system(self) -> bool:
simulation_signature = inspect.signature(self.country_package.Simulation)
return "tax_benefit_system" in simulation_signature.parameters

def _normalize_reform_values(
self,
reform: dict,
system: TaxBenefitSystem,
) -> dict:
normalized_reform = {}
for parameter_name, parameter_updates in reform.items():
parameter = get_parameter(system.parameters, parameter_name)
normalized_reform[parameter_name] = {}
for time_period, value in parameter_updates.items():
node_type = type(parameter.values_list[-1].value)
if node_type is int:
node_type = float
try:
value = float(value)
except Exception:
pass
normalized_reform[parameter_name][time_period] = node_type(value)
return normalized_reform

def _apply_reform_to_system(
self,
system: TaxBenefitSystem,
reform: dict,
) -> None:
for parameter_name, parameter_updates in reform.items():
parameter = get_parameter(system.parameters, parameter_name)
for time_period, value in parameter_updates.items():
start_instant, end_instant = time_period.split(".")
parameter.update(
start=instant(start_instant),
stop=instant(end_instant),
value=value,
)


def create_policy_reform(policy_data: dict) -> dict:
"""
Expand All @@ -498,7 +548,7 @@ def modify_parameters(parameters: ParameterNode) -> ParameterNode:
for period, value in values.items():
start, end = period.split(".")
node_type = type(node.values_list[-1].value)
if node_type == int:
if node_type is int:
node_type = float # '0' is of type int by default, but usually we want to cast to float.
if node.values_list[-1].value is None:
node_type = float
Expand Down
37 changes: 7 additions & 30 deletions policyengine_api/data/model_setup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
ENHANCED_FRS = "hf://policyengine/policyengine-uk-data/enhanced_frs_2023_24.h5"
FRS = "hf://policyengine/policyengine-uk-data/frs_2023_24.h5"
ENHANCED_FRS = (
"hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3"
)
FRS = "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5@1.40.3"

ENHANCED_CPS = "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5"
CPS = "hf://policyengine/policyengine-us-data/cps_2023.h5"
POOLED_CPS = "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5"
ENHANCED_CPS = "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0"
CPS = "hf://policyengine/policyengine-us-data/cps_2023.h5@1.77.0"
POOLED_CPS = "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.77.0"

datasets = {
"uk": {
Expand All @@ -16,28 +18,3 @@
"pooled_cps": POOLED_CPS,
},
}


def get_dataset_version(country_id: str) -> str | None:
"""
Get the dataset version for the specified country. If PolicyEngine does not
publish data for the country, raise a ValueError.

By returning None for all valid countries, we allow policyengine.py to use
whatever default dataset version it has available, without imposing version
validation constraints from the API layer.
"""
match country_id:
case "uk":
return None
case "us":
return None
case _:
raise ValueError(f"Unknown country ID: {country_id}")


for dataset in datasets["uk"]:
datasets["uk"][dataset] = f"{datasets['uk'][dataset]}@{get_dataset_version('uk')}"

for dataset in datasets["us"]:
datasets["us"][dataset] = f"{datasets['us'][dataset]}@{get_dataset_version('us')}"
22 changes: 22 additions & 0 deletions policyengine_api/libs/simulation_api_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class ModalSimulationExecution:
status: str
result: Optional[dict] = None
error: Optional[str] = None
policyengine_bundle: Optional[dict] = None
resolved_app_name: Optional[str] = None

@property
def name(self) -> str:
Expand Down Expand Up @@ -94,6 +96,8 @@ def run(self, payload: dict) -> ModalSimulationExecution:
return ModalSimulationExecution(
job_id=data["job_id"],
status=data["status"],
policyengine_bundle=data.get("policyengine_bundle"),
resolved_app_name=data.get("resolved_app_name"),
)

except httpx.HTTPStatusError as e:
Expand All @@ -115,6 +119,22 @@ def run(self, payload: dict) -> ModalSimulationExecution:
)
raise

def resolve_app_name(
self, country: str, version: Optional[str] = None
) -> tuple[str, str]:
"""Resolve the current gateway app name for a country/model version."""
response = self.client.get(f"{self.base_url}/versions/{country}")
response.raise_for_status()
version_map = response.json()

resolved_version = version or version_map["latest"]
try:
return version_map[resolved_version], resolved_version
except KeyError as exc:
raise ValueError(
f"Unknown version {resolved_version} for country {country}"
) from exc

def get_execution_id(self, execution: ModalSimulationExecution) -> str:
"""
Get the job ID from an execution.
Expand Down Expand Up @@ -156,6 +176,8 @@ def get_execution_by_id(self, job_id: str) -> ModalSimulationExecution:
status=data["status"],
result=data.get("result"),
error=data.get("error"),
policyengine_bundle=data.get("policyengine_bundle"),
resolved_app_name=data.get("resolved_app_name"),
)

except httpx.HTTPStatusError as e:
Expand Down
Loading
Loading