diff --git a/changelog.d/budget-window-batch.fixed.md b/changelog.d/budget-window-batch.fixed.md new file mode 100644 index 000000000..9bd03006f --- /dev/null +++ b/changelog.d/budget-window-batch.fixed.md @@ -0,0 +1 @@ +Added a budget-window economy endpoint that batches yearly impact calculations with bounded server-side concurrency and returns aggregated progress plus totals. diff --git a/changelog.d/fix-silent-exception-swallowing.fixed.md b/changelog.d/fix-silent-exception-swallowing.fixed.md new file mode 100644 index 000000000..4b10062e5 --- /dev/null +++ b/changelog.d/fix-silent-exception-swallowing.fixed.md @@ -0,0 +1 @@ +Log exceptions instead of silently swallowing them during household calculations. diff --git a/policyengine_api/api.py b/policyengine_api/api.py index 112cce9ac..eb3eba9ee 100644 --- a/policyengine_api/api.py +++ b/policyengine_api/api.py @@ -4,6 +4,7 @@ import time import sys +import os start_time = time.time() @@ -157,8 +158,11 @@ def log_timing(message): app.register_blueprint(user_profile_bp) log_timing("User profile routes registered") -app.route("/simulations", methods=["GET"])(get_simulations) -log_timing("Simulations endpoint registered") +if os.environ.get("FLASK_DEBUG") == "1": + app.route("/simulations", methods=["GET"])(get_simulations) + log_timing("Simulations endpoint registered") +else: + log_timing("Simulations endpoint skipped outside debug mode") app.register_blueprint(tracer_analysis_bp) log_timing("Tracer analysis routes registered") diff --git a/policyengine_api/country.py b/policyengine_api/country.py index befa49851..0cc7f3806 100644 --- a/policyengine_api/country.py +++ b/policyengine_api/country.py @@ -1,4 +1,5 @@ import importlib +import logging from flask import Response import json from policyengine_core.taxbenefitsystems import TaxBenefitSystem @@ -445,11 +446,9 @@ def calculate( entity_result ) except Exception as e: - if "axes" in household: - pass - else: + logging.exception(f"Error computing {variable_name} for {entity_id}") + if "axes" not in household: household[entity_plural][entity_id][variable_name][period] = None - print(f"Error computing {variable_name} for {entity_id}: {e}") tracer_output = simulation.tracer.computation_log log_lines = tracer_output.lines(aggregate=False, max_depth=10) diff --git a/policyengine_api/data/__init__.py b/policyengine_api/data/__init__.py index 15673afdb..94703ee36 100644 --- a/policyengine_api/data/__init__.py +++ b/policyengine_api/data/__init__.py @@ -1 +1,6 @@ -from .data import PolicyEngineDatabase, database, local_database +from .data import ( + PolicyEngineDatabase, + database, + get_remote_database, + local_database, +) diff --git a/policyengine_api/data/data.py b/policyengine_api/data/data.py index 7dcb96c43..ad521e386 100644 --- a/policyengine_api/data/data.py +++ b/policyengine_api/data/data.py @@ -19,6 +19,7 @@ class _ResultProxy: Provides fetchone()/fetchall() with dict-like row access.""" def __init__(self, cursor_result): + self.rowcount = getattr(cursor_result, "rowcount", -1) try: # Use .mappings() so rows behave like dicts self._rows = list(cursor_result.mappings()) @@ -75,16 +76,20 @@ def _create_pool(self): with open(".dbpw") as f: db_pass = f.read().strip() db_name = "policyengine" - conn = self.connector.connect( - instance_connection_string=instance_connection_name, - driver="pymysql", - db=db_name, - user=db_user, - password=db_pass, - ) + + def get_connection(): + return self.connector.connect( + instance_connection_string=instance_connection_name, + driver="pymysql", + db=db_name, + user=db_user, + password=db_pass, + ) + self.pool = sqlalchemy.create_engine( "mysql+pymysql://", - creator=lambda: conn, + creator=get_connection, + pool_pre_ping=True, ) def _close_pool(self): @@ -194,3 +199,11 @@ def initialize(self): database = PolicyEngineDatabase(local=False, initialize=False) local_database = PolicyEngineDatabase(local=True, initialize=False) +remote_database = None + + +def get_remote_database() -> PolicyEngineDatabase: + global remote_database + if remote_database is None: + remote_database = PolicyEngineDatabase(local=False, initialize=False) + return remote_database diff --git a/policyengine_api/endpoints/simulation.py b/policyengine_api/endpoints/simulation.py index 132e5b2d6..be14e115f 100644 --- a/policyengine_api/endpoints/simulation.py +++ b/policyengine_api/endpoints/simulation.py @@ -1,4 +1,4 @@ -from policyengine_api.data import local_database +from policyengine_api.data import get_remote_database """ @@ -28,9 +28,13 @@ def get_simulations( desc_limit = f"DESC LIMIT {max_results}" if max_results is not None else "" - result = local_database.query( - f"SELECT * FROM reform_impact ORDER BY start_time {desc_limit}", - ).fetchall() + result = ( + get_remote_database() + .query( + f"SELECT * FROM reform_impact ORDER BY start_time {desc_limit}", + ) + .fetchall() + ) # Format into [{}] diff --git a/policyengine_api/openapi_spec.yaml b/policyengine_api/openapi_spec.yaml index a49268c8c..77daadc9e 100644 --- a/policyengine_api/openapi_spec.yaml +++ b/policyengine_api/openapi_spec.yaml @@ -660,6 +660,138 @@ paths: type: string message: type: string + /{country_id}/economy/{policy_id}/over/{baseline_policy_id}/budget-window: + get: + summary: Calculate budget-window economic impacts + operationId: get_budget_window_economic_impact + description: Calculate annual and total budget impacts for a policy over a multi-year budget window. + parameters: + - name: country_id + in: path + description: The country ID. + required: true + schema: + type: string + - name: policy_id + in: path + description: The reform policy ID. + required: true + schema: + type: string + - name: baseline_policy_id + in: path + description: The baseline policy ID. + required: true + schema: + type: string + - name: region + in: query + description: The sub-national region. + required: true + schema: + type: string + - name: start_year + in: query + description: First year in the budget window. + required: true + schema: + type: string + - name: window_size + in: query + description: Number of years to include in the budget window. + required: true + schema: + type: integer + - name: dataset + in: query + description: Dataset selection. + required: false + schema: + type: string + default: default + - name: version + in: query + description: API version number. + required: false + schema: + type: string + - name: include_district_breakdowns + in: query + description: Whether to include congressional district breakdowns for US national simulations. + required: false + schema: + type: boolean + default: false + - name: target + in: query + description: Impact target. Budget-window calculations only support general impacts. + required: false + schema: + type: string + default: general + responses: + 200: + description: Budget-window economic impact, progress, or error state. + content: + application/json: + schema: + type: object + properties: + status: + type: string + enum: + - ok + - computing + - error + message: + type: string + nullable: true + result: + type: object + nullable: true + progress: + type: integer + nullable: true + completed_years: + type: array + items: + type: string + computing_years: + type: array + items: + type: string + queued_years: + type: array + items: + type: string + error: + type: string + nullable: true + 400: + description: Invalid budget-window request. + content: + application/json: + schema: + type: object + properties: + status: + type: string + message: + type: string + result: + type: object + nullable: true + 404: + description: Invalid country ID. + content: + text/html: + schema: + type: object + properties: + status: + type: string + message: + type: string /{country_id}/analysis: post: summary: Get or trigger policy analysis diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index 4279a1b1b..9b71532da 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -2,6 +2,7 @@ from policyengine_api.services.economy_service import ( EconomyService, EconomicImpactResult, + BudgetWindowEconomicImpactResult, ) from policyengine_api.utils import get_current_law_policy_id from policyengine_api.utils.payload_validators import validate_country @@ -13,6 +14,25 @@ economy_service = EconomyService() +def _json_response(payload: dict, status: int = 200) -> Response: + return Response( + json.dumps(payload), + status=status, + mimetype="application/json", + ) + + +def _bad_request_response(message: str) -> Response: + return _json_response( + { + "status": "error", + "message": message, + "result": None, + }, + status=400, + ) + + @validate_country @economy_bp.route( "//economy//over/", @@ -57,14 +77,93 @@ def get_economic_impact(country_id: str, policy_id: int, baseline_policy_id: int result_dict: dict[str, str | dict | None] = economic_impact_result.to_dict() - return Response( - json.dumps( - { - "status": result_dict["status"], - "message": None, - "result": result_dict["data"], - } - ), - status=200, - mimetype="application/json", + return _json_response( + { + "status": result_dict["status"], + "message": None, + "result": result_dict["data"], + } + ) + + +@validate_country +@economy_bp.route( + "//economy//over//budget-window", + methods=["GET"], +) +def get_budget_window_economic_impact( + country_id: str, policy_id: int, baseline_policy_id: int +): + policy_id = int(policy_id or get_current_law_policy_id(country_id)) + baseline_policy_id = int( + baseline_policy_id or get_current_law_policy_id(country_id) + ) + + query_parameters = request.args + options = dict(query_parameters) + options = json.loads(json.dumps(options)) + region = options.pop("region", None) + if not region: + return _bad_request_response("Missing required query parameter: region") + + dataset = options.pop("dataset", "default") + start_year = options.pop("start_year", None) + if not start_year: + return _bad_request_response("Missing required query parameter: start_year") + + window_size_raw = options.pop("window_size", None) + if window_size_raw is None: + return _bad_request_response("Missing required query parameter: window_size") + + try: + window_size = int(window_size_raw) + except (TypeError, ValueError): + return _bad_request_response("window_size must be an integer") + + include_district_breakdowns_raw = options.pop( + "include_district_breakdowns", "false" + ) + include_district_breakdowns = include_district_breakdowns_raw.lower() == "true" + if include_district_breakdowns and country_id == "us" and region == "us": + dataset = "national-with-breakdowns" + + target: Literal["general", "cliff"] = options.pop("target", "general") + if target != "general": + return _bad_request_response( + "Budget-window calculations only support target=general" + ) + + api_version = options.pop("version", COUNTRY_PACKAGE_VERSIONS.get(country_id)) + + try: + economic_impact_result: BudgetWindowEconomicImpactResult = ( + economy_service.get_budget_window_economic_impact( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + start_year=start_year, + window_size=window_size, + options=options, + api_version=api_version, + target=target, + ) + ) + except ValueError as error: + return _bad_request_response(str(error)) + + result_dict = economic_impact_result.to_dict() + + return _json_response( + { + "status": result_dict["status"], + "message": result_dict["message"], + "result": result_dict["data"], + "progress": result_dict["progress"], + "completed_years": result_dict["completed_years"], + "computing_years": result_dict["computing_years"], + "queued_years": result_dict["queued_years"], + "error": result_dict["error"], + } ) diff --git a/policyengine_api/routes/report_output_routes.py b/policyengine_api/routes/report_output_routes.py index 93256d778..1a38a394a 100644 --- a/policyengine_api/routes/report_output_routes.py +++ b/policyengine_api/routes/report_output_routes.py @@ -2,12 +2,100 @@ from werkzeug.exceptions import NotFound, BadRequest import json +from policyengine_api.constants import ( + CURRENT_YEAR, + get_economy_impact_cache_version, +) +from policyengine_api.services.reform_impacts_service import ReformImpactsService from policyengine_api.services.report_output_service import ReportOutputService -from policyengine_api.constants import CURRENT_YEAR +from policyengine_api.services.simulation_service import SimulationService from policyengine_api.utils.payload_validators import validate_country report_output_bp = Blueprint("report_output", __name__) report_output_service = ReportOutputService() +simulation_service = SimulationService() +reform_impacts_service = ReformImpactsService() + + +def _get_linked_simulation_or_raise(country_id: str, simulation_id: int) -> dict: + simulation = simulation_service.get_simulation(country_id, simulation_id) + if simulation is None: + raise BadRequest( + f"Report references simulation #{simulation_id}, but it could not be found for country {country_id}." + ) + return simulation + + +def _load_report_and_linked_simulations( + country_id: str, report_id: int +) -> tuple[dict, dict, dict | None]: + report_output = report_output_service.get_stored_report_output(report_id) + if report_output is None or report_output["country_id"] != country_id: + raise NotFound(f"Report #{report_id} not found.") + + simulation_1 = _get_linked_simulation_or_raise( + country_id=country_id, + simulation_id=report_output["simulation_1_id"], + ) + + simulation_2 = None + if report_output["simulation_2_id"] is not None: + simulation_2 = _get_linked_simulation_or_raise( + country_id=country_id, + simulation_id=report_output["simulation_2_id"], + ) + + if ( + simulation_2 is not None + and simulation_1["population_type"] != simulation_2["population_type"] + ): + raise BadRequest( + f"Report #{report_id} links simulations with mismatched population types." + ) + + return report_output, simulation_1, simulation_2 + + +def _reset_linked_simulations(country_id: str, *simulations: dict | None) -> list[int]: + reset_simulation_ids: list[int] = [] + seen_ids: set[int] = set() + + for simulation in simulations: + if simulation is None or simulation["id"] in seen_ids: + continue + simulation_service.reset_simulation(country_id, simulation["id"]) + seen_ids.add(simulation["id"]) + reset_simulation_ids.append(simulation["id"]) + + return reset_simulation_ids + + +def _delete_economy_cache_for_legacy_report_path( + country_id: str, + report_output: dict, + simulation_1: dict, + simulation_2: dict | None, +) -> int | None: + """ + Delete reform_impact rows using the current legacy app path assumptions: + dataset is always "default", options_hash is always "[]", and the report + year maps directly to the economy time period. This is correct for the + current app-generated legacy report flow, not arbitrary historical callers. + """ + return reform_impacts_service.delete_reform_impacts( + country_id=country_id, + policy_id=( + simulation_2["policy_id"] + if simulation_2 is not None + else simulation_1["policy_id"] + ), + baseline_policy_id=simulation_1["policy_id"], + region=simulation_1["population_id"], + dataset="default", + time_period=report_output["year"], + options_hash="[]", + api_version=get_economy_impact_cache_version(country_id), + ) @report_output_bp.route("//report", methods=["POST"]) @@ -197,3 +285,52 @@ def update_report_output(country_id: str) -> Response: except Exception as e: print(f"Error updating report output: {str(e)}") raise BadRequest(f"Failed to update report output: {str(e)}") + + +@report_output_bp.route("//report//rerun", methods=["POST"]) +@validate_country +def rerun_report_output(country_id: str, report_id: int) -> Response: + """ + Reset a legacy report output so the current app can recompute it. + + For economy reports this also purges reform_impact rows using the current + app-path assumptions about dataset/options provenance. + """ + print(f"Rerunning report output {report_id} for country {country_id}") + + report_output, simulation_1, simulation_2 = _load_report_and_linked_simulations( + country_id=country_id, + report_id=report_id, + ) + + report_output_service.reset_report_output(country_id, report_id) + reset_simulation_ids = _reset_linked_simulations( + country_id, simulation_1, simulation_2 + ) + + economy_cache_rows_deleted = 0 + if simulation_1["population_type"] == "geography": + deleted_rows = _delete_economy_cache_for_legacy_report_path( + country_id=country_id, + report_output=report_output, + simulation_1=simulation_1, + simulation_2=simulation_2, + ) + economy_cache_rows_deleted = deleted_rows or 0 + + response_body = dict( + status="ok", + message="Report rerun reset successfully", + result=dict( + report_id=report_id, + report_type=simulation_1["population_type"], + simulation_ids=reset_simulation_ids, + economy_cache_rows_deleted=economy_cache_rows_deleted, + ), + ) + + return Response( + json.dumps(response_body), + status=200, + mimetype="application/json", + ) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 95eae9838..165e3f03d 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -4,7 +4,6 @@ ) from policyengine_api.constants import ( COUNTRY_PACKAGE_VERSIONS, - REGION_PREFIXES, EXECUTION_STATUSES_SUCCESS, EXECUTION_STATUSES_FAILURE, EXECUTION_STATUSES_PENDING, @@ -25,9 +24,10 @@ import datetime from typing import Literal, Any, Optional, Annotated, Union from dotenv import load_dotenv -from pydantic import BaseModel +from pydantic import BaseModel, Field import numpy as np from enum import Enum +from concurrent.futures import ThreadPoolExecutor load_dotenv() @@ -44,6 +44,7 @@ class ImpactAction(Enum): COMPLETED = "completed" COMPUTING = "computing" CREATE = "create" + ERROR = "error" class ImpactStatus(Enum): @@ -58,6 +59,13 @@ class ImpactStatus(Enum): COMPLETE_STATUSES = [ImpactStatus.OK.value, ImpactStatus.ERROR.value] COMPUTING_STATUS = ImpactStatus.COMPUTING.value +BUDGET_WINDOW_MAX_ACTIVE_YEARS = 3 +BUDGET_WINDOW_MAX_YEARS = 20 +PENDING_EXECUTION_ID_PREFIX = "pending:" +PROVISIONAL_CLAIM_TTL_SECONDS = 90 +STALE_PROVISIONAL_IMPACT_MESSAGE = ( + "Simulation claim expired before job submission completed" +) class EconomicImpactSetupOptions(BaseModel): @@ -84,6 +92,7 @@ class EconomicImpactResult(BaseModel): status: ImpactStatus data: Optional[dict] = None + message: Optional[str] = None model_config = {"frozen": True} # Make model immutable @@ -116,7 +125,80 @@ def error(cls, message: str) -> "EconomicImpactResult": Create an EconomicImpactResult for an error in the impact calculation. """ logger.log_struct({"message": message}, severity="ERROR") - return cls(status=ImpactStatus.ERROR, data=None) + return cls(status=ImpactStatus.ERROR, data=None, message=message) + + +class BudgetWindowEconomicImpactResult(BaseModel): + """ + Model for a batch budget-window economic impact response. + """ + + status: ImpactStatus + data: Optional[dict] = None + progress: Optional[int] = None + completed_years: list[str] = Field(default_factory=list) + computing_years: list[str] = Field(default_factory=list) + queued_years: list[str] = Field(default_factory=list) + message: Optional[str] = None + error: Optional[str] = None + + model_config = {"frozen": True} + + def to_dict(self) -> dict[str, Any]: + return { + "status": self.status.value, + "data": self.data, + "progress": self.progress, + "completed_years": self.completed_years, + "computing_years": self.computing_years, + "queued_years": self.queued_years, + "message": self.message, + "error": self.error, + } + + @classmethod + def completed(cls, data: dict) -> "BudgetWindowEconomicImpactResult": + return cls(status=ImpactStatus.OK, data=data, progress=100) + + @classmethod + def computing( + cls, + *, + progress: int, + completed_years: list[str], + computing_years: list[str], + queued_years: list[str], + message: str, + ) -> "BudgetWindowEconomicImpactResult": + return cls( + status=ImpactStatus.COMPUTING, + data=None, + progress=progress, + completed_years=completed_years, + computing_years=computing_years, + queued_years=queued_years, + message=message, + ) + + @classmethod + def failed( + cls, + message: str, + *, + completed_years: Optional[list[str]] = None, + computing_years: Optional[list[str]] = None, + queued_years: Optional[list[str]] = None, + ) -> "BudgetWindowEconomicImpactResult": + logger.log_struct({"message": message}, severity="ERROR") + return cls( + status=ImpactStatus.ERROR, + data=None, + completed_years=completed_years or [], + computing_years=computing_years or [], + queued_years=queued_years or [], + message=message, + error=message, + ) class EconomyService: @@ -152,98 +234,497 @@ def get_economic_impact( # regions that don't contain a region prefix. if country_id == "us": region = normalize_us_region(region) + economic_impact_setup_options = self._build_economic_impact_setup_options( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + time_period=time_period, + options=options, + api_version=api_version, + target=target, + ) + + return self._get_or_create_economic_impact( + setup_options=economic_impact_setup_options + ) + + except Exception as e: + print(f"Error getting economic impact: {str(e)}") + raise e + + def get_budget_window_economic_impact( + self, + country_id: str, + policy_id: int, + baseline_policy_id: int, + region: str, + dataset: str, + start_year: str, + window_size: int, + options: dict, + api_version: str, + target: Literal["general", "cliff"] = "general", + max_active_years: int = BUDGET_WINDOW_MAX_ACTIVE_YEARS, + ) -> BudgetWindowEconomicImpactResult: + try: + if country_id == "us": + region = normalize_us_region(region) + + if target != "general": + raise ValueError( + "Budget-window calculations only support target='general'" + ) + + start_year_int = int(start_year) + if not 1 <= window_size <= BUDGET_WINDOW_MAX_YEARS: + raise ValueError( + f"window_size must be between 1 and {BUDGET_WINDOW_MAX_YEARS}" + ) + + years = [str(start_year_int + index) for index in range(window_size)] + setup_options_by_year = { + year: self._build_economic_impact_setup_options( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + time_period=year, + options=dict(options), + api_version=api_version, + target=target, + ) + for year in years + } - # Set up logging - process_id: str = self._create_process_id() + completed_impacts: dict[str, dict] = {} + computing_years: list[str] = [] + queued_years: list[str] = [] - options_hash = ( - "[" + "&".join([f"{k}={v}" for k, v in options.items()]) + "]" + for year in years: + result = self._get_existing_economic_impact( + setup_options=setup_options_by_year[year] + ) + + if result is None: + queued_years.append(year) + continue + + if result.status == ImpactStatus.OK: + completed_impacts[year] = self._extract_budget_window_annual_impact( + year=year, impact_data=result.data or {} + ) + continue + + if result.status == ImpactStatus.COMPUTING: + computing_years.append(year) + continue + + completed_years = [ + completed_year + for completed_year in years + if completed_year in completed_impacts + ] + return BudgetWindowEconomicImpactResult.failed( + self._get_economic_impact_error_message( + result=result, + year=year, + ), + completed_years=completed_years, + computing_years=computing_years, + queued_years=queued_years, + ) + + available_slots = max(0, max_active_years - len(computing_years)) + years_to_start = queued_years[:available_slots] + remaining_queued_years = queued_years[available_slots:] + + if years_to_start: + max_workers = min(len(years_to_start), max_active_years) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_year_pairs = [ + ( + year, + executor.submit( + self._get_or_create_economic_impact, + setup_options_by_year[year], + ), + ) + for year in years_to_start + ] + + for year, future in future_year_pairs: + result = future.result() + + if result.status == ImpactStatus.OK: + completed_impacts[year] = ( + self._extract_budget_window_annual_impact( + year=year, impact_data=result.data or {} + ) + ) + elif result.status == ImpactStatus.COMPUTING: + computing_years.append(year) + else: + completed_years = [ + completed_year + for completed_year in years + if completed_year in completed_impacts + ] + return BudgetWindowEconomicImpactResult.failed( + self._get_economic_impact_error_message( + result=result, + year=year, + ), + completed_years=completed_years, + computing_years=computing_years, + queued_years=remaining_queued_years, + ) + + completed_years = [ + completed_year + for completed_year in years + if completed_year in completed_impacts + ] + + if len(completed_years) == len(years): + ordered_annual_impacts = [ + completed_impacts[year] + for year in years + if year in completed_impacts + ] + return BudgetWindowEconomicImpactResult.completed( + self._build_budget_window_output( + start_year=start_year, + window_size=window_size, + annual_impacts=ordered_annual_impacts, + ) + ) + + progress = round((len(completed_years) / len(years)) * 100) + return BudgetWindowEconomicImpactResult.computing( + progress=progress, + completed_years=completed_years, + computing_years=computing_years, + queued_years=remaining_queued_years, + message=self._build_budget_window_progress_message( + completed_years=completed_years, + total_years=len(years), + computing_years=computing_years, + queued_years=remaining_queued_years, + ), ) + except Exception as e: + print(f"Error getting budget-window economic impact: {str(e)}") + raise e - country_package_version = COUNTRY_PACKAGE_VERSIONS.get(country_id) + def _build_economic_impact_setup_options( + self, + *, + country_id: str, + policy_id: int, + baseline_policy_id: int, + region: str, + dataset: str, + time_period: str, + options: dict, + api_version: str, + target: Literal["general", "cliff"] = "general", + ) -> EconomicImpactSetupOptions: + process_id: str = self._create_process_id() + options_hash = "[" + "&".join([f"{k}={v}" for k, v in options.items()]) + "]" + cache_version = get_economy_impact_cache_version(country_id, api_version) + + country_package_version = COUNTRY_PACKAGE_VERSIONS.get(country_id) + if country_id == "uk": + country_package_version = None + + return EconomicImpactSetupOptions.model_validate( + { + "process_id": process_id, + "country_id": country_id, + "reform_policy_id": policy_id, + "baseline_policy_id": baseline_policy_id, + "region": region, + "dataset": dataset, + "time_period": time_period, + "options": options, + "api_version": cache_version, + "target": target, + "model_version": country_package_version, + "data_version": get_dataset_version(country_id), + "options_hash": options_hash, + } + ) + + def _get_or_create_economic_impact( + self, setup_options: EconomicImpactSetupOptions + ) -> EconomicImpactResult: + logger.log_struct( + { + "message": "Received request for economic impact; checking if already in reform_impacts table", + **setup_options.model_dump(), + }, + severity="INFO", + ) - if country_id == "uk": - country_package_version = None + most_recent_impact: dict | None = self._get_most_recent_impact( + setup_options=setup_options + ) - cache_version = get_economy_impact_cache_version(country_id, api_version) + impact_action: ImpactAction = self._determine_impact_action( + most_recent_impact=most_recent_impact + ) - economic_impact_setup_options = EconomicImpactSetupOptions.model_validate( + if impact_action == ImpactAction.COMPLETED: + logger.log_struct( { - "process_id": process_id, - "country_id": country_id, - "reform_policy_id": policy_id, - "baseline_policy_id": baseline_policy_id, - "region": region, - "dataset": dataset, - "time_period": time_period, - "options": options, - "api_version": cache_version, - "target": target, - "model_version": country_package_version, - "data_version": get_dataset_version(country_id), - "options_hash": options_hash, - } + "message": "Found completed economic impact in db; returning result", + **setup_options.model_dump(), + }, + severity="INFO", ) + return self._handle_completed_impact(most_recent_impact=most_recent_impact) - # Logging that we've received a request + if impact_action == ImpactAction.COMPUTING: logger.log_struct( { - "message": "Received request for economic impact; checking if already in reform_impacts table", - **economic_impact_setup_options.model_dump(), + "message": "Found computing economic impact record in db; confirming this is still computing", + **setup_options.model_dump(), }, severity="INFO", ) - - most_recent_impact: dict | None = self._get_most_recent_impact( - setup_options=economic_impact_setup_options, + return self._handle_computing_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, ) - impact_action: ImpactAction = self._determine_impact_action( + if impact_action == ImpactAction.ERROR: + logger.log_struct( + { + "message": "Found failed economic impact in db; returning error", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_error_impact( + setup_options=setup_options, most_recent_impact=most_recent_impact, ) - if impact_action == ImpactAction.COMPLETED: + if impact_action == ImpactAction.CREATE: + try: + with reform_impacts_service.claim_lock( + country_id=setup_options.country_id, + policy_id=setup_options.reform_policy_id, + baseline_policy_id=setup_options.baseline_policy_id, + region=setup_options.region, + dataset=setup_options.dataset, + time_period=setup_options.time_period, + options_hash=setup_options.options_hash, + api_version=setup_options.api_version, + ): + most_recent_impact = self._get_most_recent_impact( + setup_options=setup_options + ) + impact_action = self._determine_impact_action( + most_recent_impact=most_recent_impact + ) + + if impact_action == ImpactAction.COMPLETED: + logger.log_struct( + { + "message": "Found completed economic impact in db after locking; returning result", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_completed_impact( + most_recent_impact=most_recent_impact + ) + + if impact_action == ImpactAction.COMPUTING: + logger.log_struct( + { + "message": "Found computing economic impact in db after locking; returning progress", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_computing_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + + if impact_action == ImpactAction.ERROR: + logger.log_struct( + { + "message": "Found failed economic impact in db after locking; returning error", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_error_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + + stale_provisional_execution_id = None + if self._is_stale_provisional_impact(most_recent_impact): + stale_provisional_execution_id = most_recent_impact.get( + "execution_id" + ) + + provisional_execution_id = self._build_provisional_execution_id( + setup_options.process_id + ) + self._set_reform_impact_computing( + setup_options=setup_options, + execution_id=provisional_execution_id, + ) + if stale_provisional_execution_id: + self._expire_stale_provisional_impact( + setup_options=setup_options, + execution_id=stale_provisional_execution_id, + ) + except TimeoutError: logger.log_struct( { - "message": "Found completed economic impact in db; returning result", - **economic_impact_setup_options.model_dump(), + "message": "Timed out waiting for economic impact claim lock; re-checking existing claim", + **setup_options.model_dump(), }, - severity="INFO", + severity="WARNING", ) - return self._handle_completed_impact( - most_recent_impact=most_recent_impact, + existing_impact = self._get_existing_economic_impact( + setup_options=setup_options ) + if existing_impact is not None: + return existing_impact + return EconomicImpactResult.computing() - if impact_action == ImpactAction.COMPUTING: - logger.log_struct( - { - "message": "Found computing economic impact record in db; confirming this is still computing", - **economic_impact_setup_options.model_dump(), - }, - severity="INFO", - ) - return self._handle_computing_impact( - setup_options=economic_impact_setup_options, - most_recent_impact=most_recent_impact, - ) + logger.log_struct( + { + "message": "No previous economic impact record found in db; creating new simulation run", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_create_impact( + setup_options=setup_options, + provisional_execution_id=provisional_execution_id, + ) - if impact_action == ImpactAction.CREATE: - logger.log_struct( - { - "message": "No previous economic impact record found in db; creating new simulation run", - **economic_impact_setup_options.model_dump(), - }, - severity="INFO", - ) - return self._handle_create_impact( - setup_options=economic_impact_setup_options, - ) + raise ValueError(f"Unexpected impact action: {impact_action}") - raise ValueError(f"Unexpected impact action: {impact_action}") + def _get_existing_economic_impact( + self, setup_options: EconomicImpactSetupOptions + ) -> Optional[EconomicImpactResult]: + most_recent_impact = self._get_most_recent_impact(setup_options=setup_options) + if not most_recent_impact: + return None - except Exception as e: - print(f"Error getting economic impact: {str(e)}") - raise e + status = most_recent_impact.get("status") + if status == ImpactStatus.ERROR.value: + return self._handle_error_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + + if status == ImpactStatus.OK.value: + return self._handle_completed_impact(most_recent_impact=most_recent_impact) + + if status == ImpactStatus.COMPUTING.value: + if self._is_stale_provisional_impact(most_recent_impact): + return None + return self._handle_computing_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + + raise ValueError(f"Unknown impact status: {status}") + + def _get_economic_impact_error_message( + self, result: EconomicImpactResult, year: str + ) -> str: + if result.message: + return result.message + + if isinstance(result.data, dict): + data_message = result.data.get("message") + if isinstance(data_message, str) and data_message: + return data_message + + return f"Budget-window calculation failed for {year}" + + def _extract_budget_window_annual_impact( + self, year: str, impact_data: dict + ) -> dict[str, Union[str, int, float]]: + budget = impact_data.get("budget", {}) + state_tax_revenue_impact = budget.get("state_tax_revenue_impact", 0) + tax_revenue_impact = budget.get("tax_revenue_impact", 0) + + return { + "year": year, + "taxRevenueImpact": tax_revenue_impact, + "federalTaxRevenueImpact": tax_revenue_impact - state_tax_revenue_impact, + "stateTaxRevenueImpact": state_tax_revenue_impact, + "benefitSpendingImpact": budget.get("benefit_spending_impact", 0), + "budgetaryImpact": budget.get("budgetary_impact", 0), + } + + def _sum_budget_window_annual_impacts(self, annual_impacts: list[dict]) -> dict: + totals = { + "year": "Total", + "taxRevenueImpact": 0, + "federalTaxRevenueImpact": 0, + "stateTaxRevenueImpact": 0, + "benefitSpendingImpact": 0, + "budgetaryImpact": 0, + } + + for annual_impact in annual_impacts: + totals["taxRevenueImpact"] += annual_impact["taxRevenueImpact"] + totals["federalTaxRevenueImpact"] += annual_impact[ + "federalTaxRevenueImpact" + ] + totals["stateTaxRevenueImpact"] += annual_impact["stateTaxRevenueImpact"] + totals["benefitSpendingImpact"] += annual_impact["benefitSpendingImpact"] + totals["budgetaryImpact"] += annual_impact["budgetaryImpact"] + + return totals + + def _build_budget_window_output( + self, *, start_year: str, window_size: int, annual_impacts: list[dict] + ) -> dict: + return { + "kind": "budgetWindow", + "startYear": start_year, + "endYear": str(int(start_year) + window_size - 1), + "windowSize": window_size, + "annualImpacts": annual_impacts, + "totals": self._sum_budget_window_annual_impacts(annual_impacts), + } + + def _build_budget_window_progress_message( + self, + *, + completed_years: list[str], + total_years: int, + computing_years: list[str], + queued_years: list[str], + ) -> str: + completed_count = len(completed_years) + if computing_years: + active_years = ", ".join(computing_years[:2]) + if len(computing_years) > 2: + active_years = f"{active_years} + {len(computing_years) - 2} more" + return f"Scoring {active_years} ({completed_count} of {total_years} complete)..." + + if queued_years: + return f"Queued {queued_years[0]} ({completed_count} of {total_years} complete)..." + + return f"Scoring budget window ({completed_count} of {total_years} complete)..." def _get_previous_impacts( self, @@ -297,6 +778,62 @@ def _get_most_recent_impact( return None + def _build_provisional_execution_id(self, process_id: str) -> str: + return f"{PENDING_EXECUTION_ID_PREFIX}{process_id}" + + def _is_provisional_execution_id(self, execution_id: Any) -> bool: + return isinstance(execution_id, str) and execution_id.startswith( + PENDING_EXECUTION_ID_PREFIX + ) + + def _coerce_impact_start_time(self, start_time: Any) -> Optional[datetime.datetime]: + if start_time is None: + return None + + if isinstance(start_time, str): + parsed_start_time = datetime.datetime.fromisoformat(start_time) + elif hasattr(start_time, "tzinfo") and hasattr(start_time, "isoformat"): + parsed_start_time = start_time + else: + return None + + if parsed_start_time.tzinfo is None: + return parsed_start_time.replace(tzinfo=datetime.timezone.utc) + + return parsed_start_time.astimezone(datetime.timezone.utc) + + def _is_stale_provisional_impact(self, impact: dict | None) -> bool: + if not impact: + return False + + if not self._is_provisional_execution_id(impact.get("execution_id")): + return False + + start_time = self._coerce_impact_start_time(impact.get("start_time")) + if start_time is None: + return False + + current_time = datetime.datetime.now(datetime.timezone.utc) + if current_time.tzinfo is None: + current_time = current_time.replace(tzinfo=datetime.timezone.utc) + + claim_age = current_time - start_time + return claim_age.total_seconds() > PROVISIONAL_CLAIM_TTL_SECONDS + + def _expire_stale_provisional_impact( + self, + setup_options: EconomicImpactSetupOptions, + execution_id: str, + ) -> None: + if not self._is_provisional_execution_id(execution_id): + return + + self._set_reform_impact_error( + setup_options=setup_options, + message=STALE_PROVISIONAL_IMPACT_MESSAGE, + execution_id=execution_id, + ) + def _determine_impact_action( self, most_recent_impact: dict | None, @@ -306,9 +843,13 @@ def _determine_impact_action( return ImpactAction.CREATE status = most_recent_impact.get("status") - if status in [ImpactStatus.OK.value, ImpactStatus.ERROR.value]: + if status == ImpactStatus.OK.value: return ImpactAction.COMPLETED + elif status == ImpactStatus.ERROR.value: + return ImpactAction.ERROR elif status == ImpactStatus.COMPUTING.value: + if self._is_stale_provisional_impact(most_recent_impact): + return ImpactAction.CREATE return ImpactAction.COMPUTING else: raise ValueError(f"Unknown impact status: {status}") @@ -379,15 +920,30 @@ def _handle_completed_impact( data=json.loads(most_recent_impact["reform_impact_json"]) ) + def _handle_error_impact( + self, + setup_options: EconomicImpactSetupOptions, + most_recent_impact: dict, + ) -> EconomicImpactResult: + error_message = most_recent_impact.get("message") or ( + f"Economic impact failed for {setup_options.time_period}" + ) + return EconomicImpactResult( + status=ImpactStatus.ERROR, + data=None, + message=error_message, + ) + def _handle_computing_impact( self, setup_options: EconomicImpactSetupOptions, most_recent_impact: dict, ) -> EconomicImpactResult: + execution_id = most_recent_impact["execution_id"] + if self._is_provisional_execution_id(execution_id): + return EconomicImpactResult.computing() - execution = simulation_api.get_execution_by_id( - most_recent_impact["execution_id"] - ) + execution = simulation_api.get_execution_by_id(execution_id) execution_state = simulation_api.get_execution_status(execution) return self._handle_execution_state( execution_state=execution_state, @@ -399,47 +955,57 @@ def _handle_computing_impact( def _handle_create_impact( self, setup_options: EconomicImpactSetupOptions, + provisional_execution_id: str, ) -> EconomicImpactResult: - baseline_policy = policy_service.get_policy_json( - setup_options.country_id, setup_options.baseline_policy_id - ) - reform_policy = policy_service.get_policy_json( - setup_options.country_id, setup_options.reform_policy_id - ) + try: + baseline_policy = policy_service.get_policy_json( + setup_options.country_id, setup_options.baseline_policy_id + ) + reform_policy = policy_service.get_policy_json( + setup_options.country_id, setup_options.reform_policy_id + ) - sim_config: SimulationOptions = self._setup_sim_options( - country_id=setup_options.country_id, - reform_policy=reform_policy, - baseline_policy=baseline_policy, - region=setup_options.region, - time_period=setup_options.time_period, - dataset=setup_options.dataset, - scope="macro", - include_cliffs=setup_options.target == "cliff", - model_version=setup_options.model_version, - data_version=setup_options.data_version, - ) + sim_config: SimulationOptions = self._setup_sim_options( + country_id=setup_options.country_id, + reform_policy=reform_policy, + baseline_policy=baseline_policy, + region=setup_options.region, + time_period=setup_options.time_period, + dataset=setup_options.dataset, + scope="macro", + include_cliffs=setup_options.target == "cliff", + model_version=setup_options.model_version, + data_version=setup_options.data_version, + ) - logger.log_struct( - { - "message": "Setting up sim API job", - **setup_options.model_dump(), - } - ) + logger.log_struct( + { + "message": "Setting up sim API job", + **setup_options.model_dump(), + } + ) - # Build params with metadata for Logfire tracing in the simulation API. - # The _metadata field will be captured by the Logfire span before - # SimulationOptions validation (which silently ignores extra fields). - sim_params = sim_config.model_dump() - sim_params["_metadata"] = { - "reform_policy_id": setup_options.reform_policy_id, - "baseline_policy_id": setup_options.baseline_policy_id, - "process_id": setup_options.process_id, - } + # Build params with metadata for Logfire tracing in the simulation API. + # The _metadata field will be captured by the Logfire span before + # SimulationOptions validation (which silently ignores extra fields). + sim_params = sim_config.model_dump() + sim_params["_metadata"] = { + "reform_policy_id": setup_options.reform_policy_id, + "baseline_policy_id": setup_options.baseline_policy_id, + "process_id": setup_options.process_id, + } - sim_api_execution = simulation_api.run(sim_params) - execution_id = simulation_api.get_execution_id(sim_api_execution) + sim_api_execution = simulation_api.run(sim_params) + execution_id = simulation_api.get_execution_id(sim_api_execution) + except Exception as error: + error_message = f"Failed to start simulation API job: {str(error)}" + self._set_reform_impact_error( + setup_options=setup_options, + message=error_message, + execution_id=provisional_execution_id, + ) + return EconomicImpactResult.error(message=error_message) progress_log = { **setup_options.model_dump(), @@ -448,13 +1014,117 @@ def _handle_create_impact( } logger.log_struct(progress_log, severity="INFO") - self._set_reform_impact_computing( - setup_options=setup_options, - execution_id=execution_id, - ) + try: + updated_rows = self._update_reform_impact_execution_id( + setup_options=setup_options, + current_execution_id=provisional_execution_id, + new_execution_id=execution_id, + ) + except Exception as error: + logger.log_struct( + { + "message": "Failed to promote provisional reform impact row; inserting replacement tracking row", + **setup_options.model_dump(), + "execution_id": execution_id, + "provisional_execution_id": provisional_execution_id, + "error": str(error), + }, + severity="WARNING", + ) + updated_rows = 0 + + if updated_rows != 1: + self._recover_failed_execution_id_promotion( + setup_options=setup_options, + provisional_execution_id=provisional_execution_id, + execution_id=execution_id, + updated_rows=updated_rows, + ) return EconomicImpactResult.computing() + def _recover_failed_execution_id_promotion( + self, + *, + setup_options: EconomicImpactSetupOptions, + provisional_execution_id: str, + execution_id: str, + updated_rows: int | None, + ) -> None: + logger.log_struct( + { + "message": "Provisional reform impact row was not updated; checking whether tracking has already been superseded", + **setup_options.model_dump(), + "execution_id": execution_id, + "provisional_execution_id": provisional_execution_id, + "updated_rows": updated_rows, + }, + severity="WARNING", + ) + + try: + with reform_impacts_service.claim_lock( + country_id=setup_options.country_id, + policy_id=setup_options.reform_policy_id, + baseline_policy_id=setup_options.baseline_policy_id, + region=setup_options.region, + dataset=setup_options.dataset, + time_period=setup_options.time_period, + options_hash=setup_options.options_hash, + api_version=setup_options.api_version, + ): + most_recent_impact = self._get_most_recent_impact( + setup_options=setup_options + ) + if most_recent_impact is not None: + impact_status = most_recent_impact.get("status") + tracked_execution_id = most_recent_impact.get("execution_id") + if tracked_execution_id == execution_id: + return + + if ( + impact_status == ImpactStatus.COMPUTING.value + and tracked_execution_id == provisional_execution_id + ): + retry_updated_rows = self._update_reform_impact_execution_id( + setup_options=setup_options, + current_execution_id=provisional_execution_id, + new_execution_id=execution_id, + ) + if retry_updated_rows == 1: + return + elif impact_status in ( + ImpactStatus.OK.value, + ImpactStatus.COMPUTING.value, + ): + logger.log_struct( + { + "message": "Skipping replacement tracking row because another claim is already authoritative", + **setup_options.model_dump(), + "execution_id": execution_id, + "provisional_execution_id": provisional_execution_id, + "tracked_execution_id": tracked_execution_id, + "tracked_status": impact_status, + }, + severity="WARNING", + ) + return + + self._set_reform_impact_computing( + setup_options=setup_options, + execution_id=execution_id, + ) + except TimeoutError: + logger.log_struct( + { + "message": "Timed out while recovering failed provisional promotion; leaving the newer claim authoritative", + **setup_options.model_dump(), + "execution_id": execution_id, + "provisional_execution_id": provisional_execution_id, + }, + severity="WARNING", + ) + def _setup_sim_options( self, country_id: str, @@ -568,6 +1238,9 @@ def _set_reform_impact_computing( In the reform_impact table, set the status of the impact to "computing". """ try: + start_time = datetime.datetime.now(datetime.timezone.utc).replace( + tzinfo=None + ) reform_impacts_service.set_reform_impact( country_id=setup_options.country_id, policy_id=setup_options.reform_policy_id, @@ -580,7 +1253,7 @@ def _set_reform_impact_computing( status=ImpactStatus.COMPUTING.value, api_version=setup_options.api_version, reform_impact_json=json.dumps({}), - start_time=datetime.datetime.now(), + start_time=start_time, execution_id=execution_id, ) except Exception as e: @@ -592,6 +1265,33 @@ def _set_reform_impact_computing( ) raise e + def _update_reform_impact_execution_id( + self, + setup_options: EconomicImpactSetupOptions, + current_execution_id: str, + new_execution_id: str, + ) -> int | None: + try: + return reform_impacts_service.update_reform_impact_execution_id( + country_id=setup_options.country_id, + policy_id=setup_options.reform_policy_id, + baseline_policy_id=setup_options.baseline_policy_id, + region=setup_options.region, + dataset=setup_options.dataset, + time_period=setup_options.time_period, + options_hash=setup_options.options_hash, + current_execution_id=current_execution_id, + new_execution_id=new_execution_id, + ) + except Exception as e: + logger.log_struct( + { + "message": f"Error updating reform impact execution id: {str(e)}", + **setup_options.model_dump(), + } + ) + raise e + def _set_reform_impact_complete( self, setup_options: EconomicImpactSetupOptions, diff --git a/policyengine_api/services/reform_impacts_service.py b/policyengine_api/services/reform_impacts_service.py index ca44ea10c..111e989aa 100644 --- a/policyengine_api/services/reform_impacts_service.py +++ b/policyengine_api/services/reform_impacts_service.py @@ -1,14 +1,130 @@ -from policyengine_api.data import local_database +from contextlib import contextmanager +import hashlib +from threading import Lock +from policyengine_api.data import database import datetime +LOCAL_REFORM_IMPACT_LOCK = Lock() +REFORM_IMPACT_SCHEMA_LOCK = Lock() +REFORM_IMPACT_LOCK_TIMEOUT_SECONDS = 5 + + class ReformImpactsService: """ Service for storing and retrieving economy-wide reform impacts; - this is connected to the locally-stored reform_impact table - and no existing route + this is connected to the shared reform_impact table. """ + def __init__(self): + self._schema_checked = False + + def _ensure_remote_schema(self) -> None: + if database.local or self._schema_checked: + return + + with REFORM_IMPACT_SCHEMA_LOCK: + if self._schema_checked: + return + + existing_columns = { + row["Field"] + for row in database.query("SHOW COLUMNS FROM reform_impact").fetchall() + } + required_columns = { + "dataset": ( + "ALTER TABLE reform_impact " + "ADD COLUMN dataset VARCHAR(255) NOT NULL DEFAULT 'default'" + ), + "execution_id": ( + "ALTER TABLE reform_impact " + "ADD COLUMN execution_id VARCHAR(255) NULL" + ), + "end_time": ( + "ALTER TABLE reform_impact ADD COLUMN end_time DATETIME NULL" + ), + } + + for column_name, alter_query in required_columns.items(): + if column_name in existing_columns: + continue + try: + database.query(alter_query) + except Exception as error: + if "Duplicate column name" not in str(error): + raise + + self._schema_checked = True + + def _build_lock_name( + self, + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + api_version, + ) -> str: + raw_key = ( + f"{country_id}:{policy_id}:{baseline_policy_id}:{region}:{dataset}:" + f"{time_period}:{options_hash}:{api_version}" + ) + digest = hashlib.sha256(raw_key.encode("utf-8")).hexdigest() + return f"ri:{digest[:61]}" + + @contextmanager + def claim_lock( + self, + *, + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + api_version, + timeout_seconds: int = REFORM_IMPACT_LOCK_TIMEOUT_SECONDS, + ): + if database.local: + with LOCAL_REFORM_IMPACT_LOCK: + yield + return + + lock_name = self._build_lock_name( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + time_period=time_period, + options_hash=options_hash, + api_version=api_version, + ) + with database.pool.connect() as conn: + acquired = ( + conn.exec_driver_sql( + "SELECT GET_LOCK(%s, %s) AS acquired", + (lock_name, timeout_seconds), + ) + .mappings() + .first() + ) + if acquired is None or acquired["acquired"] != 1: + raise TimeoutError( + f"Could not acquire reform impact lock for {country_id}/{policy_id}/{time_period}" + ) + + try: + yield + finally: + conn.exec_driver_sql( + "SELECT RELEASE_LOCK(%s) AS released", (lock_name,) + ) + conn.commit() + def get_all_reform_impacts( self, country_id, @@ -21,13 +137,15 @@ def get_all_reform_impacts( api_version, ): try: + self._ensure_remote_schema() query = ( "SELECT reform_impact_json, status, message, start_time, execution_id FROM " "reform_impact WHERE country_id = ? AND reform_policy_id = ? AND " "baseline_policy_id = ? AND region = ? AND time_period = ? AND " - "options_hash = ? AND api_version = ? AND dataset = ?" + "options_hash = ? AND api_version = ? AND dataset = ? " + "ORDER BY start_time DESC, reform_impact_id DESC" ) - return local_database.query( + return database.query( query, ( country_id, @@ -61,12 +179,13 @@ def set_reform_impact( execution_id: str, ): try: + self._ensure_remote_schema() query = ( "INSERT INTO reform_impact (country_id, reform_policy_id, baseline_policy_id, " "region, dataset, time_period, options_json, options_hash, status, api_version, " "reform_impact_json, start_time, execution_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" ) - local_database.query( + database.query( query, ( country_id, @@ -88,7 +207,7 @@ def set_reform_impact( print(f"Error setting reform impact: {str(e)}") raise e - def delete_reform_impact( + def update_reform_impact_execution_id( self, country_id, policy_id, @@ -97,18 +216,21 @@ def delete_reform_impact( dataset, time_period, options_hash, + current_execution_id, + new_execution_id, ): try: + self._ensure_remote_schema() query = ( - "DELETE FROM reform_impact WHERE country_id = ? AND " - "reform_policy_id = ? AND baseline_policy_id = ? AND " - "region = ? AND time_period = ? AND options_hash = ? AND " - "dataset = ? AND status = 'computing'" + "UPDATE reform_impact SET execution_id = ? WHERE country_id = ? AND " + "reform_policy_id = ? AND baseline_policy_id = ? AND region = ? AND " + "time_period = ? AND options_hash = ? AND dataset = ? AND " + "execution_id = ? AND status = 'computing'" ) - - local_database.query( + result = database.query( query, ( + new_execution_id, country_id, policy_id, baseline_policy_id, @@ -116,10 +238,78 @@ def delete_reform_impact( time_period, options_hash, dataset, + current_execution_id, ), ) + return getattr(result, "rowcount", None) + except Exception as e: + print(f"Error updating reform impact execution id: {str(e)}") + raise e + + def delete_reform_impact( + self, + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + ): + return self.delete_reform_impacts( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + time_period=time_period, + options_hash=options_hash, + statuses=("computing",), + ) + + def delete_reform_impacts( + self, + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + api_version=None, + statuses=None, + ): + try: + self._ensure_remote_schema() + query = [ + "DELETE FROM reform_impact WHERE country_id = ? AND " + "reform_policy_id = ? AND baseline_policy_id = ? AND " + "region = ? AND time_period = ? AND options_hash = ? AND " + "dataset = ?" + ] + params = [ + country_id, + policy_id, + baseline_policy_id, + region, + time_period, + options_hash, + dataset, + ] + + if api_version is not None: + query.append(" AND api_version = ?") + params.append(api_version) + + if statuses: + placeholders = ", ".join(["?"] * len(statuses)) + query.append(f" AND status IN ({placeholders})") + params.extend(statuses) + + result = database.query("".join(query), tuple(params)) + return getattr(result, "rowcount", None) except Exception as e: - print(f"Error deleting reform impact: {str(e)}") + print(f"Error deleting reform impacts: {str(e)}") raise e def set_error_reform_impact( @@ -135,13 +325,14 @@ def set_error_reform_impact( execution_id: str, ): try: + self._ensure_remote_schema() query = ( "UPDATE reform_impact SET status = ?, message = ?, end_time = ? WHERE " "country_id = ? AND reform_policy_id = ? AND baseline_policy_id = ? AND " "region = ? AND time_period = ? AND options_hash = ? AND dataset = ? AND " "execution_id = ?" ) - local_database.query( + database.query( query, ( "error", @@ -179,13 +370,14 @@ def set_complete_reform_impact( execution_id, ): try: + self._ensure_remote_schema() query = ( "UPDATE reform_impact SET status = ?, message = ?, end_time = ?, " "reform_impact_json = ? WHERE country_id = ? AND reform_policy_id = ? AND " "baseline_policy_id = ? AND region = ? AND time_period = ? AND " "options_hash = ? AND dataset = ? AND execution_id = ?" ) - local_database.query( + database.query( query, ( "ok", diff --git a/policyengine_api/services/report_output_service.py b/policyengine_api/services/report_output_service.py index 3200ec6e8..7d22d9a73 100644 --- a/policyengine_api/services/report_output_service.py +++ b/policyengine_api/services/report_output_service.py @@ -262,3 +262,34 @@ def update_report_output( except Exception as e: print(f"Error updating report output #{report_id}. Details: {str(e)}") raise e + + def reset_report_output(self, country_id: str, report_id: int) -> bool: + """ + Reset a stored report output row back to a pending state. + + This is intentionally separate from update_report_output so rerun paths + can clear persisted output and errors without changing PATCH semantics. + """ + print(f"Resetting report output {report_id}") + + try: + requested_report = self._get_report_output_row(report_id) + if requested_report is None: + raise Exception(f"Report output #{report_id} not found") + + if requested_report["country_id"] != country_id: + raise Exception( + f"Report output #{report_id} does not belong to country {country_id}" + ) + + database.query( + "UPDATE report_outputs SET status = ?, output = NULL, error_message = NULL WHERE id = ?", + ("pending", requested_report["id"]), + ) + + print(f"Successfully reset report output #{report_id}") + return True + + except Exception as e: + print(f"Error resetting report output #{report_id}. Details: {str(e)}") + raise e diff --git a/policyengine_api/services/simulation_service.py b/policyengine_api/services/simulation_service.py index 7b83689e5..606c979fc 100644 --- a/policyengine_api/services/simulation_service.py +++ b/policyengine_api/services/simulation_service.py @@ -193,3 +193,30 @@ def update_simulation( except Exception as e: print(f"Error updating simulation #{simulation_id}. Details: {str(e)}") raise e + + def reset_simulation(self, country_id: str, simulation_id: int) -> bool: + """ + Reset a simulation row back to a pending state and clear persisted + output and errors. + """ + print(f"Resetting simulation {simulation_id}") + api_version: str = COUNTRY_PACKAGE_VERSIONS.get(country_id) + + try: + simulation = self.get_simulation( + country_id=country_id, simulation_id=simulation_id + ) + if simulation is None: + raise Exception(f"Simulation #{simulation_id} not found") + + database.query( + "UPDATE simulations SET status = ?, output = NULL, error_message = NULL, api_version = ? WHERE id = ?", + ("pending", api_version, simulation_id), + ) + + print(f"Successfully reset simulation #{simulation_id}") + return True + + except Exception as e: + print(f"Error resetting simulation #{simulation_id}. Details: {str(e)}") + raise e diff --git a/tests/fixtures/services/economy_service.py b/tests/fixtures/services/economy_service.py index 687a82a48..14c566772 100644 --- a/tests/fixtures/services/economy_service.py +++ b/tests/fixtures/services/economy_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch, MagicMock import json import datetime +from contextlib import nullcontext from policyengine_api.constants import ( MODAL_EXECUTION_STATUS_SUBMITTED, @@ -91,8 +92,10 @@ def mock_reform_impacts_service(): mock_service = MagicMock() mock_service.get_all_reform_impacts.return_value = [] mock_service.set_reform_impact.return_value = None + mock_service.update_reform_impact_execution_id.return_value = 1 mock_service.set_complete_reform_impact.return_value = None mock_service.set_error_reform_impact.return_value = None + mock_service.claim_lock.side_effect = lambda **kwargs: nullcontext() with patch( "policyengine_api.services.economy_service.reform_impacts_service", @@ -147,7 +150,10 @@ def mock_numpy_random(): def create_mock_reform_impact( - status="ok", reform_impact_json=None, execution_id=MOCK_MODAL_JOB_ID + status="ok", + reform_impact_json=None, + execution_id=MOCK_MODAL_JOB_ID, + start_time=None, ): """Helper function to create mock reform impact records.""" return { @@ -163,7 +169,7 @@ def create_mock_reform_impact( "api_version": MOCK_API_VERSION, "reform_impact_json": reform_impact_json or json.dumps(MOCK_REFORM_IMPACT_DATA), "execution_id": execution_id, - "start_time": datetime.datetime(2025, 6, 26, 12, 0, 0), + "start_time": start_time or datetime.datetime(2025, 6, 26, 12, 0, 0), "end_time": ( datetime.datetime(2025, 6, 26, 12, 5, 0) if status == "ok" else None ), diff --git a/tests/to_refactor/python/test_economy_budget_window_routes.py b/tests/to_refactor/python/test_economy_budget_window_routes.py new file mode 100644 index 000000000..fca938948 --- /dev/null +++ b/tests/to_refactor/python/test_economy_budget_window_routes.py @@ -0,0 +1,120 @@ +import json +from unittest.mock import Mock, patch + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_rejects_cliff_target( + mock_get_budget_window_economic_impact, rest_client +): + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=10&target=cliff" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "target=general" in data["message"] + mock_get_budget_window_economic_impact.assert_not_called() + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_requires_window_size( + mock_get_budget_window_economic_impact, rest_client +): + response = rest_client.get( + "/us/economy/123/over/456/budget-window?region=us&start_year=2026" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "window_size" in data["message"] + mock_get_budget_window_economic_impact.assert_not_called() + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_requires_integer_window_size( + mock_get_budget_window_economic_impact, rest_client +): + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=abc" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "window_size must be an integer" == data["message"] + mock_get_budget_window_economic_impact.assert_not_called() + + +def test_budget_window_route_rejects_oversized_window(rest_client): + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=999" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "window_size must be between 1 and" in data["message"] + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_passes_version_to_service( + mock_get_budget_window_economic_impact, rest_client +): + mock_result = Mock() + mock_result.to_dict.return_value = { + "status": "ok", + "message": None, + "data": { + "kind": "budgetWindow", + "startYear": "2026", + "endYear": "2027", + "windowSize": 2, + "annualImpacts": [], + "totals": {}, + }, + "progress": 100, + "completed_years": ["2026", "2027"], + "computing_years": [], + "queued_years": [], + "error": None, + } + mock_get_budget_window_economic_impact.return_value = mock_result + + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=2&version=1.2.3" + ) + + data = json.loads(response.data) + + assert response.status_code == 200 + assert data["status"] == "ok" + mock_get_budget_window_economic_impact.assert_called_once_with( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="default", + start_year="2026", + window_size=2, + options={}, + api_version="1.2.3", + target="general", + ) diff --git a/tests/to_refactor/python/test_us_policy_macro.py b/tests/to_refactor/python/test_us_policy_macro.py index f4228523c..2b3499993 100644 --- a/tests/to_refactor/python/test_us_policy_macro.py +++ b/tests/to_refactor/python/test_us_policy_macro.py @@ -47,7 +47,11 @@ def utah_reform_runner(rest_client, region: str = "us"): policy_id = policy_create.json["result"]["policy_id"] assert policy_id is not None - query = f"/us/economy/{policy_id}/over/{default_policy}?region={region}&time_period={test_year}" + cache_buster = int(time.time() * 1000) + query = ( + f"/us/economy/{policy_id}/over/{default_policy}" + f"?region={region}&time_period={test_year}&test_run={cache_buster}" + ) economy_response = rest_client.get(query) assert economy_response.status_code == 200 assert economy_response.json["status"] == "computing", ( diff --git a/tests/unit/data/test_sqlalchemy_v2.py b/tests/unit/data/test_sqlalchemy_v2.py index 3882bb0f7..2ea63f0f0 100644 --- a/tests/unit/data/test_sqlalchemy_v2.py +++ b/tests/unit/data/test_sqlalchemy_v2.py @@ -12,6 +12,7 @@ import pytest import sqlalchemy +from unittest.mock import MagicMock from policyengine_api.data.data import _ResultProxy, PolicyEngineDatabase @@ -180,3 +181,34 @@ def test_remote_delete(self): db._execute_remote(["DELETE FROM test_table WHERE id = ?", (1,)]) result = db._execute_remote(["SELECT * FROM test_table WHERE id = ?", (1,)]) assert result.fetchone() is None + + +class TestRemotePoolCreation: + def test_create_pool_uses_fresh_connection_creator(self, monkeypatch): + first_connection = MagicMock(name="first_connection") + second_connection = MagicMock(name="second_connection") + mock_connector = MagicMock() + mock_connector.connect.side_effect = [first_connection, second_connection] + + captured_kwargs = {} + + def fake_create_engine(url, **kwargs): + captured_kwargs.update(kwargs) + return MagicMock() + + monkeypatch.setenv("POLICYENGINE_DB_PASSWORD", "test-password") + monkeypatch.setattr( + "policyengine_api.data.data.Connector", lambda: mock_connector + ) + monkeypatch.setattr( + "policyengine_api.data.data.sqlalchemy.create_engine", + fake_create_engine, + ) + + db = PolicyEngineDatabase.__new__(PolicyEngineDatabase) + db._create_pool() + + creator = captured_kwargs["creator"] + assert creator() is first_connection + assert creator() is second_connection + assert captured_kwargs["pool_pre_ping"] is True diff --git a/tests/unit/endpoints/test_simulation.py b/tests/unit/endpoints/test_simulation.py new file mode 100644 index 000000000..a9013a056 --- /dev/null +++ b/tests/unit/endpoints/test_simulation.py @@ -0,0 +1,19 @@ +from unittest.mock import MagicMock, patch + +from policyengine_api.endpoints.simulation import get_simulations + + +def test_get_simulations_reads_from_remote_database(): + mock_database = MagicMock() + mock_database.query.return_value.fetchall.return_value = [{"id": 1}] + + with patch( + "policyengine_api.endpoints.simulation.get_remote_database", + return_value=mock_database, + ): + result = get_simulations() + + mock_database.query.assert_called_once_with( + "SELECT * FROM reform_impact ORDER BY start_time DESC LIMIT 100", + ) + assert result == {"result": [{"id": 1}]} diff --git a/tests/unit/routes/test_report_output_routes.py b/tests/unit/routes/test_report_output_routes.py new file mode 100644 index 000000000..c194373f8 --- /dev/null +++ b/tests/unit/routes/test_report_output_routes.py @@ -0,0 +1,427 @@ +import pytest +from flask import Flask + +from policyengine_api.constants import ( + get_economy_impact_cache_version, + get_report_output_cache_version, +) +from policyengine_api.routes.error_routes import error_bp +from policyengine_api.routes.report_output_routes import report_output_bp + + +@pytest.fixture +def client(): + app = Flask(__name__) + app.config["TESTING"] = True + app.register_blueprint(error_bp) + app.register_blueprint(report_output_bp) + + with app.test_client() as test_client: + yield test_client + + +def insert_simulation( + test_db, + *, + country_id="us", + api_version="0.0.0", + population_id="household_1", + population_type="household", + policy_id=1, + status="complete", + output='{"result": true}', + error_message="old error", +): + test_db.query( + """INSERT INTO simulations + (country_id, api_version, population_id, population_type, policy_id, status, output, error_message) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + ( + country_id, + api_version, + population_id, + population_type, + policy_id, + status, + output, + error_message, + ), + ) + return test_db.query("SELECT * FROM simulations ORDER BY id DESC LIMIT 1").fetchone() + + +def insert_report_output( + test_db, + *, + country_id="us", + simulation_1_id, + simulation_2_id=None, + status="complete", + output='{"report": true}', + error_message="old error", + year="2025", +): + test_db.query( + """INSERT INTO report_outputs + (country_id, simulation_1_id, simulation_2_id, api_version, status, output, error_message, year) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + ( + country_id, + simulation_1_id, + simulation_2_id, + get_report_output_cache_version(country_id), + status, + output, + error_message, + year, + ), + ) + return test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + +def insert_reform_impact( + test_db, + *, + baseline_policy_id, + reform_policy_id, + country_id="us", + region="us", + dataset="default", + time_period="2025", + options_json="[]", + options_hash="[]", + api_version=None, + reform_impact_json='{"impact": 1}', + status="ok", + message="Completed", + execution_id="exec-1", +): + if api_version is None: + api_version = get_economy_impact_cache_version(country_id) + + test_db.query( + """INSERT INTO reform_impact + (baseline_policy_id, reform_policy_id, country_id, region, dataset, time_period, + options_json, options_hash, api_version, reform_impact_json, status, message, start_time, + execution_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP, ?)""", + ( + baseline_policy_id, + reform_policy_id, + country_id, + region, + dataset, + time_period, + options_json, + options_hash, + api_version, + reform_impact_json, + status, + message, + execution_id, + ), + ) + + +def test_rerun_report_output_resets_household_report_and_simulation(client, test_db): + simulation = insert_simulation(test_db) + report_output = insert_report_output(test_db, simulation_1_id=simulation["id"]) + + response = client.post(f"/us/report/{report_output['id']}/rerun") + + assert response.status_code == 200 + payload = response.get_json() + assert payload["status"] == "ok" + assert payload["result"] == { + "report_id": report_output["id"], + "report_type": "household", + "simulation_ids": [simulation["id"]], + "economy_cache_rows_deleted": 0, + } + + reset_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", (report_output["id"],) + ).fetchone() + assert reset_report["status"] == "pending" + assert reset_report["output"] is None + assert reset_report["error_message"] is None + + reset_simulation = test_db.query( + "SELECT * FROM simulations WHERE id = ?", (simulation["id"],) + ).fetchone() + assert reset_simulation["status"] == "pending" + assert reset_simulation["output"] is None + assert reset_simulation["error_message"] is None + + +def test_rerun_report_output_resets_household_comparison_report_and_both_simulations( + client, test_db +): + baseline_simulation = insert_simulation( + test_db, + population_id="household_baseline", + policy_id=20, + ) + reform_simulation = insert_simulation( + test_db, + population_id="household_reform", + policy_id=21, + output='{"result": "comparison"}', + ) + report_output = insert_report_output( + test_db, + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + ) + + response = client.post(f"/us/report/{report_output['id']}/rerun") + + assert response.status_code == 200 + payload = response.get_json() + assert payload["status"] == "ok" + assert payload["result"] == { + "report_id": report_output["id"], + "report_type": "household", + "simulation_ids": [baseline_simulation["id"], reform_simulation["id"]], + "economy_cache_rows_deleted": 0, + } + + for simulation_id in (baseline_simulation["id"], reform_simulation["id"]): + reset_simulation = test_db.query( + "SELECT * FROM simulations WHERE id = ?", + (simulation_id,), + ).fetchone() + assert reset_simulation["status"] == "pending" + assert reset_simulation["output"] is None + assert reset_simulation["error_message"] is None + + +def test_rerun_report_output_resets_economy_report_and_purges_cache(client, test_db): + baseline_simulation = insert_simulation( + test_db, + population_id="state/ca", + population_type="geography", + policy_id=10, + ) + reform_simulation = insert_simulation( + test_db, + population_id="state/ca", + population_type="geography", + policy_id=11, + output='{"result": "reform"}', + ) + report_output = insert_report_output( + test_db, + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + ) + + current_version = get_economy_impact_cache_version("us") + insert_reform_impact( + test_db, + baseline_policy_id=10, + reform_policy_id=11, + region="state/ca", + api_version=current_version, + execution_id="exec-current", + ) + insert_reform_impact( + test_db, + baseline_policy_id=10, + reform_policy_id=11, + region="state/ca", + api_version="e1stale01", + execution_id="exec-stale", + ) + insert_reform_impact( + test_db, + baseline_policy_id=10, + reform_policy_id=11, + region="state/ca", + dataset="enhanced_cps", + api_version=current_version, + execution_id="exec-other-dataset", + ) + + response = client.post(f"/us/report/{report_output['id']}/rerun") + + assert response.status_code == 200 + payload = response.get_json() + assert payload["status"] == "ok" + assert payload["result"] == { + "report_id": report_output["id"], + "report_type": "geography", + "simulation_ids": [baseline_simulation["id"], reform_simulation["id"]], + "economy_cache_rows_deleted": 1, + } + + remaining_reform_impacts = test_db.query( + "SELECT execution_id FROM reform_impact ORDER BY execution_id" + ).fetchall() + assert [row["execution_id"] for row in remaining_reform_impacts] == [ + "exec-other-dataset", + "exec-stale", + ] + + +def test_rerun_report_output_single_simulation_economy_uses_baseline_policy_for_cache_key( + client, test_db +): + simulation = insert_simulation( + test_db, + population_id="state/ny", + population_type="geography", + policy_id=30, + ) + report_output = insert_report_output(test_db, simulation_1_id=simulation["id"]) + + current_version = get_economy_impact_cache_version("us") + insert_reform_impact( + test_db, + baseline_policy_id=30, + reform_policy_id=30, + region="state/ny", + api_version=current_version, + execution_id="exec-matching", + ) + insert_reform_impact( + test_db, + baseline_policy_id=30, + reform_policy_id=31, + region="state/ny", + api_version=current_version, + execution_id="exec-other-policy", + ) + + response = client.post(f"/us/report/{report_output['id']}/rerun") + + assert response.status_code == 200 + payload = response.get_json() + assert payload["status"] == "ok" + assert payload["result"] == { + "report_id": report_output["id"], + "report_type": "geography", + "simulation_ids": [simulation["id"]], + "economy_cache_rows_deleted": 1, + } + + remaining_reform_impacts = test_db.query( + "SELECT execution_id FROM reform_impact ORDER BY execution_id" + ).fetchall() + assert [row["execution_id"] for row in remaining_reform_impacts] == [ + "exec-other-policy" + ] + + +def test_rerun_report_output_missing_report_returns_404(client): + response = client.post("/us/report/999/rerun") + + assert response.status_code == 404 + payload = response.get_json() + assert payload["status"] == "error" + assert payload["result"] is None + assert "Report #999 not found." in payload["message"] + + +def test_rerun_report_output_missing_linked_simulation_returns_400(client, test_db): + report_output = insert_report_output(test_db, simulation_1_id=999) + + response = client.post(f"/us/report/{report_output['id']}/rerun") + + assert response.status_code == 400 + payload = response.get_json() + assert payload["status"] == "error" + assert payload["result"] is None + assert "references simulation #999" in payload["message"] + + unchanged_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", (report_output["id"],) + ).fetchone() + assert unchanged_report["status"] == "complete" + assert unchanged_report["output"] == '{"report": true}' + + +def test_rerun_report_output_missing_secondary_simulation_does_not_partially_reset( + client, test_db +): + baseline_simulation = insert_simulation( + test_db, + population_id="household_baseline", + policy_id=40, + ) + report_output = insert_report_output( + test_db, + simulation_1_id=baseline_simulation["id"], + simulation_2_id=999, + ) + + response = client.post(f"/us/report/{report_output['id']}/rerun") + + assert response.status_code == 400 + payload = response.get_json() + assert payload["status"] == "error" + assert "references simulation #999" in payload["message"] + + unchanged_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", (report_output["id"],) + ).fetchone() + assert unchanged_report["status"] == "complete" + assert unchanged_report["output"] == '{"report": true}' + assert unchanged_report["error_message"] == "old error" + + unchanged_simulation = test_db.query( + "SELECT * FROM simulations WHERE id = ?", (baseline_simulation["id"],) + ).fetchone() + assert unchanged_simulation["status"] == "complete" + assert unchanged_simulation["output"] == '{"result": true}' + assert unchanged_simulation["error_message"] == "old error" + + +def test_rerun_report_output_mismatched_population_types_returns_controlled_error( + client, test_db +): + geography_simulation = insert_simulation( + test_db, + population_id="state/tx", + population_type="geography", + policy_id=50, + ) + household_simulation = insert_simulation( + test_db, + population_id="household_mismatch", + population_type="household", + policy_id=51, + output='{"result": "mismatch"}', + ) + report_output = insert_report_output( + test_db, + simulation_1_id=geography_simulation["id"], + simulation_2_id=household_simulation["id"], + ) + + response = client.post(f"/us/report/{report_output['id']}/rerun") + + assert response.status_code == 400 + payload = response.get_json() + assert payload["status"] == "error" + assert "mismatched population types" in payload["message"] + + unchanged_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", (report_output["id"],) + ).fetchone() + assert unchanged_report["status"] == "complete" + assert unchanged_report["output"] == '{"report": true}' + assert unchanged_report["error_message"] == "old error" + + for simulation_id, expected_output in ( + (geography_simulation["id"], '{"result": true}'), + (household_simulation["id"], '{"result": "mismatch"}'), + ): + unchanged_simulation = test_db.query( + "SELECT * FROM simulations WHERE id = ?", + (simulation_id,), + ).fetchone() + assert unchanged_simulation["status"] == "complete" + assert unchanged_simulation["output"] == expected_output diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 162d30c20..dbe25ffdf 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -1,14 +1,19 @@ +import datetime import json import pytest from unittest.mock import patch, MagicMock from typing import Literal from policyengine_api.services.economy_service import ( + BUDGET_WINDOW_MAX_YEARS, EconomyService, EconomicImpactResult, EconomicImpactSetupOptions, ImpactAction, ImpactStatus, + PENDING_EXECUTION_ID_PREFIX, + PROVISIONAL_CLAIM_TTL_SECONDS, + STALE_PROVISIONAL_IMPACT_MESSAGE, ) from tests.fixtures.services.economy_service import ( MOCK_COUNTRY_ID, @@ -35,6 +40,23 @@ ) +def make_mock_budget_impact_data( + *, + tax_revenue_impact: int, + state_tax_revenue_impact: int, + benefit_spending_impact: int, + budgetary_impact: int, +): + return { + "budget": { + "tax_revenue_impact": tax_revenue_impact, + "state_tax_revenue_impact": state_tax_revenue_impact, + "benefit_spending_impact": benefit_spending_impact, + "budgetary_impact": budgetary_impact, + } + } + + class TestEconomyService: class TestGetEconomicImpact: @pytest.fixture @@ -80,6 +102,36 @@ def test__given_completed_impact__returns_completed_result( mock_reform_impacts_service.get_all_reform_impacts.assert_called_once() mock_simulation_api.run.assert_not_called() + def test__given_error_impact__returns_error_result( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + error_impact = create_mock_reform_impact( + status="error", + reform_impact_json=json.dumps({}), + ) + error_impact["message"] = "Failed to start simulation API job" + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + error_impact + ] + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.ERROR + assert result.data is None + assert result.message == "Failed to start simulation API job" + mock_reform_impacts_service.get_all_reform_impacts.assert_called_once() + mock_simulation_api.run.assert_not_called() + def test__given_computing_impact_with_succeeded_execution__returns_completed_result( self, economy_service, @@ -181,6 +233,21 @@ def test__given_no_previous_impact__creates_new_simulation( assert result.data is None mock_simulation_api.run.assert_called_once() mock_reform_impacts_service.set_reform_impact.assert_called_once() + assert any( + call.args == (datetime.timezone.utc,) + for call in mock_datetime.now.call_args_list + ) + mock_reform_impacts_service.update_reform_impact_execution_id.assert_called_once_with( + country_id=MOCK_COUNTRY_ID, + policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=MOCK_TIME_PERIOD, + options_hash=MOCK_OPTIONS_HASH, + current_execution_id=f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}", + new_execution_id=MOCK_EXECUTION_ID, + ) def test__given_no_previous_impact__includes_metadata_in_simulation_params( self, @@ -212,6 +279,226 @@ def test__given_no_previous_impact__includes_metadata_in_simulation_params( ) assert sim_params["_metadata"]["process_id"] == MOCK_PROCESS_ID + def test__given_simulation_api_submission_failure__marks_provisional_claim_error( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + mock_simulation_api.run.side_effect = RuntimeError("gateway unavailable") + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.ERROR + assert ( + result.message + == "Failed to start simulation API job: gateway unavailable" + ) + mock_reform_impacts_service.set_reform_impact.assert_called_once() + mock_reform_impacts_service.set_error_reform_impact.assert_called_once_with( + country_id=MOCK_COUNTRY_ID, + policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=MOCK_TIME_PERIOD, + options_hash=MOCK_OPTIONS_HASH, + message="Failed to start simulation API job: gateway unavailable", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}", + ) + mock_reform_impacts_service.update_reform_impact_execution_id.assert_not_called() + + def test__given_simulation_setup_failure__marks_provisional_claim_error( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + with patch.object( + economy_service, + "_setup_sim_options", + side_effect=ValueError("Invalid US state: 'zz'"), + ): + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.ERROR + assert ( + result.message + == "Failed to start simulation API job: Invalid US state: 'zz'" + ) + mock_reform_impacts_service.set_reform_impact.assert_called_once() + mock_reform_impacts_service.set_error_reform_impact.assert_called_once_with( + country_id=MOCK_COUNTRY_ID, + policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=MOCK_TIME_PERIOD, + options_hash=MOCK_OPTIONS_HASH, + message="Failed to start simulation API job: Invalid US state: 'zz'", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}", + ) + mock_simulation_api.run.assert_not_called() + mock_reform_impacts_service.update_reform_impact_execution_id.assert_not_called() + + def test__given_claim_lock_timeout_and_existing_provisional_claim__returns_computing( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_numpy_random, + ): + provisional_impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_other", + start_time=datetime.datetime.now(datetime.timezone.utc), + ) + mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ + [], + [provisional_impact], + ] + mock_reform_impacts_service.claim_lock.side_effect = TimeoutError( + "lock busy" + ) + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + mock_simulation_api.run.assert_not_called() + + def test__given_stale_provisional_claim__expires_and_recreates_simulation( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + ): + stale_start_time = datetime.datetime.now( + datetime.timezone.utc + ) - datetime.timedelta(seconds=PROVISIONAL_CLAIM_TTL_SECONDS + 1) + stale_provisional_impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", + start_time=stale_start_time, + ) + mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ + [stale_provisional_impact], + [stale_provisional_impact], + ] + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + mock_reform_impacts_service.set_error_reform_impact.assert_called_once_with( + country_id=MOCK_COUNTRY_ID, + policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=MOCK_TIME_PERIOD, + options_hash=MOCK_OPTIONS_HASH, + message=STALE_PROVISIONAL_IMPACT_MESSAGE, + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", + ) + mock_reform_impacts_service.set_reform_impact.assert_called_once() + mock_simulation_api.run.assert_called_once() + + def test__given_provisional_promotion_updates_zero_rows__inserts_replacement_tracking_row( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + mock_reform_impacts_service.update_reform_impact_execution_id.return_value = 0 + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + assert mock_reform_impacts_service.set_reform_impact.call_count == 2 + first_insert = mock_reform_impacts_service.set_reform_impact.call_args_list[ + 0 + ] + second_insert = ( + mock_reform_impacts_service.set_reform_impact.call_args_list[1] + ) + assert ( + first_insert.kwargs["execution_id"] + == f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}" + ) + assert second_insert.kwargs["execution_id"] == MOCK_EXECUTION_ID + + def test__given_provisional_promotion_updates_zero_rows_but_newer_claim_exists__does_not_insert_fallback( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + replacement_impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_replacement", + start_time=datetime.datetime.now(datetime.timezone.utc), + ) + mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ + [], + [], + [replacement_impact], + ] + mock_reform_impacts_service.update_reform_impact_execution_id.return_value = 0 + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + assert mock_reform_impacts_service.set_reform_impact.call_count == 1 + inserted_execution_id = ( + mock_reform_impacts_service.set_reform_impact.call_args.kwargs[ + "execution_id" + ] + ) + assert ( + inserted_execution_id + == f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}" + ) + def test__given_runtime_cache_version__uses_versioned_economy_cache_key( self, economy_service, @@ -235,7 +522,7 @@ def test__given_runtime_cache_version__uses_versioned_economy_cache_key( economy_service.get_economic_impact(**base_params) - mock_reform_impacts_service.get_all_reform_impacts.assert_called_once_with( + expected_call = ( MOCK_COUNTRY_ID, MOCK_POLICY_ID, MOCK_BASELINE_POLICY_ID, @@ -245,6 +532,11 @@ def test__given_runtime_cache_version__uses_versioned_economy_cache_key( MOCK_OPTIONS_HASH, cache_version, ) + assert mock_reform_impacts_service.get_all_reform_impacts.call_count == 2 + for ( + call_args + ) in mock_reform_impacts_service.get_all_reform_impacts.call_args_list: + assert call_args.args == expected_call def test__given_exception__raises_error( self, @@ -267,6 +559,367 @@ def test__given_exception__raises_error( economy_service.get_economic_impact(**base_params) assert str(exc_info.value) == "Database error" + class TestGetBudgetWindowEconomicImpact: + @pytest.fixture + def economy_service(self): + return EconomyService() + + @pytest.fixture + def base_params(self): + return { + "country_id": MOCK_COUNTRY_ID, + "policy_id": MOCK_POLICY_ID, + "baseline_policy_id": MOCK_BASELINE_POLICY_ID, + "region": MOCK_REGION, + "dataset": MOCK_DATASET, + "start_year": "2026", + "window_size": 3, + "options": MOCK_OPTIONS, + "api_version": MOCK_API_VERSION, + "target": "general", + } + + def test__given_all_years_completed__returns_aggregated_budget_window_result( + self, economy_service, base_params + ): + def make_setup(*, time_period, **_kwargs): + return EconomicImpactSetupOptions( + process_id=MOCK_PROCESS_ID, + country_id=MOCK_COUNTRY_ID, + reform_policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=time_period, + options=MOCK_OPTIONS, + api_version=MOCK_API_VERSION, + target="general", + options_hash=MOCK_OPTIONS_HASH, + ) + + yearly_results = { + "2026": EconomicImpactResult.completed( + make_mock_budget_impact_data( + tax_revenue_impact=100, + state_tax_revenue_impact=20, + benefit_spending_impact=-10, + budgetary_impact=90, + ) + ), + "2027": EconomicImpactResult.completed( + make_mock_budget_impact_data( + tax_revenue_impact=120, + state_tax_revenue_impact=30, + benefit_spending_impact=-20, + budgetary_impact=100, + ) + ), + "2028": EconomicImpactResult.completed( + make_mock_budget_impact_data( + tax_revenue_impact=140, + state_tax_revenue_impact=40, + benefit_spending_impact=-30, + budgetary_impact=110, + ) + ), + } + + with ( + patch.object( + economy_service, + "_build_economic_impact_setup_options", + side_effect=make_setup, + ), + patch.object( + economy_service, + "_get_existing_economic_impact", + side_effect=lambda setup_options: yearly_results[ + setup_options.time_period + ], + ) as mock_get_existing, + patch.object( + economy_service, "_get_or_create_economic_impact" + ) as mock_get_economic_impact, + ): + result = economy_service.get_budget_window_economic_impact( + **base_params + ) + + assert result.status == ImpactStatus.OK + assert result.progress == 100 + assert result.data["annualImpacts"] == [ + { + "year": "2026", + "taxRevenueImpact": 100, + "federalTaxRevenueImpact": 80, + "stateTaxRevenueImpact": 20, + "benefitSpendingImpact": -10, + "budgetaryImpact": 90, + }, + { + "year": "2027", + "taxRevenueImpact": 120, + "federalTaxRevenueImpact": 90, + "stateTaxRevenueImpact": 30, + "benefitSpendingImpact": -20, + "budgetaryImpact": 100, + }, + { + "year": "2028", + "taxRevenueImpact": 140, + "federalTaxRevenueImpact": 100, + "stateTaxRevenueImpact": 40, + "benefitSpendingImpact": -30, + "budgetaryImpact": 110, + }, + ] + assert result.data["totals"] == { + "year": "Total", + "taxRevenueImpact": 360, + "federalTaxRevenueImpact": 270, + "stateTaxRevenueImpact": 90, + "benefitSpendingImpact": -60, + "budgetaryImpact": 300, + } + assert mock_get_existing.call_count == 3 + mock_get_economic_impact.assert_not_called() + + def test__given_missing_years__starts_only_up_to_remaining_active_slots( + self, economy_service, base_params + ): + def make_setup(*, time_period, **_kwargs): + return EconomicImpactSetupOptions( + process_id=MOCK_PROCESS_ID, + country_id=MOCK_COUNTRY_ID, + reform_policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=time_period, + options=MOCK_OPTIONS, + api_version=MOCK_API_VERSION, + target="general", + options_hash=MOCK_OPTIONS_HASH, + ) + + base_params["window_size"] = 5 + + existing_results = { + "2026": EconomicImpactResult.completed( + make_mock_budget_impact_data( + tax_revenue_impact=100, + state_tax_revenue_impact=20, + benefit_spending_impact=-10, + budgetary_impact=90, + ) + ), + "2027": EconomicImpactResult.computing(), + "2028": None, + "2029": None, + "2030": None, + } + + with ( + patch.object( + economy_service, + "_build_economic_impact_setup_options", + side_effect=make_setup, + ), + patch.object( + economy_service, + "_get_existing_economic_impact", + side_effect=lambda setup_options: existing_results[ + setup_options.time_period + ], + ), + patch.object( + economy_service, + "_get_or_create_economic_impact", + return_value=EconomicImpactResult.computing(), + ) as mock_get_economic_impact, + ): + result = economy_service.get_budget_window_economic_impact( + **base_params + ) + + assert result.status == ImpactStatus.COMPUTING + assert result.progress == 20 + assert result.completed_years == ["2026"] + assert result.computing_years == ["2027", "2028", "2029"] + assert result.queued_years == ["2030"] + assert "1 of 5 complete" in result.message + assert mock_get_economic_impact.call_count == 2 + started_years = sorted( + call.args[0].time_period + for call in mock_get_economic_impact.call_args_list + ) + assert started_years == ["2028", "2029"] + + def test__given_year_error__returns_budget_window_error( + self, economy_service, base_params, mock_logger + ): + def make_setup(*, time_period, **_kwargs): + return EconomicImpactSetupOptions( + process_id=MOCK_PROCESS_ID, + country_id=MOCK_COUNTRY_ID, + reform_policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=time_period, + options=MOCK_OPTIONS, + api_version=MOCK_API_VERSION, + target="general", + options_hash=MOCK_OPTIONS_HASH, + ) + + with ( + patch.object( + economy_service, + "_build_economic_impact_setup_options", + side_effect=make_setup, + ), + patch.object( + economy_service, + "_get_existing_economic_impact", + side_effect=[ + EconomicImpactResult.completed( + make_mock_budget_impact_data( + tax_revenue_impact=100, + state_tax_revenue_impact=20, + benefit_spending_impact=-10, + budgetary_impact=90, + ) + ), + EconomicImpactResult( + status=ImpactStatus.ERROR, + data={"message": "Calculation failed for 2027"}, + ), + None, + ], + ), + patch.object( + economy_service, "_get_or_create_economic_impact" + ) as mock_get_economic_impact, + ): + result = economy_service.get_budget_window_economic_impact( + **base_params + ) + + assert result.status == ImpactStatus.ERROR + assert result.error == "Calculation failed for 2027" + assert result.completed_years == ["2026"] + mock_get_economic_impact.assert_not_called() + + def test__given_cliff_target__raises_value_error( + self, economy_service, base_params + ): + base_params["target"] = "cliff" + + with pytest.raises( + ValueError, + match="Budget-window calculations only support target='general'", + ): + economy_service.get_budget_window_economic_impact(**base_params) + + def test__given_oversized_window__raises_value_error( + self, economy_service, base_params + ): + base_params["window_size"] = BUDGET_WINDOW_MAX_YEARS + 1 + + with pytest.raises( + ValueError, + match=(f"window_size must be between 1 and {BUDGET_WINDOW_MAX_YEARS}"), + ): + economy_service.get_budget_window_economic_impact(**base_params) + + def test__given_started_year_error__returns_specific_budget_window_error( + self, economy_service, base_params, mock_logger + ): + with ( + patch.object( + economy_service, + "_get_existing_economic_impact", + side_effect=[None, None, None], + ), + patch.object( + economy_service, + "_get_or_create_economic_impact", + side_effect=[ + EconomicImpactResult.error("Calculation failed for 2026"), + EconomicImpactResult.computing(), + EconomicImpactResult.computing(), + ], + ), + ): + result = economy_service.get_budget_window_economic_impact( + **base_params + ) + + assert result.status == ImpactStatus.ERROR + assert result.error == "Calculation failed for 2026" + assert result.completed_years == [] + + def test__given_runtime_cache_version__uses_versioned_cache_key_for_budget_window( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_logger, + mock_datetime, + mock_numpy_random, + monkeypatch, + ): + cache_version = "e1cache01" + seen_existing_calls = [] + seen_create_calls = [] + + monkeypatch.setattr( + "policyengine_api.services.economy_service.get_economy_impact_cache_version", + lambda country_id, api_version=None: cache_version, + ) + + def fake_get_existing(setup_options): + seen_existing_calls.append( + (setup_options.time_period, setup_options.api_version) + ) + return None + + def fake_get_or_create(setup_options): + seen_create_calls.append( + (setup_options.time_period, setup_options.api_version) + ) + return EconomicImpactResult.computing() + + with ( + patch.object( + economy_service, + "_get_existing_economic_impact", + side_effect=fake_get_existing, + ), + patch.object( + economy_service, + "_get_or_create_economic_impact", + side_effect=fake_get_or_create, + ), + ): + result = economy_service.get_budget_window_economic_impact( + **base_params + ) + + assert result.status == ImpactStatus.COMPUTING + assert seen_existing_calls == [ + ("2026", cache_version), + ("2027", cache_version), + ("2028", cache_version), + ] + assert seen_create_calls == [ + ("2026", cache_version), + ("2027", cache_version), + ("2028", cache_version), + ] + class TestGetPreviousImpacts: @pytest.fixture def economy_service(self): @@ -349,6 +1002,47 @@ def test__given_no_impacts__returns_none( # Assert assert result is None + class TestGetExistingEconomicImpact: + @pytest.fixture + def economy_service(self): + return EconomyService() + + @pytest.fixture + def setup_options(self): + return EconomicImpactSetupOptions( + process_id=MOCK_PROCESS_ID, + country_id=MOCK_COUNTRY_ID, + reform_policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=MOCK_TIME_PERIOD, + options=MOCK_OPTIONS, + api_version=MOCK_API_VERSION, + target="general", + options_hash=MOCK_OPTIONS_HASH, + ) + + def test__given_stale_provisional_impact__returns_none( + self, + economy_service, + setup_options, + mock_reform_impacts_service, + ): + stale_impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", + start_time=datetime.datetime.now(datetime.timezone.utc) + - datetime.timedelta(seconds=PROVISIONAL_CLAIM_TTL_SECONDS + 1), + ) + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + stale_impact + ] + + result = economy_service._get_existing_economic_impact(setup_options) + + assert result is None + class TestDetermineImpactAction: @pytest.fixture def economy_service(self): @@ -366,12 +1060,12 @@ def test__given_ok_status__returns_completed(self, economy_service): assert result == ImpactAction.COMPLETED - def test__given_error_status__returns_completed(self, economy_service): + def test__given_error_status__returns_error(self, economy_service): impact = create_mock_reform_impact(status="error") result = economy_service._determine_impact_action(impact) - assert result == ImpactAction.COMPLETED + assert result == ImpactAction.ERROR def test__given_computing_status__returns_computing(self, economy_service): impact = create_mock_reform_impact(status="computing") @@ -380,6 +1074,20 @@ def test__given_computing_status__returns_computing(self, economy_service): assert result == ImpactAction.COMPUTING + def test__given_stale_provisional_computing_status__returns_create( + self, economy_service + ): + impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", + start_time=datetime.datetime.now(datetime.timezone.utc) + - datetime.timedelta(seconds=PROVISIONAL_CLAIM_TTL_SECONDS + 1), + ) + + result = economy_service._determine_impact_action(impact) + + assert result == ImpactAction.CREATE + def test__given_unknown_status__raises_error(self, economy_service): impact = create_mock_reform_impact(status="unknown") @@ -445,6 +1153,7 @@ def test__given_failed_state__returns_error_result( assert result.status == ImpactStatus.ERROR assert result.data is None + assert result.message == "Simulation API execution failed" mock_reform_impacts_service.set_error_reform_impact.assert_called_once() def test__given_active_state__returns_computing_result( @@ -459,6 +1168,21 @@ def test__given_active_state__returns_computing_result( assert result.status == ImpactStatus.COMPUTING assert result.data is None + def test__given_provisional_claim__returns_computing_without_polling( + self, economy_service, setup_options, mock_simulation_api, mock_logger + ): + reform_impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_pending", + ) + + result = economy_service._handle_computing_impact( + setup_options, reform_impact + ) + + assert result.status == ImpactStatus.COMPUTING + mock_simulation_api.get_execution_by_id.assert_not_called() + def test__given_unknown_state__raises_error( self, economy_service, setup_options ): @@ -516,6 +1240,7 @@ def test__given_modal_failed_state__then_returns_error_result( # Then assert result.status == ImpactStatus.ERROR assert result.data is None + assert result.message == "Simulation API execution failed" mock_reform_impacts_service.set_error_reform_impact.assert_called_once() def test__given_modal_failed_state_with_error_message__then_includes_error_in_message( @@ -537,6 +1262,10 @@ def test__given_modal_failed_state_with_error_message__then_includes_error_in_me # Then assert result.status == ImpactStatus.ERROR + assert ( + result.message + == "Simulation API execution failed: Simulation timed out" + ) # Verify the error message was passed to the service call_args = mock_reform_impacts_service.set_error_reform_impact.call_args assert "Simulation timed out" in call_args[1]["message"] @@ -634,6 +1363,7 @@ def test__given_error__creates_correct_instance_and_logs(self): assert result.status == ImpactStatus.ERROR assert result.data is None + assert result.message == "Test error message" mock_logger.log_struct.assert_called_once() diff --git a/tests/unit/services/test_reform_impacts_service.py b/tests/unit/services/test_reform_impacts_service.py new file mode 100644 index 000000000..2e0473b16 --- /dev/null +++ b/tests/unit/services/test_reform_impacts_service.py @@ -0,0 +1,326 @@ +import datetime +from unittest.mock import MagicMock + +import pytest + +from policyengine_api.constants import get_economy_impact_cache_version +from policyengine_api.services.reform_impacts_service import ReformImpactsService + + +def insert_reform_impact( + test_db, + *, + baseline_policy_id=1, + reform_policy_id=2, + country_id="us", + region="us", + dataset="default", + time_period="2025", + options_json="[]", + options_hash="[]", + api_version=None, + reform_impact_json='{"result": 1}', + status="ok", + message="Completed", + execution_id="exec-1", +): + if api_version is None: + api_version = get_economy_impact_cache_version(country_id) + + test_db.query( + """INSERT INTO reform_impact + (baseline_policy_id, reform_policy_id, country_id, region, dataset, time_period, + options_json, options_hash, api_version, reform_impact_json, status, message, start_time, + execution_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + baseline_policy_id, + reform_policy_id, + country_id, + region, + dataset, + time_period, + options_json, + options_hash, + api_version, + reform_impact_json, + status, + message, + datetime.datetime(2026, 1, 1, 12, 0, 0), + execution_id, + ), + ) + + +class TestReformImpactsService: + def test__given_remote_database_missing_columns__ensure_remote_schema_adds_them( + self, monkeypatch + ): + service = ReformImpactsService() + + show_columns_result = MagicMock() + show_columns_result.fetchall.return_value = [ + {"Field": "reform_impact_id"}, + {"Field": "status"}, + {"Field": "start_time"}, + ] + alter_dataset_result = MagicMock() + alter_execution_result = MagicMock() + alter_end_time_result = MagicMock() + + mock_database = MagicMock() + mock_database.local = False + mock_database.query.side_effect = [ + show_columns_result, + alter_dataset_result, + alter_execution_result, + alter_end_time_result, + ] + + monkeypatch.setattr( + "policyengine_api.services.reform_impacts_service.database", + mock_database, + ) + + service._ensure_remote_schema() + + assert mock_database.query.call_args_list[0].args == ( + "SHOW COLUMNS FROM reform_impact", + ) + assert mock_database.query.call_args_list[1].args == ( + "ALTER TABLE reform_impact ADD COLUMN dataset VARCHAR(255) NOT NULL DEFAULT 'default'", + ) + assert mock_database.query.call_args_list[2].args == ( + "ALTER TABLE reform_impact ADD COLUMN execution_id VARCHAR(255) NULL", + ) + assert mock_database.query.call_args_list[3].args == ( + "ALTER TABLE reform_impact ADD COLUMN end_time DATETIME NULL", + ) + + def test__given_remote_database_existing_columns__ensure_remote_schema_skips_alter( + self, monkeypatch + ): + service = ReformImpactsService() + + show_columns_result = MagicMock() + show_columns_result.fetchall.return_value = [ + {"Field": "reform_impact_id"}, + {"Field": "status"}, + {"Field": "start_time"}, + {"Field": "dataset"}, + {"Field": "execution_id"}, + {"Field": "end_time"}, + ] + + mock_database = MagicMock() + mock_database.local = False + mock_database.query.return_value = show_columns_result + + monkeypatch.setattr( + "policyengine_api.services.reform_impacts_service.database", + mock_database, + ) + + service._ensure_remote_schema() + + mock_database.query.assert_called_once_with("SHOW COLUMNS FROM reform_impact") + + def test__given_remote_database__claim_lock_uses_advisory_lock(self, monkeypatch): + service = ReformImpactsService() + + acquired_result = MagicMock() + acquired_result.mappings.return_value.first.return_value = {"acquired": 1} + release_result = MagicMock() + + mock_connection = MagicMock() + mock_connection.exec_driver_sql.side_effect = [ + acquired_result, + release_result, + ] + + mock_connection_context = MagicMock() + mock_connection_context.__enter__.return_value = mock_connection + mock_connection_context.__exit__.return_value = False + + mock_pool = MagicMock() + mock_pool.connect.return_value = mock_connection_context + + mock_database = MagicMock() + mock_database.local = False + mock_database.pool = mock_pool + + monkeypatch.setattr( + "policyengine_api.services.reform_impacts_service.database", + mock_database, + ) + + with service.claim_lock( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="2026", + options_hash="[option=value]", + api_version="e1cache01", + ): + pass + + assert mock_connection.exec_driver_sql.call_count == 2 + + acquire_call = mock_connection.exec_driver_sql.call_args_list[0] + assert acquire_call.args == ( + "SELECT GET_LOCK(%s, %s) AS acquired", + ( + service._build_lock_name( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="2026", + options_hash="[option=value]", + api_version="e1cache01", + ), + 5, + ), + ) + assert len(acquire_call.args[1][0]) <= 64 + + release_call = mock_connection.exec_driver_sql.call_args_list[1] + assert release_call.args == ( + "SELECT RELEASE_LOCK(%s) AS released", + (acquire_call.args[1][0],), + ) + mock_connection.commit.assert_called_once() + + def test__given_remote_database_lock_timeout__claim_lock_raises(self, monkeypatch): + service = ReformImpactsService() + + acquired_result = MagicMock() + acquired_result.mappings.return_value.first.return_value = {"acquired": 0} + + mock_connection = MagicMock() + mock_connection.exec_driver_sql.return_value = acquired_result + + mock_connection_context = MagicMock() + mock_connection_context.__enter__.return_value = mock_connection + mock_connection_context.__exit__.return_value = False + + mock_pool = MagicMock() + mock_pool.connect.return_value = mock_connection_context + + mock_database = MagicMock() + mock_database.local = False + mock_database.pool = mock_pool + + monkeypatch.setattr( + "policyengine_api.services.reform_impacts_service.database", + mock_database, + ) + + with pytest.raises( + TimeoutError, + match="Could not acquire reform impact lock", + ): + with service.claim_lock( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="2026", + options_hash="[option=value]", + api_version="e1cache01", + ): + pass + + def test_delete_reform_impacts_deletes_completed_rows_for_exact_cache_key( + self, test_db + ): + service = ReformImpactsService() + current_version = get_economy_impact_cache_version("us") + + insert_reform_impact( + test_db, + api_version=current_version, + status="ok", + execution_id="exec-ok", + ) + insert_reform_impact( + test_db, + api_version=current_version, + status="error", + execution_id="exec-error", + ) + insert_reform_impact( + test_db, + api_version=current_version, + status="computing", + execution_id="exec-computing", + ) + insert_reform_impact( + test_db, + api_version="e1stale01", + status="ok", + execution_id="exec-stale", + ) + insert_reform_impact( + test_db, + dataset="enhanced_cps", + api_version=current_version, + status="ok", + execution_id="exec-other-dataset", + ) + + deleted_rows = service.delete_reform_impacts( + country_id="us", + policy_id=2, + baseline_policy_id=1, + region="us", + dataset="default", + time_period="2025", + options_hash="[]", + api_version=current_version, + ) + + assert deleted_rows == 3 + + remaining_rows = test_db.query( + "SELECT execution_id, dataset, api_version, status FROM reform_impact ORDER BY execution_id" + ).fetchall() + assert [row["execution_id"] for row in remaining_rows] == [ + "exec-other-dataset", + "exec-stale", + ] + + def test_delete_reform_impact_keeps_completed_rows(self, test_db): + service = ReformImpactsService() + + insert_reform_impact( + test_db, + status="ok", + execution_id="exec-ok", + ) + insert_reform_impact( + test_db, + status="computing", + execution_id="exec-computing", + ) + + deleted_rows = service.delete_reform_impact( + country_id="us", + policy_id=2, + baseline_policy_id=1, + region="us", + dataset="default", + time_period="2025", + options_hash="[]", + ) + + assert deleted_rows == 1 + + remaining_rows = test_db.query( + "SELECT execution_id, status FROM reform_impact ORDER BY execution_id" + ).fetchall() + assert remaining_rows == [{"execution_id": "exec-ok", "status": "ok"}] diff --git a/tests/unit/services/test_report_output_service.py b/tests/unit/services/test_report_output_service.py index e3b63cbd3..55708f6a2 100644 --- a/tests/unit/services/test_report_output_service.py +++ b/tests/unit/services/test_report_output_service.py @@ -563,3 +563,59 @@ def test_update_report_output_stale_id_keeps_stale_output_quarantined( assert rows[0]["api_version"] == stale_version assert rows[0]["status"] == "complete" assert rows[0]["output"] == output_json + + +class TestResetReportOutput: + def test_reset_report_output_clears_output_and_error(self, test_db): + output_json = json.dumps({"result": "complete"}) + error_message = "old error" + + test_db.query( + """INSERT INTO report_outputs + (country_id, simulation_1_id, simulation_2_id, status, output, error_message, api_version, year) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + ( + "us", + 11, + None, + "complete", + output_json, + error_message, + get_report_output_cache_version("us"), + "2025", + ), + ) + + report = test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + success = service.reset_report_output( + country_id="us", + report_id=report["id"], + ) + + assert success is True + + reset_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report["id"],), + ).fetchone() + assert reset_report["status"] == "pending" + assert reset_report["output"] is None + assert reset_report["error_message"] is None + + def test_reset_report_output_rejects_wrong_country(self, test_db): + test_db.query( + """INSERT INTO report_outputs + (country_id, simulation_1_id, simulation_2_id, status, api_version, year) + VALUES (?, ?, ?, ?, ?, ?)""", + ("us", 12, None, "complete", get_report_output_cache_version("us"), "2025"), + ) + + report = test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + with pytest.raises(Exception, match="does not belong to country uk"): + service.reset_report_output(country_id="uk", report_id=report["id"]) diff --git a/tests/unit/services/test_simulation_service.py b/tests/unit/services/test_simulation_service.py index ac1fbccf6..cbb1e2774 100644 --- a/tests/unit/services/test_simulation_service.py +++ b/tests/unit/services/test_simulation_service.py @@ -1,5 +1,8 @@ +import json + import pytest +from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS from policyengine_api.services.simulation_service import SimulationService from tests.fixtures.services.simulation_fixtures import ( @@ -231,3 +234,59 @@ def test_duplicate_simulation_returns_existing(self, test_db): assert first_simulation["country_id"] == second_simulation["country_id"] assert first_simulation["population_id"] == second_simulation["population_id"] assert first_simulation["policy_id"] == second_simulation["policy_id"] + + +class TestResetSimulation: + def test_reset_simulation_clears_output_and_error(self, test_db): + output_json = json.dumps({"household": {"income": 100}}) + + test_db.query( + """INSERT INTO simulations + (country_id, api_version, population_id, population_type, policy_id, status, output, error_message) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + ( + "us", + "oldvers1", + "household_reset", + "household", + 42, + "complete", + output_json, + "old error", + ), + ) + + simulation = test_db.query( + "SELECT * FROM simulations ORDER BY id DESC LIMIT 1" + ).fetchone() + + success = service.reset_simulation( + country_id="us", + simulation_id=simulation["id"], + ) + + assert success is True + + reset_simulation = test_db.query( + "SELECT * FROM simulations WHERE id = ?", + (simulation["id"],), + ).fetchone() + assert reset_simulation["status"] == "pending" + assert reset_simulation["output"] is None + assert reset_simulation["error_message"] is None + assert reset_simulation["api_version"] == COUNTRY_PACKAGE_VERSIONS["us"] + + def test_reset_simulation_requires_matching_country(self, test_db): + test_db.query( + """INSERT INTO simulations + (country_id, api_version, population_id, population_type, policy_id, status) + VALUES (?, ?, ?, ?, ?, ?)""", + ("us", "oldvers1", "household_reset", "household", 43, "complete"), + ) + + simulation = test_db.query( + "SELECT * FROM simulations ORDER BY id DESC LIMIT 1" + ).fetchone() + + with pytest.raises(Exception, match="Simulation #.* not found"): + service.reset_simulation(country_id="uk", simulation_id=simulation["id"])