diff --git a/changelog.d/fixed/3394.md b/changelog.d/fixed/3394.md new file mode 100644 index 000000000..126469ec9 --- /dev/null +++ b/changelog.d/fixed/3394.md @@ -0,0 +1 @@ +Record resolved PolicyEngine bundle metadata from the runtime that actually executed society-wide simulations, and key reproduce/cache behavior off the resolved dataset bundle rather than caller-side defaults. diff --git a/policyengine_api/country.py b/policyengine_api/country.py index befa49851..4278637d8 100644 --- a/policyengine_api/country.py +++ b/policyengine_api/country.py @@ -1,5 +1,5 @@ import importlib -from flask import Response +import inspect import json from policyengine_core.taxbenefitsystems import TaxBenefitSystem from typing import Union, Optional @@ -22,14 +22,6 @@ build_congressional_district_metadata, ) -# Note: The following policyengine_[xx] imports are probably redundant. -# These modules are imported dynamically in the __init__ function below. -import policyengine_uk -import policyengine_us -import policyengine_canada -import policyengine_ng -import policyengine_il - from policyengine_api.data import local_database from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS @@ -45,24 +37,40 @@ def __init__(self, country_package_name: str, country_id: str): self.build_metadata() def build_metadata(self): - self.metadata = dict( - variables=self.build_variables(), - parameters=self.build_parameters(), - entities=self.build_entities(), - variableModules=self.tax_benefit_system.variable_module_metadata, - economy_options=self.build_microsimulation_options(), - current_law_id={ - "uk": 1, - "us": 2, - "ca": 3, - "ng": 4, - "il": 5, - }[self.country_id], - basicInputs=self.tax_benefit_system.basic_inputs, - modelled_policies=self.tax_benefit_system.modelled_policies, - version=get_package_version(self.country_package_name.replace("_", "-")), + self.metadata = self._json_safe( + dict( + variables=self.build_variables(), + parameters=self.build_parameters(), + entities=self.build_entities(), + variableModules=self.tax_benefit_system.variable_module_metadata, + economy_options=self.build_microsimulation_options(), + current_law_id={ + "uk": 1, + "us": 2, + "ca": 3, + "ng": 4, + "il": 5, + }[self.country_id], + basicInputs=self.tax_benefit_system.basic_inputs, + modelled_policies=self.tax_benefit_system.modelled_policies, + version=get_package_version( + self.country_package_name.replace("_", "-") + ), + ) ) + def _json_safe(self, value): + if isinstance(value, Path): + return str(value) + if isinstance(value, dict): + return { + key: self._json_safe(nested_value) + for key, nested_value in value.items() + } + if isinstance(value, list): + return [self._json_safe(nested_value) for nested_value in value] + return value + def build_microsimulation_options(self) -> dict: # { region: [{ name: "uk", label: "the UK" }], time_period: [{ name: 2022, label: "2022", ... }] } options = dict() @@ -363,31 +371,7 @@ def calculate( household_id: Optional[int] = None, policy_id: Optional[int] = None, ): - if reform is not None and len(reform.keys()) > 0: - system = self.tax_benefit_system.clone() - for parameter_name in reform: - for time_period, value in reform[parameter_name].items(): - start_instant, end_instant = time_period.split(".") - parameter = get_parameter(system.parameters, parameter_name) - node_type = type(parameter.values_list[-1].value) - if node_type == int: - node_type = float - try: - value = float(value) - except: - pass - parameter.update( - start=instant(start_instant), - stop=instant(end_instant), - value=node_type(value), - ) - else: - system = self.tax_benefit_system - - simulation = self.country_package.Simulation( - tax_benefit_system=system, - situation=household, - ) + simulation, system = self._create_simulation(household, reform) household = json.loads(json.dumps(household)) @@ -429,14 +413,14 @@ def calculate( entity_index = population.get_index(entity_id) if variable.value_type == Enum: entity_result = result.decode()[entity_index].name - elif variable.value_type == float: + elif variable.value_type is float: entity_result = float(str(result[entity_index])) # Convert infinities to JSON infinities if entity_result == float("inf"): entity_result = "Infinity" elif entity_result == float("-inf"): entity_result = "-Infinity" - elif variable.value_type == str: + elif variable.value_type is str: entity_result = str(result[entity_index]) else: entity_result = result.tolist()[entity_index] @@ -473,6 +457,72 @@ def calculate( return household + def _create_simulation( + self, + household: dict, + reform: Union[dict, None], + ): + normalized_reform = None + if reform: + system = self.tax_benefit_system.clone() + normalized_reform = self._normalize_reform_values(reform, system) + else: + system = self.tax_benefit_system + + if self._simulation_accepts_tax_benefit_system(): + if normalized_reform: + self._apply_reform_to_system(system, normalized_reform) + simulation = self.country_package.Simulation( + tax_benefit_system=system, + situation=household, + ) + return simulation, system + + simulation_kwargs = {"situation": household} + if normalized_reform: + simulation_kwargs["reform"] = normalized_reform + simulation = self.country_package.Simulation(**simulation_kwargs) + return simulation, simulation.tax_benefit_system + + def _simulation_accepts_tax_benefit_system(self) -> bool: + simulation_signature = inspect.signature(self.country_package.Simulation) + return "tax_benefit_system" in simulation_signature.parameters + + def _normalize_reform_values( + self, + reform: dict, + system: TaxBenefitSystem, + ) -> dict: + normalized_reform = {} + for parameter_name, parameter_updates in reform.items(): + parameter = get_parameter(system.parameters, parameter_name) + normalized_reform[parameter_name] = {} + for time_period, value in parameter_updates.items(): + node_type = type(parameter.values_list[-1].value) + if node_type is int: + node_type = float + try: + value = float(value) + except Exception: + pass + normalized_reform[parameter_name][time_period] = node_type(value) + return normalized_reform + + def _apply_reform_to_system( + self, + system: TaxBenefitSystem, + reform: dict, + ) -> None: + for parameter_name, parameter_updates in reform.items(): + parameter = get_parameter(system.parameters, parameter_name) + for time_period, value in parameter_updates.items(): + start_instant, end_instant = time_period.split(".") + parameter.update( + start=instant(start_instant), + stop=instant(end_instant), + value=value, + ) + def create_policy_reform(policy_data: dict) -> dict: """ @@ -498,7 +548,7 @@ def modify_parameters(parameters: ParameterNode) -> ParameterNode: for period, value in values.items(): start, end = period.split(".") node_type = type(node.values_list[-1].value) - if node_type == int: + if node_type is int: node_type = float # '0' is of type int by default, but usually we want to cast to float. if node.values_list[-1].value is None: node_type = float diff --git a/policyengine_api/data/model_setup.py b/policyengine_api/data/model_setup.py index 739f7bbcc..e804ac4cf 100644 --- a/policyengine_api/data/model_setup.py +++ b/policyengine_api/data/model_setup.py @@ -1,9 +1,11 @@ -ENHANCED_FRS = "hf://policyengine/policyengine-uk-data/enhanced_frs_2023_24.h5" -FRS = "hf://policyengine/policyengine-uk-data/frs_2023_24.h5" +ENHANCED_FRS = ( + "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3" +) +FRS = "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5@1.40.3" -ENHANCED_CPS = "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5" -CPS = "hf://policyengine/policyengine-us-data/cps_2023.h5" -POOLED_CPS = "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5" +ENHANCED_CPS = "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0" +CPS = "hf://policyengine/policyengine-us-data/cps_2023.h5@1.77.0" +POOLED_CPS = "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.77.0" datasets = { "uk": { @@ -16,28 +18,3 @@ "pooled_cps": POOLED_CPS, }, } - - -def get_dataset_version(country_id: str) -> str | None: - """ - Get the dataset version for the specified country. If PolicyEngine does not - publish data for the country, raise a ValueError. - - By returning None for all valid countries, we allow policyengine.py to use - whatever default dataset version it has available, without imposing version - validation constraints from the API layer. - """ - match country_id: - case "uk": - return None - case "us": - return None - case _: - raise ValueError(f"Unknown country ID: {country_id}") - - -for dataset in datasets["uk"]: - datasets["uk"][dataset] = f"{datasets['uk'][dataset]}@{get_dataset_version('uk')}" - -for dataset in datasets["us"]: - datasets["us"][dataset] = f"{datasets['us'][dataset]}@{get_dataset_version('us')}" diff --git a/policyengine_api/libs/simulation_api_modal.py b/policyengine_api/libs/simulation_api_modal.py index cc220ef17..e171888e7 100644 --- a/policyengine_api/libs/simulation_api_modal.py +++ b/policyengine_api/libs/simulation_api_modal.py @@ -24,6 +24,8 @@ class ModalSimulationExecution: status: str result: Optional[dict] = None error: Optional[str] = None + policyengine_bundle: Optional[dict] = None + resolved_app_name: Optional[str] = None @property def name(self) -> str: @@ -94,6 +96,8 @@ def run(self, payload: dict) -> ModalSimulationExecution: return ModalSimulationExecution( job_id=data["job_id"], status=data["status"], + policyengine_bundle=data.get("policyengine_bundle"), + resolved_app_name=data.get("resolved_app_name"), ) except httpx.HTTPStatusError as e: @@ -115,6 +119,22 @@ def run(self, payload: dict) -> ModalSimulationExecution: ) raise + def resolve_app_name( + self, country: str, version: Optional[str] = None + ) -> tuple[str, str]: + """Resolve the current gateway app name for a country/model version.""" + response = self.client.get(f"{self.base_url}/versions/{country}") + response.raise_for_status() + version_map = response.json() + + resolved_version = version or version_map["latest"] + try: + return version_map[resolved_version], resolved_version + except KeyError as exc: + raise ValueError( + f"Unknown version {resolved_version} for country {country}" + ) from exc + def get_execution_id(self, execution: ModalSimulationExecution) -> str: """ Get the job ID from an execution. @@ -156,6 +176,8 @@ def get_execution_by_id(self, job_id: str) -> ModalSimulationExecution: status=data["status"], result=data.get("result"), error=data.get("error"), + policyengine_bundle=data.get("policyengine_bundle"), + resolved_app_name=data.get("resolved_app_name"), ) except httpx.HTTPStatusError as e: diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 95eae9838..42612eefa 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -4,7 +4,6 @@ ) from policyengine_api.constants import ( COUNTRY_PACKAGE_VERSIONS, - REGION_PREFIXES, EXECUTION_STATUSES_SUCCESS, EXECUTION_STATUSES_FAILURE, EXECUTION_STATUSES_PENDING, @@ -12,7 +11,9 @@ ) from policyengine_api.gcp_logging import logger from policyengine_api.libs.simulation_api_modal import simulation_api_modal -from policyengine_api.data.model_setup import get_dataset_version +from policyengine_api.data.model_setup import ( + datasets as configured_datasets, +) from policyengine_api.data.congressional_districts import ( get_valid_state_codes, get_valid_congressional_districts, @@ -23,7 +24,7 @@ from policyengine.utils.data.datasets import get_default_dataset import json import datetime -from typing import Literal, Any, Optional, Annotated, Union +from typing import Literal, Any, Optional, Annotated from dotenv import load_dotenv from pydantic import BaseModel import numpy as np @@ -36,6 +37,16 @@ simulation_api = simulation_api_modal +def get_policyengine_version() -> str | None: + """Legacy test seam; runtime bundle metadata comes from the simulation API.""" + return None + + +def get_dataset_version(country_id: str) -> str | None: + """Legacy test seam; runtime bundle metadata comes from the simulation API.""" + return None + + class ImpactAction(Enum): """ Enum for the action to take based on the status of an economic impact calculation. @@ -72,7 +83,9 @@ class EconomicImpactSetupOptions(BaseModel): api_version: str target: Literal["general", "cliff"] model_version: str | None = None + policyengine_version: str | None = None data_version: str | None = None + runtime_app_name: str | None = None options_hash: str | None = None @@ -156,16 +169,20 @@ def get_economic_impact( # Set up logging process_id: str = self._create_process_id() - options_hash = ( - "[" + "&".join([f"{k}={v}" for k, v in options.items()]) + "]" - ) - country_package_version = COUNTRY_PACKAGE_VERSIONS.get(country_id) - - if country_id == "uk": - country_package_version = None - cache_version = get_economy_impact_cache_version(country_id, api_version) + resolved_dataset = self._setup_data( + country_id=country_id, + region=region, + dataset=dataset, + ) + resolved_model_version = country_package_version + resolved_data_version = self._extract_dataset_version(resolved_dataset) + options_hash = self._build_options_hash( + options=options, + model_version=resolved_model_version, + dataset=resolved_dataset, + ) economic_impact_setup_options = EconomicImpactSetupOptions.model_validate( { @@ -174,13 +191,15 @@ def get_economic_impact( "reform_policy_id": policy_id, "baseline_policy_id": baseline_policy_id, "region": region, - "dataset": dataset, + "dataset": resolved_dataset, "time_period": time_period, "options": options, "api_version": cache_version, "target": target, - "model_version": country_package_version, - "data_version": get_dataset_version(country_id), + "model_version": resolved_model_version, + "policyengine_version": None, + "data_version": resolved_data_version, + "runtime_app_name": None, "options_hash": options_hash, } ) @@ -198,6 +217,20 @@ def get_economic_impact( setup_options=economic_impact_setup_options, ) + if most_recent_impact and self._should_refresh_cached_impact( + setup_options=economic_impact_setup_options, + most_recent_impact=most_recent_impact, + ): + most_recent_impact = self._get_most_recent_impact( + economic_impact_setup_options + ) + if ( + not most_recent_impact + or most_recent_impact.get("options_hash") + != economic_impact_setup_options.options_hash + ): + most_recent_impact = None + impact_action: ImpactAction = self._determine_impact_action( most_recent_impact=most_recent_impact, ) @@ -211,6 +244,7 @@ def get_economic_impact( severity="INFO", ) return self._handle_completed_impact( + setup_options=economic_impact_setup_options, most_recent_impact=most_recent_impact, ) @@ -228,6 +262,21 @@ def get_economic_impact( ) if impact_action == ImpactAction.CREATE: + if economic_impact_setup_options.runtime_app_name is None: + ( + economic_impact_setup_options.runtime_app_name, + economic_impact_setup_options.model_version, + ) = simulation_api.resolve_app_name( + country_id, + economic_impact_setup_options.model_version, + ) + economic_impact_setup_options.options_hash = self._build_options_hash( + options=options, + model_version=economic_impact_setup_options.model_version, + dataset=resolved_dataset, + data_version=resolved_data_version, + runtime_app_name=economic_impact_setup_options.runtime_app_name, + ) logger.log_struct( { "message": "No previous economic impact record found in db; creating new simulation run", @@ -260,15 +309,18 @@ def _get_previous_impacts( Fetch any previous simulation runs for the given policy reform. """ - previous_impacts: list[Any] = reform_impacts_service.get_all_reform_impacts( - country_id, - policy_id, - baseline_policy_id, - region, - dataset, - time_period, - options_hash, - api_version, + previous_impacts: list[Any] = ( + reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix( + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + self._build_options_hash_lookup_pattern(options_hash), + api_version, + ) ) return previous_impacts @@ -292,10 +344,14 @@ def _get_most_recent_impact( api_version=setup_options.api_version, ) - if previous_impacts: - return previous_impacts[0] + if not previous_impacts: + return None + + for impact in previous_impacts: + if impact.get("options_hash") == setup_options.options_hash: + return impact - return None + return previous_impacts[0] def _determine_impact_action( self, @@ -327,7 +383,11 @@ def _handle_execution_state( Modal statuses (complete, failed, running, submitted). """ if execution_state in EXECUTION_STATUSES_SUCCESS: - result = simulation_api.get_execution_result(execution) + result = self._with_policyengine_bundle( + result=simulation_api.get_execution_result(execution), + setup_options=setup_options, + execution=execution, + ) self._set_reform_impact_complete( setup_options=setup_options, reform_impact_json=json.dumps(result), @@ -372,11 +432,15 @@ def _handle_execution_state( def _handle_completed_impact( self, + setup_options: EconomicImpactSetupOptions, most_recent_impact: dict, ) -> EconomicImpactResult: - + result = json.loads(most_recent_impact["reform_impact_json"]) return EconomicImpactResult.completed( - data=json.loads(most_recent_impact["reform_impact_json"]) + data=self._with_policyengine_bundle( + result=result, + setup_options=setup_options, + ) ) def _handle_computing_impact( @@ -400,7 +464,6 @@ def _handle_create_impact( self, setup_options: EconomicImpactSetupOptions, ) -> EconomicImpactResult: - baseline_policy = policy_service.get_policy_json( setup_options.country_id, setup_options.baseline_policy_id ) @@ -436,6 +499,11 @@ def _handle_create_impact( "reform_policy_id": setup_options.reform_policy_id, "baseline_policy_id": setup_options.baseline_policy_id, "process_id": setup_options.process_id, + "model_version": setup_options.model_version, + "policyengine_version": setup_options.policyengine_version, + "data_version": setup_options.data_version, + "dataset": setup_options.dataset, + "resolved_app_name": setup_options.runtime_app_name, } sim_api_execution = simulation_api.run(sim_params) @@ -489,6 +557,132 @@ def _setup_sim_options( } ) + def _build_options_hash( + self, + options: dict, + model_version: str | None, + dataset: str, + runtime_app_name: str | None = None, + data_version: str | None = None, + policyengine_version: str | None = None, + ) -> str: + option_pairs = "&".join([f"{k}={v}" for k, v in options.items()]) + bundle_parts = [ + f"dataset={dataset}", + f"model_version={model_version}", + ] + if data_version: + bundle_parts.append(f"data_version={data_version}") + if policyengine_version: + bundle_parts.append(f"policyengine_version={policyengine_version}") + if runtime_app_name: + bundle_parts.append(f"runtime_app_name={runtime_app_name}") + return "[" + "&".join([option_pairs, *bundle_parts]).strip("&") + "]" + + def _build_options_hash_lookup_pattern(self, options_hash: str) -> str: + escaped_options_hash = ( + options_hash.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + ) + if options_hash.endswith("]"): + return f"{escaped_options_hash[:-1]}&%" + return f"{escaped_options_hash}%" + + def _extract_dataset_version(self, dataset: str) -> str | None: + if "@" not in dataset: + return None + return dataset.rsplit("@", 1)[1] + + def _extract_cached_result(self, most_recent_impact: dict) -> dict: + try: + return json.loads(most_recent_impact["reform_impact_json"]) + except (TypeError, ValueError): + return {} + + def _should_refresh_cached_impact( + self, + setup_options: EconomicImpactSetupOptions, + most_recent_impact: dict, + ) -> bool: + if most_recent_impact.get("status") == ImpactStatus.COMPUTING.value: + return False + + cached_result = self._extract_cached_result(most_recent_impact) + cached_resolved_app_name = cached_result.get("resolved_app_name") + try: + runtime_app_name, resolved_model_version = simulation_api.resolve_app_name( + setup_options.country_id, + setup_options.model_version, + ) + except Exception: + return False + + setup_options.runtime_app_name = runtime_app_name + setup_options.model_version = resolved_model_version + setup_options.options_hash = self._build_options_hash( + options=setup_options.options, + model_version=resolved_model_version, + dataset=setup_options.dataset, + data_version=setup_options.data_version, + policyengine_version=setup_options.policyengine_version, + runtime_app_name=runtime_app_name, + ) + if ( + not isinstance(cached_resolved_app_name, str) + or not cached_resolved_app_name + ): + return True + + return runtime_app_name != cached_resolved_app_name + + def _with_policyengine_bundle( + self, + result: dict, + setup_options: EconomicImpactSetupOptions, + execution: Optional[Any] = None, + ) -> dict: + result = result if isinstance(result, dict) else {} + cached_resolved_app_name = result.get("resolved_app_name") + use_setup_model_version = execution is not None or ( + isinstance(cached_resolved_app_name, str) and bool(cached_resolved_app_name) + ) + bundle = { + "model_version": ( + setup_options.model_version if use_setup_model_version else None + ), + "policyengine_version": ( + setup_options.policyengine_version if use_setup_model_version else None + ), + "data_version": setup_options.data_version, + "dataset": setup_options.dataset, + } + if isinstance(result.get("policyengine_bundle"), dict): + for key, value in result["policyengine_bundle"].items(): + if bundle.get(key) is None and value is not None: + bundle[key] = value + execution_bundle = ( + getattr(execution, "policyengine_bundle", None) + if execution is not None + else None + ) + if isinstance(execution_bundle, dict): + for key, value in execution_bundle.items(): + if value is not None: + bundle[key] = value + response = { + **result, + "policyengine_bundle": bundle, + } + resolved_app_name = None + if execution is not None: + maybe_resolved_app_name = getattr(execution, "resolved_app_name", None) + if isinstance(maybe_resolved_app_name, str) and maybe_resolved_app_name: + resolved_app_name = maybe_resolved_app_name + if resolved_app_name is None: + resolved_app_name = setup_options.runtime_app_name + if resolved_app_name: + response["resolved_app_name"] = resolved_app_name + return response + def _setup_region(self, country_id: str, region: str) -> str: """ Validate the region for the given country. @@ -537,13 +731,23 @@ def _setup_data( Determine the dataset to use based on the country and region. If the dataset is in PASSTHROUGH_DATASETS, it will be passed directly - to the simulation API. Otherwise, uses policyengine's get_default_dataset - to resolve the appropriate GCS path. + to the simulation API. If the dataset matches a configured dataset alias + for the country, resolve it to the published dataset URI. Otherwise, + uses policyengine's get_default_dataset to resolve the appropriate GCS + path. """ # If the dataset is a recognized passthrough keyword, use it directly if dataset in self.PASSTHROUGH_DATASETS: return dataset + if "://" in dataset: + return dataset + + # Resolve explicit dataset aliases exposed in metadata. + country_datasets = configured_datasets.get(country_id, {}) + if dataset in country_datasets: + return country_datasets[dataset] + try: return get_default_dataset(country_id, region) except ValueError as e: diff --git a/policyengine_api/services/reform_impacts_service.py b/policyengine_api/services/reform_impacts_service.py index ca44ea10c..0f41352f3 100644 --- a/policyengine_api/services/reform_impacts_service.py +++ b/policyengine_api/services/reform_impacts_service.py @@ -44,6 +44,45 @@ def get_all_reform_impacts( print(f"Error getting all reform impacts: {str(e)}") raise e + def get_all_reform_impacts_by_options_hash_prefix( + self, + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + options_hash_prefix, + api_version, + ): + try: + query = ( + "SELECT reform_impact_json, status, message, start_time, execution_id, options_hash FROM " + "reform_impact WHERE country_id = ? AND reform_policy_id = ? AND " + "baseline_policy_id = ? AND region = ? AND time_period = ? AND " + "(options_hash = ? OR options_hash LIKE ? ESCAPE '\\') AND api_version = ? AND dataset = ? " + "ORDER BY CASE WHEN options_hash = ? THEN 0 ELSE 1 END, start_time DESC" + ) + return local_database.query( + query, + ( + country_id, + policy_id, + baseline_policy_id, + region, + time_period, + options_hash, + options_hash_prefix, + api_version, + dataset, + options_hash, + ), + ).fetchall() + except Exception as e: + print(f"Error getting reform impacts by prefix: {str(e)}") + raise e + def set_reform_impact( self, country_id, diff --git a/pyproject.toml b/pyproject.toml index 8abe8ce53..d80c55cdc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "policyengine_canada==0.96.3", "policyengine-ng==0.5.1", "policyengine-il==0.1.0", - "policyengine_uk==2.39.0", + "policyengine_uk==2.78.0", "policyengine_us==1.633.2", "policyengine_core>=3.16.6", "policyengine>=0.7.0", diff --git a/tests/fixtures/libs/simulation_api_modal.py b/tests/fixtures/libs/simulation_api_modal.py index 64ce139e7..fa47f8b2a 100644 --- a/tests/fixtures/libs/simulation_api_modal.py +++ b/tests/fixtures/libs/simulation_api_modal.py @@ -36,6 +36,13 @@ "budget_impact": {"baseline": 1000, "reform": 1200}, "inequality_impact": {"baseline": 0.45, "reform": 0.42}, } +MOCK_POLICYENGINE_BUNDLE = { + "model_version": "1.459.0", + "policyengine_version": "3.4.0", + "data_version": "1.77.0", + "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0", +} +MOCK_RESOLVED_APP_NAME = "policyengine-us-1-459-0" MOCK_SUBMIT_RESPONSE_SUCCESS = { "job_id": MOCK_MODAL_JOB_ID, @@ -43,6 +50,8 @@ "poll_url": f"/jobs/{MOCK_MODAL_JOB_ID}", "country": "us", "version": "1.459.0", + "policyengine_bundle": MOCK_POLICYENGINE_BUNDLE, + "resolved_app_name": MOCK_RESOLVED_APP_NAME, } MOCK_POLL_RESPONSE_RUNNING = { @@ -55,6 +64,8 @@ "status": MODAL_EXECUTION_STATUS_COMPLETE, "result": MOCK_SIMULATION_RESULT, "error": None, + "policyengine_bundle": MOCK_POLICYENGINE_BUNDLE, + "resolved_app_name": MOCK_RESOLVED_APP_NAME, } MOCK_POLL_RESPONSE_FAILED = { diff --git a/tests/fixtures/services/economy_service.py b/tests/fixtures/services/economy_service.py index 687a82a48..88f2d08b0 100644 --- a/tests/fixtures/services/economy_service.py +++ b/tests/fixtures/services/economy_service.py @@ -17,12 +17,32 @@ MOCK_TIME_PERIOD = "2025" MOCK_API_VERSION = "1.0" MOCK_OPTIONS = {"option1": "value1", "option2": "value2"} -MOCK_OPTIONS_HASH = "[option1=value1&option2=value2]" +MOCK_DATA_VERSION = "1.77.0" +MOCK_LOOKUP_OPTIONS_HASH = ( + "[option1=value1&option2=value2" + "&dataset=hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0" + "&model_version=1.2.3]" +) +MOCK_OPTIONS_HASH = ( + MOCK_LOOKUP_OPTIONS_HASH[:-1] + + "&data_version=1.77.0" + + "&runtime_app_name=policyengine-simulation-us1-2-3-uk2-7-8]" +) MOCK_MODAL_JOB_ID = "fc-test123xyz" MOCK_EXECUTION_ID = MOCK_MODAL_JOB_ID # Alias for test compatibility MOCK_PROCESS_ID = "job_20250626120000_1234" MOCK_MODEL_VERSION = "1.2.3" -MOCK_DATA_VERSION = None +MOCK_POLICYENGINE_VERSION = "3.4.0" +MOCK_RESOLVED_DATASET = ( + "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0" +) +MOCK_RESOLVED_APP_NAME = "policyengine-simulation-us1-2-3-uk2-7-8" +MOCK_RUNTIME_BUNDLE = { + "model_version": MOCK_MODEL_VERSION, + "policyengine_version": MOCK_POLICYENGINE_VERSION, + "data_version": MOCK_DATA_VERSION, + "dataset": MOCK_RESOLVED_DATASET, +} MOCK_REFORM_POLICY_JSON = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) @@ -41,7 +61,7 @@ "region": MOCK_REGION, "time_period": MOCK_TIME_PERIOD, "scope": "macro", - "dataset": MOCK_DATASET, + "dataset": MOCK_RESOLVED_DATASET, "include_cliffs": False, "model_version": MOCK_MODEL_VERSION, "data_version": MOCK_DATA_VERSION, @@ -68,6 +88,16 @@ def mock_get_dataset_version(): yield mock +@pytest.fixture +def mock_get_policyengine_version(): + """Mock get_policyengine_version function.""" + with patch( + "policyengine_api.services.economy_service.get_policyengine_version", + return_value=MOCK_POLICYENGINE_VERSION, + ) as mock: + yield mock + + @pytest.fixture def mock_policy_service(): """Mock PolicyService with get_policy_json method.""" @@ -89,6 +119,7 @@ def mock_policy_service(): def mock_reform_impacts_service(): """Mock ReformImpactsService with all required methods.""" mock_service = MagicMock() + mock_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [] mock_service.get_all_reform_impacts.return_value = [] mock_service.set_reform_impact.return_value = None mock_service.set_complete_reform_impact.return_value = None @@ -109,6 +140,10 @@ def mock_simulation_api(): mock_api._setup_sim_options.return_value = MOCK_SIM_CONFIG mock_api.run.return_value = mock_execution + mock_api.resolve_app_name.side_effect = lambda country_id, version=None: ( + MOCK_RESOLVED_APP_NAME, + version or MOCK_MODEL_VERSION, + ) mock_api.get_execution_id.return_value = MOCK_MODAL_JOB_ID mock_api.get_execution_by_id.return_value = mock_execution mock_api.get_execution_status.return_value = MODAL_EXECUTION_STATUS_RUNNING @@ -147,21 +182,36 @@ def mock_numpy_random(): def create_mock_reform_impact( - status="ok", reform_impact_json=None, execution_id=MOCK_MODAL_JOB_ID + status="ok", + reform_impact_json=None, + execution_id=MOCK_MODAL_JOB_ID, + options_hash=MOCK_OPTIONS_HASH, ): """Helper function to create mock reform impact records.""" + default_reform_impact_json = json.dumps( + { + **MOCK_REFORM_IMPACT_DATA, + "resolved_app_name": MOCK_RESOLVED_APP_NAME, + "policyengine_bundle": { + "model_version": MOCK_MODEL_VERSION, + "policyengine_version": MOCK_POLICYENGINE_VERSION, + "data_version": MOCK_DATA_VERSION, + "dataset": MOCK_RESOLVED_DATASET, + }, + } + ) return { "id": 1, "country_id": MOCK_COUNTRY_ID, "policy_id": MOCK_POLICY_ID, "baseline_policy_id": MOCK_BASELINE_POLICY_ID, "region": MOCK_REGION, - "dataset": MOCK_DATASET, + "dataset": MOCK_RESOLVED_DATASET, "time_period": MOCK_TIME_PERIOD, - "options_hash": MOCK_OPTIONS_HASH, + "options_hash": options_hash, "status": status, "api_version": MOCK_API_VERSION, - "reform_impact_json": reform_impact_json or json.dumps(MOCK_REFORM_IMPACT_DATA), + "reform_impact_json": reform_impact_json or default_reform_impact_json, "execution_id": execution_id, "start_time": datetime.datetime(2025, 6, 26, 12, 0, 0), "end_time": ( @@ -175,6 +225,7 @@ def create_mock_modal_execution( status=MODAL_EXECUTION_STATUS_SUBMITTED, result=None, error=None, + policyengine_bundle=None, ): """ Helper function to create mock Modal execution objects. @@ -201,6 +252,8 @@ def create_mock_modal_execution( mock_execution.status = status mock_execution.result = result mock_execution.error = error + mock_execution.policyengine_bundle = policyengine_bundle or MOCK_RUNTIME_BUNDLE + mock_execution.resolved_app_name = MOCK_RESOLVED_APP_NAME return mock_execution @@ -211,6 +264,10 @@ def mock_simulation_api_modal(): mock_execution = create_mock_modal_execution() mock_api.run.return_value = mock_execution + mock_api.resolve_app_name.side_effect = lambda country_id, version=None: ( + MOCK_RESOLVED_APP_NAME, + version or MOCK_MODEL_VERSION, + ) mock_api.get_execution_id.return_value = MOCK_MODAL_JOB_ID mock_api.get_execution_by_id.return_value = mock_execution mock_api.get_execution_status.return_value = MODAL_EXECUTION_STATUS_RUNNING diff --git a/tests/to_refactor/python/test_yearly_var_removal.py b/tests/to_refactor/python/test_yearly_var_removal.py index b0d9211d9..875176fe4 100644 --- a/tests/to_refactor/python/test_yearly_var_removal.py +++ b/tests/to_refactor/python/test_yearly_var_removal.py @@ -1,5 +1,6 @@ import pytest import json +import uuid from policyengine_api.endpoints.household import get_household_under_policy from policyengine_api.services.metadata_service import MetadataService @@ -19,7 +20,10 @@ def client(): yield client -TEST_HOUSEHOLD_ID = "-100" +def make_test_household_id() -> str: + # Use a negative signed 32-bit-ish integer string to avoid colliding with + # normal autoincrement rows while remaining compatible with INT columns. + return str(-((uuid.uuid4().int % 2_000_000_000) or 1)) def create_test_household(household_id, country_id): @@ -109,26 +113,22 @@ def interface_test_household_under_policy( # Value to invalidated if any key is not present in household is_test_passing = True + test_household_id = make_test_household_id() + # Fetch live country metadata metadata = metadata_service.get_metadata(country_id) - # Create the test household on the local db instance - create_test_household(TEST_HOUSEHOLD_ID, country_id) - - # Remove the created household from the db - test_row = database.query( - f"SELECT * FROM household WHERE id = ? AND country_id = ?", - (TEST_HOUSEHOLD_ID, country_id), - ).fetchone() - - # Create a result object by simply calling the relevant function - result_object = get_household_under_policy( - country_id, TEST_HOUSEHOLD_ID, CURRENT_LAW - )["result"] + try: + # Create the test household on the local db instance + create_test_household(test_household_id, country_id) - # Remove the created test household - remove_test_household(TEST_HOUSEHOLD_ID, country_id) - remove_calculated_hup(TEST_HOUSEHOLD_ID, CURRENT_LAW, country_id) + # Create a result object by simply calling the relevant function + result_object = get_household_under_policy( + country_id, test_household_id, CURRENT_LAW + )["result"] + finally: + remove_test_household(test_household_id, country_id) + remove_calculated_hup(test_household_id, CURRENT_LAW, country_id) # Create a dict of entity singular and plural terms for testing entities_map = {} @@ -195,6 +195,19 @@ def interface_test_household_under_policy( return is_test_passing +def test_make_test_household_id_returns_negative_integer_string(): + test_household_id = make_test_household_id() + + assert test_household_id.startswith("-") + assert int(test_household_id) < 0 + + +def test_make_test_household_id_is_unique(): + generated_ids = {make_test_household_id() for _ in range(100)} + + assert len(generated_ids) == 100 + + def test_us_household_under_policy(): """ Test that a US household under current law is created correctly diff --git a/tests/unit/data/test_model_setup.py b/tests/unit/data/test_model_setup.py index 45c69a114..1bbabef28 100644 --- a/tests/unit/data/test_model_setup.py +++ b/tests/unit/data/test_model_setup.py @@ -1,31 +1,36 @@ -import pytest - -from policyengine_api.data.model_setup import get_dataset_version - - -class TestGetDatasetVersion: - """Tests for the get_dataset_version function.""" - - def test__given_us__returns_none(self): - result = get_dataset_version("us") - assert result is None - - def test__given_uk__returns_none(self): - result = get_dataset_version("uk") - assert result is None - - def test__given_invalid_country__raises_value_error(self): - with pytest.raises(ValueError) as exc_info: - get_dataset_version("invalid") - assert "Unknown country ID: invalid" in str(exc_info.value) - - def test__given_empty_string__raises_value_error(self): - with pytest.raises(ValueError) as exc_info: - get_dataset_version("") - assert "Unknown country ID:" in str(exc_info.value) - - def test__given_canada__raises_value_error(self): - # Canada is a valid country in the API but doesn't have dataset versioning - with pytest.raises(ValueError) as exc_info: - get_dataset_version("ca") - assert "Unknown country ID: ca" in str(exc_info.value) +from policyengine_api.data.model_setup import ( + CPS, + ENHANCED_CPS, + ENHANCED_FRS, + FRS, + POOLED_CPS, + datasets, +) + + +class TestDatasets: + def test__given_us_aliases__then_returns_versioned_public_hf_uris(self): + assert datasets["us"] == { + "enhanced_cps": ENHANCED_CPS, + "cps": CPS, + "pooled_cps": POOLED_CPS, + } + assert ENHANCED_CPS.endswith("@1.77.0") + assert CPS.endswith("@1.77.0") + assert POOLED_CPS.endswith("@1.77.0") + + def test__given_uk_aliases__then_returns_versioned_private_hf_uris(self): + assert datasets["uk"] == { + "enhanced_frs": ENHANCED_FRS, + "frs": FRS, + } + assert ENHANCED_FRS.startswith( + "hf://policyengine/policyengine-uk-data-private/" + ) + assert FRS.startswith("hf://policyengine/policyengine-uk-data-private/") + assert ENHANCED_FRS.endswith("@1.40.3") + assert FRS.endswith("@1.40.3") + + def test__given_unknown_country__then_has_no_dataset_aliases(self): + assert "ca" not in datasets + assert "invalid" not in datasets diff --git a/tests/unit/libs/test_simulation_api_modal.py b/tests/unit/libs/test_simulation_api_modal.py index d44dde8cb..740bd37a2 100644 --- a/tests/unit/libs/test_simulation_api_modal.py +++ b/tests/unit/libs/test_simulation_api_modal.py @@ -6,7 +6,7 @@ """ import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch import httpx from policyengine_api.libs.simulation_api_modal import ( @@ -24,16 +24,18 @@ MOCK_MODAL_BASE_URL, MOCK_SIMULATION_PAYLOAD, MOCK_SIMULATION_RESULT, + MOCK_POLICYENGINE_BUNDLE, + MOCK_RESOLVED_APP_NAME, MOCK_SUBMIT_RESPONSE_SUCCESS, MOCK_POLL_RESPONSE_RUNNING, MOCK_POLL_RESPONSE_COMPLETE, MOCK_POLL_RESPONSE_FAILED, MOCK_HEALTH_RESPONSE, create_mock_httpx_response, - mock_httpx_client, - mock_modal_logger, ) +pytest_plugins = ("tests.fixtures.libs.simulation_api_modal",) + class TestModalSimulationExecution: """Tests for the ModalSimulationExecution dataclass.""" @@ -135,6 +137,8 @@ def test__given_valid_payload__then_returns_execution_with_job_id( # Then assert execution.job_id == MOCK_MODAL_JOB_ID assert execution.status == MODAL_EXECUTION_STATUS_SUBMITTED + assert execution.policyengine_bundle == MOCK_POLICYENGINE_BUNDLE + assert execution.resolved_app_name == MOCK_RESOLVED_APP_NAME mock_httpx_client.post.assert_called_once() def test__given_valid_payload__then_posts_to_correct_endpoint( @@ -187,6 +191,26 @@ def test__given_network_error__then_raises_exception( with pytest.raises(httpx.RequestError): api.run(MOCK_SIMULATION_PAYLOAD) + class TestResolveAppName: + def test__given_country_and_version__then_returns_registered_app( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=200, + json_data={ + "latest": "1.459.0", + "1.459.0": MOCK_RESOLVED_APP_NAME, + }, + ) + api = SimulationAPIModal() + + app_name, resolved_version = api.resolve_app_name("us", "1.459.0") + + assert app_name == MOCK_RESOLVED_APP_NAME + assert resolved_version == "1.459.0" + class TestGetExecutionById: def test__given_running_job__then_returns_running_status( self, @@ -226,6 +250,8 @@ def test__given_complete_job__then_returns_result( # Then assert execution.status == MODAL_EXECUTION_STATUS_COMPLETE assert execution.result == MOCK_SIMULATION_RESULT + assert execution.policyengine_bundle == MOCK_POLICYENGINE_BUNDLE + assert execution.resolved_app_name == MOCK_RESOLVED_APP_NAME def test__given_failed_job__then_returns_error( self, diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 162d30c20..9ad1aa1e7 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -12,28 +12,28 @@ ) from tests.fixtures.services.economy_service import ( MOCK_COUNTRY_ID, + MOCK_DATA_VERSION, MOCK_POLICY_ID, MOCK_BASELINE_POLICY_ID, MOCK_REGION, MOCK_DATASET, MOCK_TIME_PERIOD, MOCK_API_VERSION, + MOCK_MODEL_VERSION, + MOCK_POLICYENGINE_VERSION, MOCK_OPTIONS, + MOCK_LOOKUP_OPTIONS_HASH, MOCK_OPTIONS_HASH, MOCK_EXECUTION_ID, MOCK_PROCESS_ID, MOCK_REFORM_IMPACT_DATA, + MOCK_RESOLVED_DATASET, + MOCK_RESOLVED_APP_NAME, create_mock_reform_impact, - mock_country_package_versions, - mock_datetime, - mock_get_dataset_version, - mock_logger, - mock_numpy_random, - mock_policy_service, - mock_reform_impacts_service, - mock_simulation_api, ) +pytest_plugins = ("tests.fixtures.services.economy_service",) + class TestEconomyService: class TestGetEconomicImpact: @@ -61,6 +61,7 @@ def test__given_completed_impact__returns_completed_result( base_params, mock_country_package_versions, mock_get_dataset_version, + mock_get_policyengine_version, mock_policy_service, mock_reform_impacts_service, mock_simulation_api, @@ -69,23 +70,67 @@ def test__given_completed_impact__returns_completed_result( mock_numpy_random, ): completed_impact = create_mock_reform_impact(status="ok") - mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [ completed_impact ] result = economy_service.get_economic_impact(**base_params) assert result.status == ImpactStatus.OK - assert result.data == MOCK_REFORM_IMPACT_DATA - mock_reform_impacts_service.get_all_reform_impacts.assert_called_once() + assert ( + result.data["poverty_impact"] + == MOCK_REFORM_IMPACT_DATA["poverty_impact"] + ) + assert result.data["policyengine_bundle"] == { + "model_version": MOCK_MODEL_VERSION, + "policyengine_version": MOCK_POLICYENGINE_VERSION, + "data_version": MOCK_DATA_VERSION, + "dataset": MOCK_RESOLVED_DATASET, + } + ( + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.assert_called_once() + ) mock_simulation_api.run.assert_not_called() + def test__given_legacy_completed_impact__refreshes_cache( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_get_policyengine_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + completed_impact = create_mock_reform_impact( + status="ok", + reform_impact_json=json.dumps(MOCK_REFORM_IMPACT_DATA), + options_hash=MOCK_LOOKUP_OPTIONS_HASH, + ) + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [ + completed_impact + ] + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + mock_simulation_api.resolve_app_name.assert_called_once_with( + MOCK_COUNTRY_ID, + MOCK_MODEL_VERSION, + ) + mock_simulation_api.run.assert_called_once() + def test__given_computing_impact_with_succeeded_execution__returns_completed_result( self, economy_service, base_params, mock_country_package_versions, mock_get_dataset_version, + mock_get_policyengine_version, mock_policy_service, mock_reform_impacts_service, mock_simulation_api, @@ -94,7 +139,7 @@ def test__given_computing_impact_with_succeeded_execution__returns_completed_res mock_numpy_random, ): computing_impact = create_mock_reform_impact(status="computing") - mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [ computing_impact ] mock_simulation_api.get_execution_status.return_value = "complete" @@ -105,7 +150,15 @@ def test__given_computing_impact_with_succeeded_execution__returns_completed_res result = economy_service.get_economic_impact(**base_params) assert result.status == ImpactStatus.OK - assert result.data == MOCK_REFORM_IMPACT_DATA + assert ( + result.data["budget_impact"] == MOCK_REFORM_IMPACT_DATA["budget_impact"] + ) + assert result.data["policyengine_bundle"] == { + "model_version": MOCK_MODEL_VERSION, + "policyengine_version": MOCK_POLICYENGINE_VERSION, + "data_version": MOCK_DATA_VERSION, + "dataset": MOCK_RESOLVED_DATASET, + } mock_simulation_api.get_execution_by_id.assert_called_once_with( MOCK_EXECUTION_ID ) @@ -117,6 +170,7 @@ def test__given_computing_impact_with_failed_execution__returns_error_result( base_params, mock_country_package_versions, mock_get_dataset_version, + mock_get_policyengine_version, mock_policy_service, mock_reform_impacts_service, mock_simulation_api, @@ -125,7 +179,7 @@ def test__given_computing_impact_with_failed_execution__returns_error_result( mock_numpy_random, ): computing_impact = create_mock_reform_impact(status="computing") - mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [ computing_impact ] mock_simulation_api.get_execution_status.return_value = "failed" @@ -142,6 +196,7 @@ def test__given_computing_impact_with_active_execution__returns_computing_result base_params, mock_country_package_versions, mock_get_dataset_version, + mock_get_policyengine_version, mock_policy_service, mock_reform_impacts_service, mock_simulation_api, @@ -150,7 +205,7 @@ def test__given_computing_impact_with_active_execution__returns_computing_result mock_numpy_random, ): computing_impact = create_mock_reform_impact(status="computing") - mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [ computing_impact ] mock_simulation_api.get_execution_status.return_value = "running" @@ -166,6 +221,7 @@ def test__given_no_previous_impact__creates_new_simulation( base_params, mock_country_package_versions, mock_get_dataset_version, + mock_get_policyengine_version, mock_policy_service, mock_reform_impacts_service, mock_simulation_api, @@ -173,7 +229,7 @@ def test__given_no_previous_impact__creates_new_simulation( mock_datetime, mock_numpy_random, ): - mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [] result = economy_service.get_economic_impact(**base_params) @@ -188,6 +244,7 @@ def test__given_no_previous_impact__includes_metadata_in_simulation_params( base_params, mock_country_package_versions, mock_get_dataset_version, + mock_get_policyengine_version, mock_policy_service, mock_reform_impacts_service, mock_simulation_api, @@ -196,7 +253,7 @@ def test__given_no_previous_impact__includes_metadata_in_simulation_params( mock_numpy_random, ): """Verify that _metadata with policy IDs is passed to simulation API.""" - mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [] economy_service.get_economic_impact(**base_params) @@ -211,6 +268,13 @@ def test__given_no_previous_impact__includes_metadata_in_simulation_params( sim_params["_metadata"]["baseline_policy_id"] == MOCK_BASELINE_POLICY_ID ) assert sim_params["_metadata"]["process_id"] == MOCK_PROCESS_ID + assert sim_params["_metadata"]["model_version"] == MOCK_MODEL_VERSION + assert sim_params["_metadata"]["policyengine_version"] is None + assert sim_params["_metadata"]["data_version"] == MOCK_DATA_VERSION + assert sim_params["_metadata"]["dataset"] == MOCK_RESOLVED_DATASET + assert ( + sim_params["_metadata"]["resolved_app_name"] == MOCK_RESOLVED_APP_NAME + ) def test__given_runtime_cache_version__uses_versioned_economy_cache_key( self, @@ -218,6 +282,7 @@ def test__given_runtime_cache_version__uses_versioned_economy_cache_key( base_params, mock_country_package_versions, mock_get_dataset_version, + mock_get_policyengine_version, mock_policy_service, mock_reform_impacts_service, mock_simulation_api, @@ -231,27 +296,248 @@ def test__given_runtime_cache_version__uses_versioned_economy_cache_key( "policyengine_api.services.economy_service.get_economy_impact_cache_version", lambda country_id, api_version=None: cache_version, ) - mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [] economy_service.get_economic_impact(**base_params) - mock_reform_impacts_service.get_all_reform_impacts.assert_called_once_with( + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.assert_called_once_with( MOCK_COUNTRY_ID, MOCK_POLICY_ID, MOCK_BASELINE_POLICY_ID, MOCK_REGION, - MOCK_DATASET, + MOCK_RESOLVED_DATASET, MOCK_TIME_PERIOD, - MOCK_OPTIONS_HASH, + MOCK_LOOKUP_OPTIONS_HASH, + economy_service._build_options_hash_lookup_pattern( + MOCK_LOOKUP_OPTIONS_HASH + ), cache_version, ) + def test__given_alias_dataset__queries_previous_impacts_with_resolved_bundle( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_get_policyengine_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [] + + economy_service.get_economic_impact(**base_params) + + call_args = mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.call_args.args + assert call_args[4] == MOCK_RESOLVED_DATASET + assert call_args[6] == MOCK_LOOKUP_OPTIONS_HASH + assert call_args[7] == economy_service._build_options_hash_lookup_pattern( + MOCK_LOOKUP_OPTIONS_HASH + ) + assert "data_version=" not in call_args[7] + assert "runtime_app_name" not in call_args[7] + + def test__given_completed_impact__uses_resolved_runtime_bundle_for_cache_lookup( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_get_policyengine_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + completed_impact = create_mock_reform_impact(status="ok") + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [ + completed_impact + ] + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.OK + mock_simulation_api.resolve_app_name.assert_called_once_with( + MOCK_COUNTRY_ID, + MOCK_MODEL_VERSION, + ) + + def test__given_cached_impact_and_runtime_lookup_fails__then_returns_cached_result( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_get_policyengine_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + completed_impact = create_mock_reform_impact(status="ok") + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [ + completed_impact + ] + mock_simulation_api.resolve_app_name.side_effect = RuntimeError( + "versions down" + ) + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.OK + assert ( + result.data["policyengine_bundle"]["dataset"] == MOCK_RESOLVED_DATASET + ) + mock_simulation_api.run.assert_not_called() + + def test__given_legacy_cached_impact_without_resolved_app_name__then_refreshes_cache( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_get_policyengine_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + completed_impact = create_mock_reform_impact( + status="ok", + reform_impact_json=json.dumps(MOCK_REFORM_IMPACT_DATA), + options_hash=MOCK_LOOKUP_OPTIONS_HASH, + ) + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [ + completed_impact + ] + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + mock_simulation_api.resolve_app_name.assert_called_once_with( + MOCK_COUNTRY_ID, + MOCK_MODEL_VERSION, + ) + mock_simulation_api.run.assert_called_once() + + def test__given_legacy_and_refreshed_cached_impacts__then_reuses_refreshed_entry( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_get_policyengine_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + legacy_impact = create_mock_reform_impact( + status="ok", + reform_impact_json=json.dumps(MOCK_REFORM_IMPACT_DATA), + options_hash=MOCK_LOOKUP_OPTIONS_HASH, + ) + refreshed_impact = create_mock_reform_impact(status="ok") + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.side_effect = [ + [legacy_impact, refreshed_impact], + [refreshed_impact], + ] + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.OK + assert result.data["policyengine_bundle"] == { + "model_version": MOCK_MODEL_VERSION, + "policyengine_version": MOCK_POLICYENGINE_VERSION, + "data_version": MOCK_DATA_VERSION, + "dataset": MOCK_RESOLVED_DATASET, + } + assert ( + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.call_count + == 2 + ) + mock_simulation_api.run.assert_not_called() + + def test__given_legacy_cached_impact_and_runtime_lookup_fails__then_returns_cached_result( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_get_policyengine_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + completed_impact = create_mock_reform_impact( + status="ok", + reform_impact_json=json.dumps(MOCK_REFORM_IMPACT_DATA), + options_hash=MOCK_LOOKUP_OPTIONS_HASH, + ) + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [ + completed_impact + ] + mock_simulation_api.resolve_app_name.side_effect = RuntimeError( + "versions down" + ) + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.OK + assert result.data["policyengine_bundle"]["model_version"] is None + mock_simulation_api.run.assert_not_called() + + def test__given_legacy_computing_impact_without_resolved_app_name__then_reuses_execution( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_get_policyengine_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + computing_impact = create_mock_reform_impact( + status="computing", + reform_impact_json=json.dumps({}), + ) + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [ + computing_impact + ] + mock_simulation_api.get_execution_status.return_value = "running" + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + mock_simulation_api.resolve_app_name.assert_not_called() + mock_simulation_api.run.assert_not_called() + def test__given_exception__raises_error( self, economy_service, base_params, mock_country_package_versions, mock_get_dataset_version, + mock_get_policyengine_version, mock_policy_service, mock_reform_impacts_service, mock_simulation_api, @@ -259,7 +545,7 @@ def test__given_exception__raises_error( mock_datetime, mock_numpy_random, ): - mock_reform_impacts_service.get_all_reform_impacts.side_effect = Exception( + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.side_effect = Exception( "Database error" ) @@ -267,6 +553,37 @@ def test__given_exception__raises_error( economy_service.get_economic_impact(**base_params) assert str(exc_info.value) == "Database error" + def test__given_uk_request__preserves_model_version_in_bundle( + self, + economy_service, + mock_country_package_versions, + mock_get_dataset_version, + mock_get_policyengine_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + mock_country_package_versions["uk"] = "2.7.8" + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [] + + economy_service.get_economic_impact( + country_id="uk", + policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region="uk", + dataset="default", + time_period=MOCK_TIME_PERIOD, + options=MOCK_OPTIONS, + api_version=MOCK_API_VERSION, + target="general", + ) + + sim_params = mock_simulation_api.run.call_args[0][0] + assert sim_params["_metadata"]["model_version"] == "2.7.8" + class TestGetPreviousImpacts: @pytest.fixture def economy_service(self): @@ -276,9 +593,7 @@ def test_given_valid_parameters_calls_service_correctly( self, economy_service, mock_reform_impacts_service ): expected_impacts = [create_mock_reform_impact()] - mock_reform_impacts_service.get_all_reform_impacts.return_value = ( - expected_impacts - ) + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = expected_impacts result = economy_service._get_previous_impacts( MOCK_COUNTRY_ID, @@ -292,7 +607,7 @@ def test_given_valid_parameters_calls_service_correctly( ) assert result == expected_impacts - mock_reform_impacts_service.get_all_reform_impacts.assert_called_once_with( + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.assert_called_once_with( MOCK_COUNTRY_ID, MOCK_POLICY_ID, MOCK_BASELINE_POLICY_ID, @@ -300,6 +615,7 @@ def test_given_valid_parameters_calls_service_correctly( MOCK_DATASET, MOCK_TIME_PERIOD, MOCK_OPTIONS_HASH, + economy_service._build_options_hash_lookup_pattern(MOCK_OPTIONS_HASH), MOCK_API_VERSION, ) @@ -316,11 +632,14 @@ def setup_options(self): reform_policy_id=MOCK_POLICY_ID, baseline_policy_id=MOCK_BASELINE_POLICY_ID, region=MOCK_REGION, - dataset=MOCK_DATASET, + dataset=MOCK_RESOLVED_DATASET, time_period=MOCK_TIME_PERIOD, options=MOCK_OPTIONS, api_version=MOCK_API_VERSION, target="general", + model_version=MOCK_MODEL_VERSION, + policyengine_version=MOCK_POLICYENGINE_VERSION, + data_version=MOCK_DATA_VERSION, options_hash=MOCK_OPTIONS_HASH, ) @@ -331,17 +650,30 @@ def test__given_existing_impacts__returns_first_impact( create_mock_reform_impact(), create_mock_reform_impact(), ] - mock_reform_impacts_service.get_all_reform_impacts.return_value = impacts + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = impacts result = economy_service._get_most_recent_impact(setup_options) assert result == impacts[0] + def test__given_exact_and_prefix_matches__prefers_exact_options_hash( + self, economy_service, setup_options, mock_reform_impacts_service + ): + impacts = [ + create_mock_reform_impact(options_hash=MOCK_LOOKUP_OPTIONS_HASH), + create_mock_reform_impact(options_hash=MOCK_OPTIONS_HASH), + ] + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = impacts + + result = economy_service._get_most_recent_impact(setup_options) + + assert result == impacts[1] + def test__given_no_impacts__returns_none( self, economy_service, setup_options, mock_reform_impacts_service ): # Arrange - mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [] # Act result = economy_service._get_most_recent_impact(setup_options) @@ -400,11 +732,14 @@ def setup_options(self): reform_policy_id=MOCK_POLICY_ID, baseline_policy_id=MOCK_BASELINE_POLICY_ID, region=MOCK_REGION, - dataset=MOCK_DATASET, + dataset=MOCK_RESOLVED_DATASET, time_period=MOCK_TIME_PERIOD, options=MOCK_OPTIONS, api_version=MOCK_API_VERSION, target="general", + model_version=MOCK_MODEL_VERSION, + policyengine_version=MOCK_POLICYENGINE_VERSION, + data_version=MOCK_DATA_VERSION, options_hash=MOCK_OPTIONS_HASH, ) @@ -427,7 +762,12 @@ def test__given_succeeded_state__returns_completed_result( ) assert result.status == ImpactStatus.OK - assert result.data == MOCK_REFORM_IMPACT_DATA + assert result.data["policyengine_bundle"] == { + "model_version": MOCK_MODEL_VERSION, + "policyengine_version": MOCK_POLICYENGINE_VERSION, + "data_version": MOCK_DATA_VERSION, + "dataset": MOCK_RESOLVED_DATASET, + } mock_reform_impacts_service.set_complete_reform_impact.assert_called_once() def test__given_failed_state__returns_error_result( @@ -493,7 +833,12 @@ def test__given_modal_complete_state__then_returns_completed_result( # Then assert result.status == ImpactStatus.OK - assert result.data == MOCK_REFORM_IMPACT_DATA + assert result.data["policyengine_bundle"] == { + "model_version": MOCK_MODEL_VERSION, + "policyengine_version": MOCK_POLICYENGINE_VERSION, + "data_version": MOCK_DATA_VERSION, + "dataset": MOCK_RESOLVED_DATASET, + } mock_reform_impacts_service.set_complete_reform_impact.assert_called_once() def test__given_modal_failed_state__then_returns_error_result( @@ -845,6 +1190,25 @@ def test__given_congressional_district__returns_correct_sim_options( assert sim_options["region"] == "congressional_district/CA-37" assert sim_options["data"] == "gs://policyengine-us-data/districts/CA-37.h5" + def test__given_explicit_dataset__returns_named_dataset(self): + service = EconomyService() + + sim_options_model = service._setup_sim_options( + self.test_country_id, + self.test_reform_policy, + self.test_current_law_baseline_policy, + self.test_region, + self.test_time_period, + self.test_scope, + dataset="enhanced_cps", + ) + + sim_options = sim_options_model.model_dump() + assert ( + sim_options["data"] + == "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0" + ) + class TestSetupRegion: """Tests for _setup_region method. @@ -1006,6 +1370,27 @@ def test__given_passthrough_test_dataset__returns_dataset_directly( ) assert result == "national-with-breakdowns-test" + def test__given_explicit_us_enhanced_cps__returns_named_dataset(self): + service = EconomyService() + result = service._setup_data("us", "us", dataset="enhanced_cps") + assert ( + result + == "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0" + ) + + def test__given_explicit_us_cps__returns_named_dataset(self): + service = EconomyService() + result = service._setup_data("us", "us", dataset="cps") + assert result == "hf://policyengine/policyengine-us-data/cps_2023.h5@1.77.0" + + def test__given_explicit_uk_enhanced_frs__returns_named_dataset(self): + service = EconomyService() + result = service._setup_data("uk", "uk", dataset="enhanced_frs") + assert ( + result + == "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3" + ) + def test__given_default_dataset__uses_get_default_dataset(self): # Test that "default" falls through to get_default_dataset service = EconomyService() diff --git a/tests/unit/test_country.py b/tests/unit/test_country.py index 55a1f7c70..358baca0c 100644 --- a/tests/unit/test_country.py +++ b/tests/unit/test_country.py @@ -1,8 +1,9 @@ import pytest import pandas as pd from pathlib import Path +from types import SimpleNamespace -from policyengine_api.country import COUNTRIES +from policyengine_api.country import COUNTRIES, PolicyEngineCountry class TestUKCountryMetadata: @@ -80,6 +81,12 @@ def test__uk_has_all_region_types(self, uk_regions): assert "constituency" in types assert "local_authority" in types + def test__uk_metadata_is_json_serializable(self, uk_country): + """Verify metadata does not leak filesystem paths or other non-JSON values.""" + import json + + json.dumps(uk_country.metadata) + class TestLocalAuthoritiesDataFile: """Tests for the local authorities CSV data file.""" @@ -142,3 +149,47 @@ def test__welsh_local_authorities_have_w_prefix(self, local_authorities_df): ] # Wales has 22 principal areas assert len(welsh_las) == 22 + + +class TestSimulationCompatibility: + def test__create_simulation_uses_legacy_tax_benefit_system_signature(self): + class LegacySimulation: + def __init__(self, *, tax_benefit_system, situation): + self.tax_benefit_system = tax_benefit_system + self.situation = situation + + country = PolicyEngineCountry.__new__(PolicyEngineCountry) + country.country_package = SimpleNamespace(Simulation=LegacySimulation) + country.tax_benefit_system = object() + + simulation, system = country._create_simulation( + {"households": {"household": {}}}, + None, + ) + + assert system is country.tax_benefit_system + assert simulation.tax_benefit_system is country.tax_benefit_system + assert simulation.situation == {"households": {"household": {}}} + + def test__create_simulation_uses_reform_signature_when_system_arg_is_unsupported( + self, + ): + class ModernSimulation: + def __init__(self, *, situation, reform=None): + self.situation = situation + self.reform = reform + self.tax_benefit_system = "modern-system" + + country = PolicyEngineCountry.__new__(PolicyEngineCountry) + country.country_package = SimpleNamespace(Simulation=ModernSimulation) + country.tax_benefit_system = object() + + simulation, system = country._create_simulation( + {"households": {"household": {}}}, + None, + ) + + assert system == "modern-system" + assert simulation.tax_benefit_system == "modern-system" + assert simulation.reform is None + assert simulation.situation == {"households": {"household": {}}}