From 4123fe9739c1c4bccebaa149985d0415a4272ef1 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 29 Apr 2026 07:23:00 +0200 Subject: [PATCH 01/54] Anchor consumption grid lower bound to consumption_floor parameter Consumption is now declared as `IrregSpacedGrid(n_points=N)` (no fixed points). Callers inject log-spaced gridpoints from `consumption_floor` to $300k via `aca_model.consumption_grid. inject_consumption_points(params=..., model=...)` before solving. This means the lowest consumption choice equals the per-iteration floor, removing a degree of freedom from the grid and eliminating the previous mismatch where c < floor was a legal grid choice. Requires pylcm support for runtime-supplied points on continuous action grids (PR OpenSourceEconomics/pylcm#338). aca-model CI now installs pylcm from the matching `feature/runtime-action-grids` branch. Other changes: - `consumption_grid.py`: new module with `compute_consumption_points` and `inject_consumption_points` helpers. - `benchmark.get_benchmark_params(*, model=None)`: when `model` is given, returns params with consumption points injected. - `benchmark.get_benchmark_initial_conditions`: switch from `.start` / `.stop` to `to_jax().min()` / `.max()` so it works on both `LinSpacedGrid` and `PiecewiseLinSpacedGrid` (the AIME grid is now piecewise; this was a pre-existing bug surfacing as `AttributeError`). Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/main.yml | 4 +- src/aca_model/baseline/regimes/_common.py | 17 +---- src/aca_model/benchmark.py | 20 +++++- src/aca_model/consumption_grid.py | 76 +++++++++++++++++++++++ tests/test_benchmark.py | 2 +- 5 files changed, 97 insertions(+), 22 deletions(-) create mode 100644 src/aca_model/consumption_grid.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 67c82fa..a0247b8 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -26,10 +26,10 @@ jobs: - uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - - name: Install pylcm + - name: Install pylcm (unreleased feature branch required) run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@main" + git+https://github.com/OpenSourceEconomics/pylcm.git@feature/runtime-action-grids" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 688a504..efd3b73 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -194,14 +194,6 @@ class Grids: # bend points (0 → kink_0 → kink_1 → kink_2). Total = 32. _AIME_PIECE_N_POINTS: tuple[int, int, int] = (10, 11, 11) -# Consumption grid: log-spaced from the lower bound of the -# `consumption_floor` parameter (BOUNDS in task_estimate_parameters) -# up to a high value that brackets the unconstrained optimum for the -# richest agents in the state space. Mirrors the struct-ret design -# (concentrate gridpoints where CRRA curvature is highest). -_CONSUMPTION_GRID_START: float = 100.0 -_CONSUMPTION_GRID_STOP: float = 300_000.0 - def build_grids( grid_config: GridConfig = GRID_CONFIG, @@ -273,14 +265,7 @@ def build_grids( ), aime=_build_aime_grid(grid_config=grid_config, fixed_params=fixed_params), consumption=IrregSpacedGrid( - points=tuple( - float(c) - for c in np.geomspace( - _CONSUMPTION_GRID_START, - _CONSUMPTION_GRID_STOP, - num=grid_config.n_consumption_gridpoints, - ) - ), + n_points=grid_config.n_consumption_gridpoints, ), wage_res=wage_res, hcc_persistent=hcc_persistent, diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 5a822c9..a9d7128 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -44,6 +44,7 @@ from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.baseline.model import create_model from aca_model.config import BENCHMARK_GRID_CONFIG +from aca_model.consumption_grid import inject_consumption_points _PARAMS_FILE = ( Path(__file__).resolve().parent / "_benchmark_data" / "benchmark_params.pkl" @@ -96,17 +97,26 @@ def create_benchmark_model(*, pref_type_grid: DiscreteGrid | None = None) -> Mod ) -def get_benchmark_params() -> tuple[dict[str, Any], dict[str, Any]]: +def get_benchmark_params( + *, model: Model | None = None +) -> tuple[dict[str, Any], dict[str, Any]]: """Load the frozen `(fixed_params, params)` snapshot. Pref-type-indexed `pd.Series` in `params` are truncated to `_N_BENCHMARK_PREF_TYPES` rows so they line up with `BenchmarkPrefType`'s categories. + + When `model` is provided, consumption gridpoints are injected into + `params` for each regime that declares `consumption` as an + `IrregSpacedGrid` with runtime-supplied points. The lower bound is + read from `params["consumption_floor"]`. """ with _PARAMS_FILE.open("rb") as fh: data = cloudpickle.load(fh) fixed_params = data["fixed_params"] params = _truncate_pref_type_indexed(data["params"]) + if model is not None: + params = inject_consumption_points(params=params, model=model) return fixed_params, params @@ -143,10 +153,14 @@ def get_benchmark_initial_conditions( regime = rng.choice(regime_ids, size=n_subjects).astype(np.int32) # Grid ranges come from any of the five regimes (shared structure). + # Use to_jax() so the helper handles both LinSpacedGrid and + # PiecewiseLinSpacedGrid (the latter has no `.start` / `.stop`). ref_regime = model.regimes[_INITIAL_REGIMES[0]] grids = ref_regime.states - assets_lo, assets_hi = grids["assets"].start, grids["assets"].stop - aime_lo, aime_hi = grids["aime"].start, grids["aime"].stop + assets_pts = np.asarray(grids["assets"].to_jax()) + aime_pts = np.asarray(grids["aime"].to_jax()) + assets_lo, assets_hi = float(assets_pts.min()), float(assets_pts.max()) + aime_lo, aime_hi = float(aime_pts.min()), float(aime_pts.max()) hcc_p_pts = np.asarray(grids["hcc_persistent"].to_jax()) hcc_t_pts = np.asarray(grids["hcc_transitory"].to_jax()) wage_res_pts = np.asarray(grids["log_ft_wage_res"].to_jax()) diff --git a/src/aca_model/consumption_grid.py b/src/aca_model/consumption_grid.py new file mode 100644 index 0000000..d670d3c --- /dev/null +++ b/src/aca_model/consumption_grid.py @@ -0,0 +1,76 @@ +"""Runtime-supplied gridpoints for the consumption action. + +Consumption is declared as `IrregSpacedGrid(n_points=N)` in +`baseline.regimes._common.build_grids` so the lower bound can track +the per-iteration `consumption_floor` parameter. Callers must inject +the actual gridpoints into `params` via `inject_consumption_points` +before calling `model.solve()` / `model.simulate()`. +""" + +from collections.abc import Mapping +from typing import Any + +import jax.numpy as jnp +from jax import Array +from lcm import IrregSpacedGrid, Model + +MAX_CONSUMPTION: float = 300_000.0 +"""Upper bound of the consumption grid in $/year. Brackets the unconstrained +CRRA optimum for the highest-asset, highest-income agents in the state space.""" + + +def compute_consumption_points( + *, consumption_floor: float, n_points: int +) -> Array: + """Return log-spaced consumption gridpoints from the floor to `MAX_CONSUMPTION`. + + Args: + consumption_floor: Lowest gridpoint, equal to the `consumption_floor` + parameter so the agent cannot pick `c < floor` even when saving + from a transfer top-up. + n_points: Total number of gridpoints. + + Returns: + 1-D float array of length `n_points`. + """ + return jnp.geomspace(consumption_floor, MAX_CONSUMPTION, num=n_points) + + +def inject_consumption_points( + *, + params: Mapping[str, Any], + model: Model, + consumption_floor: float | None = None, +) -> dict[str, Any]: + """Inject consumption gridpoints into per-regime params. + + Walks `model.regimes`, finds those with `consumption` declared as + `IrregSpacedGrid` with runtime-supplied points, and writes + `params[regime_name]["consumption"] = {"points": }`. + + Args: + params: Existing params mapping. Returned as a new dict; the input is + not mutated. + model: Model whose regime specs determine which regimes need points. + consumption_floor: Lowest gridpoint. When `None`, taken from + `params["consumption_floor"]`. + + Returns: + New params dict with consumption points injected. + """ + if consumption_floor is None: + consumption_floor = float(params["consumption_floor"]) + out: dict[str, Any] = dict(params) + for regime_name, regime in model.regimes.items(): + grid = regime.actions.get("consumption") + if not ( + isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime + ): + continue + points = compute_consumption_points( + consumption_floor=consumption_floor, n_points=grid.n_points + ) + regime_entry = dict(out.get(regime_name, {})) + regime_entry["consumption"] = {"points": points} + out[regime_name] = regime_entry + return out diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 72fb473..c1e48e0 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -13,7 +13,7 @@ def test_benchmark_model_simulates_end_to_end() -> None: n_subjects = 20 model = create_benchmark_model() - _, params = get_benchmark_params() + _, params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=n_subjects, seed=0 ) From 134286108b7445f3e17e8824bcdd1739a98b6089 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 29 Apr 2026 18:31:50 +0200 Subject: [PATCH 02/54] Refactor utility_scale_factor to take pref_type, return scalar MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `utility_scale_factor` was registered as a regime function returning a (n_pref_types,) array, then re-indexed by `pref_type` inside `bequest` and `utility`. pylcm broadcasts function outputs to per-cell scalars before consumption, so that `[pref_type]` indexing produced silent NaN in the dead regime's V — surfaced as the all-NaN failure on the ASV benchmark. Mirror the `discount_factor` pattern: take the state as input, return a per-cell scalar. Drop the `[pref_type]` indexing on `utility_scale_factor` from `utility` and `bequest` (those still index the params-Series `consumption_weight` and `coefficient_rra`, which is the supported pattern — only DAG function outputs are pre-broadcast). The matching pylcm validator (PR #338) now raises a clear `RegimeInitializationError` when a function output is consumed via state-indexing in a downstream consumer; this aca-model change is the fix that lets the dead regime construct under that validator. Tests in `test_preferences.py` and `test_model_components.py` updated to pass scalar `utility_scale_factor` and supply the new `pref_type` arg to `utility_scale_factor`. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/agent/preferences.py | 42 ++++++++++++++++++------------ src/aca_model/consumption_grid.py | 8 ++---- tests/test_model_components.py | 8 +++--- tests/test_preferences.py | 11 ++++++-- 4 files changed, 40 insertions(+), 29 deletions(-) diff --git a/src/aca_model/agent/preferences.py b/src/aca_model/agent/preferences.py index 3b0bb5e..28a8367 100644 --- a/src/aca_model/agent/preferences.py +++ b/src/aca_model/agent/preferences.py @@ -133,8 +133,11 @@ def utility( """Within-period utility: CES aggregator over consumption and leisure. u = utility_scale_factor * ((c/eq_scale)^α * l^(1-α))^(1-γ) / (1-γ) - with log case for γ=1. `consumption_weight`, `coefficient_rra`, and - `utility_scale_factor` are indexed by `pref_type`. + with log case for γ=1. `consumption_weight` and `coefficient_rra` are + pref-type-indexed Series sourced directly from params; `utility_scale_factor` + is a regime-function output (already a per-cell scalar — must NOT be + re-indexed by pref_type, see `aca_model.agent.preferences.utility_scale_factor` + for why). """ alpha = consumption_weight[pref_type] gamma = coefficient_rra[pref_type] @@ -147,7 +150,7 @@ def utility( jnp.log(composite), composite**one_minus_gamma / one_minus_gamma, ) - return u * utility_scale_factor[pref_type] + return u * utility_scale_factor def discount_factor( @@ -164,6 +167,7 @@ def discount_factor( def utility_scale_factor( + pref_type: DiscreteState, average_consumption: float, consumption_weight: FloatND, coefficient_rra: FloatND, @@ -174,26 +178,29 @@ def utility_scale_factor( reference_age: int, scale_reference_age: int, ) -> FloatND: - """Compute scale factor so utility is approximately 1 at typical values. - - Uses leisure at `scale_reference_age` when working `scale_reference_hours` - (after fixed costs) and average consumption. Returns one scale per - preference type, indexed by pref_type. + """Compute the scale factor so utility is approximately 1 at typical values. + + Returns the scalar for the cell's `pref_type`. Mirrors the `discount_factor` + pattern: take the state as input, return a per-cell scalar. Registering this + as a regime function and then doing `utility_scale_factor[pref_type]` in a + downstream consumer is invalid — pylcm broadcasts function outputs to + per-cell scalars before consumption, and the validator in + `lcm.regime_building.validation` raises on that clash. """ + alpha = consumption_weight[pref_type] + gamma = coefficient_rra[pref_type] age_offset = scale_reference_age - reference_age average_leisure = ( time_endowment - scale_reference_hours - (fixed_cost_of_work_intercept + fixed_cost_of_work_age_trend * age_offset) ) - u_cons = average_consumption**consumption_weight - u_leisure = average_leisure ** (1.0 - consumption_weight) + u_cons = average_consumption**alpha + u_leisure = average_leisure ** (1.0 - alpha) - one_minus_gamma = jnp.where( - jnp.isclose(coefficient_rra, 1.0), 1.0, 1.0 - coefficient_rra - ) + one_minus_gamma = jnp.where(jnp.isclose(gamma, 1.0), 1.0, 1.0 - gamma) raw = jnp.where( - jnp.isclose(coefficient_rra, 1.0), + jnp.isclose(gamma, 1.0), jnp.log(u_cons * u_leisure), (u_cons * u_leisure) ** one_minus_gamma / one_minus_gamma, ) @@ -237,8 +244,9 @@ def bequest( """Bequest function for terminal/dead states. bequest = scale * bwt * (max(0,a) + shifter)^(α*(1-γ)) / (1-γ) - `consumption_weight`, `coefficient_rra`, and `utility_scale_factor` - are indexed by `pref_type`. + `consumption_weight` and `coefficient_rra` are pref-type-indexed Series + from params; `utility_scale_factor` is a regime-function output (already a + per-cell scalar — must NOT be re-indexed by pref_type). """ alpha = consumption_weight[pref_type] gamma = coefficient_rra[pref_type] @@ -250,4 +258,4 @@ def bequest( jnp.log(assets_shifted), assets_shifted ** (one_minus_gamma * alpha) / one_minus_gamma, ) - return val * scaled_bequest_weight * utility_scale_factor[pref_type] + return val * scaled_bequest_weight * utility_scale_factor diff --git a/src/aca_model/consumption_grid.py b/src/aca_model/consumption_grid.py index d670d3c..8ba8bc4 100644 --- a/src/aca_model/consumption_grid.py +++ b/src/aca_model/consumption_grid.py @@ -19,9 +19,7 @@ CRRA optimum for the highest-asset, highest-income agents in the state space.""" -def compute_consumption_points( - *, consumption_floor: float, n_points: int -) -> Array: +def compute_consumption_points(*, consumption_floor: float, n_points: int) -> Array: """Return log-spaced consumption gridpoints from the floor to `MAX_CONSUMPTION`. Args: @@ -63,9 +61,7 @@ def inject_consumption_points( out: dict[str, Any] = dict(params) for regime_name, regime in model.regimes.items(): grid = regime.actions.get("consumption") - if not ( - isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime - ): + if not (isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime): continue points = compute_consumption_points( consumption_floor=consumption_floor, n_points=grid.n_points diff --git a/tests/test_model_components.py b/tests/test_model_components.py index cbb2f72..9876ac8 100644 --- a/tests/test_model_components.py +++ b/tests/test_model_components.py @@ -84,7 +84,7 @@ def test_utility_positive_leisure() -> None: consumption_weight=jnp.array([0.4, 0.4, 0.4]), coefficient_rra=jnp.array([2.0, 2.0, 2.0]), equivalence_scale=jnp.array(1.0), - utility_scale_factor=jnp.array([1.0, 1.0, 1.0]), + utility_scale_factor=jnp.array(1.0), ) assert jnp.isfinite(result) @@ -97,7 +97,7 @@ def test_utility_log_case() -> None: consumption_weight=jnp.array([0.4, 0.4, 0.4]), coefficient_rra=jnp.array([1.0, 1.0, 1.0]), equivalence_scale=jnp.array(1.0), - utility_scale_factor=jnp.array([1.0, 1.0, 1.0]), + utility_scale_factor=jnp.array(1.0), ) composite = 10000.0**0.4 * 3000.0**0.6 expected = jnp.log(composite) @@ -112,7 +112,7 @@ def test_bequest_positive_assets() -> None: scaled_bequest_weight=0.5, consumption_weight=jnp.array([0.4, 0.4, 0.4]), coefficient_rra=jnp.array([2.0, 2.0, 2.0]), - utility_scale_factor=jnp.array([1.0, 1.0, 1.0]), + utility_scale_factor=jnp.array(1.0), ) assert jnp.isfinite(result) @@ -125,7 +125,7 @@ def test_bequest_zero_assets() -> None: scaled_bequest_weight=0.5, consumption_weight=jnp.array([0.4, 0.4, 0.4]), coefficient_rra=jnp.array([2.0, 2.0, 2.0]), - utility_scale_factor=jnp.array([1.0, 1.0, 1.0]), + utility_scale_factor=jnp.array(1.0), ) assert jnp.isfinite(result) assert result < 0 # CRRA with γ>1 gives negative values diff --git a/tests/test_preferences.py b/tests/test_preferences.py index 8c1921c..4ff2266 100644 --- a/tests/test_preferences.py +++ b/tests/test_preferences.py @@ -33,6 +33,7 @@ def test_utility_scale_factor_crra() -> None: result = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_5_BY_TYPE, @@ -43,11 +44,12 @@ def test_utility_scale_factor_crra() -> None: reference_age=REFERENCE_AGE, scale_reference_age=SCALE_REFERENCE_AGE, ) - assert jnp.isclose(result[0], 9_233_279_397_806_166.0, rtol=1e-6) + assert jnp.isclose(result, 9_233_279_397_806_166.0, rtol=1e-6) def test_utility_scale_factor_log() -> None: result = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_1_BY_TYPE, @@ -58,7 +60,7 @@ def test_utility_scale_factor_log() -> None: reference_age=REFERENCE_AGE, scale_reference_age=SCALE_REFERENCE_AGE, ) - assert jnp.isclose(result[0], 0.113_073_257_794_546_72, rtol=1e-6) + assert jnp.isclose(result, 0.113_073_257_794_546_72, rtol=1e-6) # --- scaled_bequest_weight --- @@ -105,6 +107,7 @@ def test_scaled_bequest_weight_zero() -> None: def test_utility_log_regression() -> None: scale = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_1_BY_TYPE, @@ -129,6 +132,7 @@ def test_utility_log_regression() -> None: def test_utility_crra_regression() -> None: scale = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_5_BY_TYPE, @@ -154,6 +158,7 @@ def test_utility_crra_regression() -> None: def test_utility_married_equivalence() -> None: """Married with equiv-scaled consumption should equal single utility.""" scale = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_5_BY_TYPE, @@ -190,6 +195,7 @@ def test_utility_married_equivalence() -> None: def test_bequest_log_regression() -> None: scale = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_1_BY_TYPE, @@ -222,6 +228,7 @@ def test_bequest_log_regression() -> None: def test_bequest_crra_regression() -> None: scale = preferences.utility_scale_factor( + pref_type=jnp.array(0), average_consumption=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_5_BY_TYPE, From 8cd8e37d3ada3eb8f9f91d76cd678d280d10e926 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 30 Apr 2026 14:59:49 +0200 Subject: [PATCH 03/54] Halve aca-model production assets batch size to fit V100 16GB Reduce n_assets_batch_size from 2 to 1 in MODEL_CONFIG so the assets state axis is streamed one slice at a time, lowering peak GPU memory during solve on the V100-PCIE-16GB. Benchmark grid config is unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aca_model/config.py b/src/aca_model/config.py index 1bae45f..2904ca4 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -36,7 +36,7 @@ class GridConfig: # `batch_size` on the assets grid: chunked vmap stride for the # outer state loop. Useful at prod sizes for memory reasons; set # to 0 in BENCHMARK_GRID_CONFIG to skip the Python-loop overhead. - n_assets_batch_size: int = 2 + n_assets_batch_size: int = 1 MODEL_CONFIG = ModelConfig() From 84d484a5153c241ab318b785749fe8b103a5ca0d Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 30 Apr 2026 17:05:11 +0200 Subject: [PATCH 04/54] =?UTF-8?q?Revert=20assets=20batch=5Fsize=20halving?= =?UTF-8?q?=20=E2=80=94=20V100=20OOM=20was=20elsewhere?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The OOM at production grid sizes came from pylcm's deferred diagnostics flush in solve_brute (`_emit_deferred_diagnostics` materialising a fused per-period reduction graph at end-of-solve), not from per-period peak. Halving the assets batch did not address that; reverting so the production loop runs at its previous throughput. Workaround for the diagnostics OOM lives in aca-estimation's simulate tasks (log_level="off"). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aca_model/config.py b/src/aca_model/config.py index 2904ca4..1bae45f 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -36,7 +36,7 @@ class GridConfig: # `batch_size` on the assets grid: chunked vmap stride for the # outer state loop. Useful at prod sizes for memory reasons; set # to 0 in BENCHMARK_GRID_CONFIG to skip the Python-loop overhead. - n_assets_batch_size: int = 1 + n_assets_batch_size: int = 2 MODEL_CONFIG = ModelConfig() From f0892efbbe891f198e7a42ba18147e322e7b71e7 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 10:35:14 +0200 Subject: [PATCH 05/54] Re-halve production assets batch size: V100 still OOMs per-period MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR #339's per-period `block_until_ready` made the OOM surface inside the loop instead of at the post-loop diagnostic flush, but the 7.26 GiB allocation request was the same — it isn't the diagnostic accumulator, it's a real per-period `max_Q_over_a` working set at production grid sizes (`n_consumption=70`, `n_assets=24`, `n_aime=12`, plus the per-target next-V gather across reachable regimes). Cutting the assets-axis chunk back to 1 reduces the per-kernel peak. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aca_model/config.py b/src/aca_model/config.py index 1bae45f..2904ca4 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -36,7 +36,7 @@ class GridConfig: # `batch_size` on the assets grid: chunked vmap stride for the # outer state loop. Useful at prod sizes for memory reasons; set # to 0 in BENCHMARK_GRID_CONFIG to skip the Python-loop overhead. - n_assets_batch_size: int = 2 + n_assets_batch_size: int = 1 MODEL_CONFIG = ModelConfig() From 08e42cb1e669f6e43582539bf4afae3cfbedafcd Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 12:45:59 +0200 Subject: [PATCH 06/54] config: add n_aime_batch_size to splay AIME outer-loop on V100 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Production solve allocates a per-period Q intermediate of shape `(non-assets-states × actions)` per assets-batch slot. With `n_assets_batch_size=1` we already chunk that axis to the minimum; the remaining outer-state product (aime × wage_res × hcc × pref_type × health × ...) times the action grid still pushes past the V100 16 GB once `pref_type` is split off into its own partition lift, which removes a free factor that previously thinned the kernel. Add a sibling `n_aime_batch_size` knob (default 1, 0 in `BENCHMARK_GRID_CONFIG`) and thread it through both AIME grid types in `_build_aime_grid`. AIME has 12 prod gridpoints in the LinSpaced fallback and 32 in the PiecewiseLinSpaced production path, so a unit batch shrinks the live Q intermediate by roughly that factor — enough headroom to land back inside V100 memory. Pairs with the pylcm-side fix that stops `_DiagnosticRow` pinning per-period V templates in device memory (lazy-solve-diagnostics branch). The diagnostic leak masked the underlying batching gap; once it's gone, the Q intermediate is the next thing to size for the device. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/baseline/regimes/_common.py | 9 +++++++-- src/aca_model/benchmark.py | 3 ++- src/aca_model/config.py | 9 ++++++--- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index efd3b73..327aa15 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -287,7 +287,10 @@ def _build_aime_grid( """ if fixed_params is None or "pia_aime_grid" not in fixed_params: return LinSpacedGrid( - start=0.0, stop=8_000.0, n_points=grid_config.n_aime_gridpoints + start=0.0, + stop=8_000.0, + n_points=grid_config.n_aime_gridpoints, + batch_size=grid_config.n_aime_batch_size, ) kinks = [float(k) for k in np.asarray(fixed_params["pia_aime_grid"])] pieces = ( @@ -295,7 +298,9 @@ def _build_aime_grid( Piece(interval=f"[{kinks[1]}, {kinks[2]})", n_points=_AIME_PIECE_N_POINTS[1]), Piece(interval=f"[{kinks[2]}, {kinks[3]}]", n_points=_AIME_PIECE_N_POINTS[2]), ) - return PiecewiseLinSpacedGrid(pieces=pieces) + return PiecewiseLinSpacedGrid( + pieces=pieces, batch_size=grid_config.n_aime_batch_size + ) def _has_required_wage_keys(*, wage_params: Mapping[str, Any]) -> bool: diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index a9d7128..3c06670 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -76,7 +76,8 @@ def create_benchmark_model(*, pref_type_grid: DiscreteGrid | None = None) -> Mod The benchmark uses a 2-type `BenchmarkPrefType`. No `batch_size != 0` on any grid (continuous grids inherit - `BENCHMARK_GRID_CONFIG.n_assets_batch_size = 0`). + `BENCHMARK_GRID_CONFIG.n_assets_batch_size = 0` and + `n_aime_batch_size = 0`). Args: pref_type_grid: Override for the pref_type grid. Default is a plain diff --git a/src/aca_model/config.py b/src/aca_model/config.py index 2904ca4..37fc0c8 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -33,10 +33,12 @@ class GridConfig: n_wage_res_gridpoints: int = 5 n_hcc_persistent_gridpoints: int = 3 n_hcc_transitory_gridpoints: int = 5 - # `batch_size` on the assets grid: chunked vmap stride for the - # outer state loop. Useful at prod sizes for memory reasons; set - # to 0 in BENCHMARK_GRID_CONFIG to skip the Python-loop overhead. + # `batch_size` on the assets / AIME grids: chunked vmap stride for the + # outer state loop. Both partition the per-period Q intermediate so it + # fits in V100 16 GB once we splay across `pref_type`. Set to 0 in + # BENCHMARK_GRID_CONFIG to skip the Python-loop overhead. n_assets_batch_size: int = 1 + n_aime_batch_size: int = 1 MODEL_CONFIG = ModelConfig() @@ -50,4 +52,5 @@ class GridConfig: n_hcc_persistent_gridpoints=3, n_hcc_transitory_gridpoints=3, n_assets_batch_size=0, + n_aime_batch_size=0, ) From e08fc19705549c5e59f38e5a704b993412c491be Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 1 May 2026 12:57:29 +0200 Subject: [PATCH 07/54] consumption_grid: read upper bound from `max_consumption` fixed param MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The grid floor already tracks the per-iteration `consumption_floor` parameter; the ceiling was a hardcoded 300k constant. Surface it as a fixed param via a marker function (`consumption_grid_upper_bound`) so callers can declare the bracket per model creation, and read it back at inject time from each regime's `resolved_fixed_params`. The marker function's output is intentionally unused — its only job is to put `max_consumption` in the regime params template so pylcm's fixed-param machinery captures it. dags.tree pruning drops the call at solve / simulate. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/baseline/regimes/_common.py | 7 +++ src/aca_model/consumption_grid.py | 75 ++++++++++++++--------- 2 files changed, 52 insertions(+), 30 deletions(-) diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 327aa15..d46f921 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -37,6 +37,7 @@ from aca_model.baseline import health_insurance from aca_model.baseline.health_insurance import BuyPrivate from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig +from aca_model.consumption_grid import consumption_grid_upper_bound from aca_model.environment import social_security, taxes from aca_model.environment.social_security import ClaimedSS @@ -537,6 +538,12 @@ def build_common_functions(spec: dict[str, str]) -> dict: functions["cash_on_hand"] = assets_and_income.cash_on_hand functions["transfers"] = assets_and_income.transfers + # Marker: surfaces `max_consumption` in the params template so it + # can be supplied via fixed_params and read back at inject time + # by `inject_consumption_points`. Output unused; pruned at + # solve / simulate. + functions["consumption_grid_upper_bound"] = consumption_grid_upper_bound + return functions diff --git a/src/aca_model/consumption_grid.py b/src/aca_model/consumption_grid.py index 8ba8bc4..6238328 100644 --- a/src/aca_model/consumption_grid.py +++ b/src/aca_model/consumption_grid.py @@ -1,10 +1,12 @@ """Runtime-supplied gridpoints for the consumption action. Consumption is declared as `IrregSpacedGrid(n_points=N)` in -`baseline.regimes._common.build_grids` so the lower bound can track -the per-iteration `consumption_floor` parameter. Callers must inject -the actual gridpoints into `params` via `inject_consumption_points` -before calling `model.solve()` / `model.simulate()`. +`baseline.regimes._common.build_grids` so the bounds can track +runtime parameters: the lower bound from the per-iteration +`consumption_floor` parameter, the upper bound from the per-creation-time +`max_consumption` fixed param. Callers must inject the actual +gridpoints into `params` via `inject_consumption_points` before +calling `model.solve()` / `model.simulate()`. """ from collections.abc import Mapping @@ -14,31 +16,11 @@ from jax import Array from lcm import IrregSpacedGrid, Model -MAX_CONSUMPTION: float = 300_000.0 -"""Upper bound of the consumption grid in $/year. Brackets the unconstrained -CRRA optimum for the highest-asset, highest-income agents in the state space.""" - - -def compute_consumption_points(*, consumption_floor: float, n_points: int) -> Array: - """Return log-spaced consumption gridpoints from the floor to `MAX_CONSUMPTION`. - - Args: - consumption_floor: Lowest gridpoint, equal to the `consumption_floor` - parameter so the agent cannot pick `c < floor` even when saving - from a transfer top-up. - n_points: Total number of gridpoints. - - Returns: - 1-D float array of length `n_points`. - """ - return jnp.geomspace(consumption_floor, MAX_CONSUMPTION, num=n_points) - def inject_consumption_points( *, params: Mapping[str, Any], model: Model, - consumption_floor: float | None = None, ) -> dict[str, Any]: """Inject consumption gridpoints into per-regime params. @@ -46,27 +28,60 @@ def inject_consumption_points( `IrregSpacedGrid` with runtime-supplied points, and writes `params[regime_name]["consumption"] = {"points": }`. + Lower bound: `params["consumption_floor"]` (varies per iteration). + Upper bound: `max_consumption` from the regime's resolved + fixed-params (set once at model creation). + Args: params: Existing params mapping. Returned as a new dict; the input is not mutated. model: Model whose regime specs determine which regimes need points. - consumption_floor: Lowest gridpoint. When `None`, taken from - `params["consumption_floor"]`. Returns: New params dict with consumption points injected. """ - if consumption_floor is None: - consumption_floor = float(params["consumption_floor"]) + consumption_floor = float(params["consumption_floor"]) out: dict[str, Any] = dict(params) for regime_name, regime in model.regimes.items(): grid = regime.actions.get("consumption") if not (isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime): continue - points = compute_consumption_points( - consumption_floor=consumption_floor, n_points=grid.n_points + # Runtime-points grids always have `n_points` set (the constructor + # rejects the (points=None, n_points=None) combo); narrow for ty. + assert grid.n_points is not None + max_consumption = float( + model.internal_regimes[regime_name].resolved_fixed_params["max_consumption"] + ) + points = _compute_consumption_points( + consumption_floor=consumption_floor, + max_consumption=max_consumption, + n_points=grid.n_points, ) regime_entry = dict(out.get(regime_name, {})) regime_entry["consumption"] = {"points": points} out[regime_name] = regime_entry return out + + +def consumption_grid_upper_bound(max_consumption: float) -> float: + """Surface `max_consumption` in the regime params template. + + pylcm builds the params template from each regime function's + signature. `max_consumption` is read at runtime by + `inject_consumption_points` from `resolved_fixed_params`; for + that to work via pylcm's fixed-params machinery, the key must + appear in some function's signature. This marker function is + the entry point — its output is intentionally unused, and + dags.tree pruning drops the call at solve / simulate time. + """ + return max_consumption + + +def _compute_consumption_points( + *, + consumption_floor: float, + max_consumption: float, + n_points: int, +) -> Array: + """Return log-spaced consumption gridpoints from floor to max.""" + return jnp.geomspace(consumption_floor, max_consumption, num=n_points) From c1ffb2a793ea4c5488b3e99a87883b80860c6b1d Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 2 May 2026 08:19:20 +0200 Subject: [PATCH 08/54] create_model: default `max_consumption` into fixed_params The runtime-upper-bound change requires every caller to supply `max_consumption` via `fixed_params`; estimation tasks (e.g. `task_simulate_aca`) hit a `KeyError` mid-pipeline because they construct the model from data-derived `fixed_params` that have no reason to mention a grid bracket. Centralise the default in both `baseline.model.create_model` and `aca.model.create_model` so existing callers keep working with the prior 300k bracket and only opt-in callers need to override. --- src/aca_model/aca/model.py | 4 +++- src/aca_model/baseline/model.py | 17 ++++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/aca_model/aca/model.py b/src/aca_model/aca/model.py index 8b4507a..0e02b0c 100644 --- a/src/aca_model/aca/model.py +++ b/src/aca_model/aca/model.py @@ -11,6 +11,7 @@ from aca_model.aca import PolicyVariant from aca_model.aca.regimes import build_all_regimes +from aca_model.baseline.model import _with_max_consumption_default from aca_model.baseline.regimes import RegimeId from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig @@ -51,6 +52,7 @@ def create_model( stop=MODEL_CONFIG.end_age - 1, step="Y", ) + fixed_params = _with_max_consumption_default(fixed_params) regimes = build_all_regimes( policy=policy, grid_config=grid_config, @@ -63,6 +65,6 @@ def create_model( ages=ages, regime_id_class=RegimeId, description=f"Structural retirement model ({policy.name})", - fixed_params=fixed_params or {}, + fixed_params=fixed_params, derived_categoricals=derived_categoricals, ) diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index a886495..2843fa0 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -18,6 +18,11 @@ from aca_model.baseline.regimes import RegimeId, build_all_regimes from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig +_DEFAULT_MAX_CONSUMPTION: float = 300_000.0 +"""Upper bound of the consumption grid in $/year. Brackets the unconstrained +CRRA optimum for the highest-asset, highest-income agents in the state space. +Callers can override by passing `max_consumption` in `fixed_params`.""" + def create_model( *, @@ -59,6 +64,7 @@ def create_model( stop=MODEL_CONFIG.end_age - 1, step="Y", ) + fixed_params = _with_max_consumption_default(fixed_params) regimes = build_all_regimes( grid_config, fixed_params=fixed_params, @@ -71,6 +77,15 @@ def create_model( ages=ages, regime_id_class=RegimeId, description="Baseline structural retirement model (pre-ACA)", - fixed_params=fixed_params or {}, + fixed_params=fixed_params, derived_categoricals=derived_categoricals, ) + + +def _with_max_consumption_default( + fixed_params: Mapping[str, Any] | None, +) -> dict[str, Any]: + """Return a copy of `fixed_params` with `max_consumption` defaulted.""" + out = dict(fixed_params or {}) + out.setdefault("max_consumption", _DEFAULT_MAX_CONSUMPTION) + return out From a21768752b15f70da2fa82910e6f4cebc26689d2 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 2 May 2026 13:20:38 +0200 Subject: [PATCH 09/54] create_model: forward n_subjects through baseline + aca + benchmark Lets callers opt in to pylcm's simulate-AOT path (`Model(n_subjects=...)`) without bypassing the aca-model factories. --- src/aca_model/aca/model.py | 2 ++ src/aca_model/baseline/model.py | 2 ++ src/aca_model/benchmark.py | 10 +++++++++- 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/aca_model/aca/model.py b/src/aca_model/aca/model.py index 0e02b0c..9a118c6 100644 --- a/src/aca_model/aca/model.py +++ b/src/aca_model/aca/model.py @@ -24,6 +24,7 @@ def create_model( derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]] | None = None, grid_config: GridConfig = GRID_CONFIG, + n_subjects: int | None = None, ) -> Model: """Create an ACA policy variant model. @@ -67,4 +68,5 @@ def create_model( description=f"Structural retirement model ({policy.name})", fixed_params=fixed_params, derived_categoricals=derived_categoricals, + n_subjects=n_subjects, ) diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index 2843fa0..9d6b04a 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -32,6 +32,7 @@ def create_model( | None = None, grid_config: GridConfig = GRID_CONFIG, pref_type_grid: DiscreteGrid | None = None, + n_subjects: int | None = None, ) -> Model: """Create the baseline structural retirement model. @@ -79,6 +80,7 @@ def create_model( description="Baseline structural retirement model (pre-ACA)", fixed_params=fixed_params, derived_categoricals=derived_categoricals, + n_subjects=n_subjects, ) diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 3c06670..35bf880 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -71,7 +71,11 @@ ) -def create_benchmark_model(*, pref_type_grid: DiscreteGrid | None = None) -> Model: +def create_benchmark_model( + *, + pref_type_grid: DiscreteGrid | None = None, + n_subjects: int | None = None, +) -> Model: """Create the aca baseline with `BENCHMARK_GRID_CONFIG` and frozen fixed_params. The benchmark uses a 2-type `BenchmarkPrefType`. No `batch_size != 0` @@ -86,6 +90,9 @@ def create_benchmark_model(*, pref_type_grid: DiscreteGrid | None = None) -> Mod (or `PARTITION_VMAP`) to get the partition-lifted kernel — the recommended production setting for aca-model at scale, but only supported on pylcm versions that expose `DispatchStrategy`. + n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. When set, the + first matching `simulate(...)` call AOT-compiles all simulate + functions for that batch shape. """ if pref_type_grid is None: pref_type_grid = DiscreteGrid(BenchmarkPrefType) @@ -95,6 +102,7 @@ def create_benchmark_model(*, pref_type_grid: DiscreteGrid | None = None) -> Mod fixed_params=fixed_params, derived_categoricals=_DERIVED_CATEGORICALS, pref_type_grid=pref_type_grid, + n_subjects=n_subjects, ) From d1eb320f0439a73a169bc5916d4af0b15208b2b6 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 2 May 2026 13:30:56 +0200 Subject: [PATCH 10/54] create_model: require n_subjects (no default) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The aca-model factories now require `n_subjects` as a kw-only int with no default — there's never a good reason for an aca-model caller to leave it unspecified, and silently letting it default to `None` (= no AOT, lazy-compile path) was exactly how the simulate-AOT benefit went unused on the prod estimation loop. Forcing each caller to make a deliberate choice catches that. Tests pass `n_subjects=1` for bare `get_params_template()` / shock-grid-inspection paths that never simulate. --- src/aca_model/aca/model.py | 2 +- src/aca_model/baseline/model.py | 2 +- src/aca_model/benchmark.py | 2 +- tests/test_benchmark.py | 2 +- tests/test_model_creation.py | 14 +++++++------- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/aca_model/aca/model.py b/src/aca_model/aca/model.py index 9a118c6..a9097b4 100644 --- a/src/aca_model/aca/model.py +++ b/src/aca_model/aca/model.py @@ -18,13 +18,13 @@ def create_model( *, + n_subjects: int, policy: PolicyVariant = PolicyVariant.ACA, fixed_params: Mapping[str, Any] | None = None, wage_params: Mapping[str, Any] | None = None, derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]] | None = None, grid_config: GridConfig = GRID_CONFIG, - n_subjects: int | None = None, ) -> Model: """Create an ACA policy variant model. diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index 9d6b04a..78c90eb 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -26,13 +26,13 @@ def create_model( *, + n_subjects: int, fixed_params: Mapping[str, Any] | None = None, wage_params: Mapping[str, Any] | None = None, derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]] | None = None, grid_config: GridConfig = GRID_CONFIG, pref_type_grid: DiscreteGrid | None = None, - n_subjects: int | None = None, ) -> Model: """Create the baseline structural retirement model. diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 35bf880..08f5ec6 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -73,8 +73,8 @@ def create_benchmark_model( *, + n_subjects: int, pref_type_grid: DiscreteGrid | None = None, - n_subjects: int | None = None, ) -> Model: """Create the aca baseline with `BENCHMARK_GRID_CONFIG` and frozen fixed_params. diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index c1e48e0..adafd66 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -12,7 +12,7 @@ @pytest.mark.long_running def test_benchmark_model_simulates_end_to_end() -> None: n_subjects = 20 - model = create_benchmark_model() + model = create_benchmark_model(n_subjects=n_subjects) _, params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=n_subjects, seed=0 diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index 841f2bb..d154e6b 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -21,24 +21,24 @@ def build_regime(name: str): def test_model_creates_successfully() -> None: - model = create_model() + model = create_model(n_subjects=1) assert len(model.regimes) == 19 assert model.n_periods == 45 def test_model_age_range() -> None: - model = create_model() + model = create_model(n_subjects=1) assert model.ages.values[0] == 51.0 assert model.ages.values[-1] == 95.0 def test_dead_regime_is_terminal() -> None: - model = create_model() + model = create_model(n_subjects=1) assert model.regimes["dead"].terminal def test_non_terminal_regimes_not_terminal() -> None: - model = create_model() + model = create_model(n_subjects=1) for name in REGIME_SPECS: assert not model.regimes[name].terminal @@ -170,7 +170,7 @@ def test_hcc_persistent_and_transitory_are_shock_grids() -> None: def test_aca_model_creates_successfully() -> None: - model = create_aca_model() + model = create_aca_model(n_subjects=1) assert len(model.regimes) == 19 assert model.n_periods == 45 @@ -211,7 +211,7 @@ def test_aca_other_regimes_have_no_aca_policy_keys() -> None: @pytest.mark.parametrize("policy", list(PolicyVariant)) def test_all_policy_variants_create(policy: PolicyVariant) -> None: """All policy variants create valid models.""" - model = create_aca_model(policy=policy) + model = create_aca_model(n_subjects=1, policy=policy) assert len(model.regimes) == 19 @@ -251,5 +251,5 @@ def test_aca_only_medicaid_expansion() -> None: def test_baseline_model_creates() -> None: """Baseline model creates successfully without PolicyVariant.""" - model = create_model() + model = create_model(n_subjects=1) assert len(model.regimes) == 19 From 9e252051ad53683a8ad65e9dba68a910240103c0 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 2 May 2026 13:41:55 +0200 Subject: [PATCH 11/54] ci: install pylcm from feat/simulate-aot-n-subjects (carries Model.n_subjects) --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index a0247b8..bd40367 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -29,7 +29,7 @@ jobs: - name: Install pylcm (unreleased feature branch required) run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@feature/runtime-action-grids" + git+https://github.com/OpenSourceEconomics/pylcm.git@feat/simulate-aot-n-subjects" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest From cdd10169a13fbc74604f9e879276ddb4c17b53c4 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 2 May 2026 21:02:13 +0200 Subject: [PATCH 12/54] consumption_grid: max_consumption is a required factory arg, attached to Model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The marker-function-via-DAG pattern didn't survive pylcm's pruning: `consumption_grid_upper_bound`'s output is unused, so dags.tree drops it before its `max_consumption` parameter reaches the params template, and `broadcast_to_template` has nowhere to put the value. Result: `resolved_fixed_params["max_consumption"]` was always missing, `inject_consumption_points` raised KeyError. Sidestep pylcm's params machinery for this knob: - Drop the `consumption_grid_upper_bound` marker function and the `_with_max_consumption_default` helper. - Add `max_consumption: float` (kw-only, required, no default) to all three factories: `baseline.create_model`, `aca.create_model`, `create_benchmark_model`. - Each factory attaches the value directly to the returned `Model` instance (`model.max_consumption = ...`). - `inject_consumption_points` reads `model.max_consumption` directly. No defaults — every caller passes the bracket explicitly. --- src/aca_model/aca/model.py | 13 ++++++--- src/aca_model/baseline/model.py | 35 +++++++++++------------ src/aca_model/baseline/regimes/_common.py | 7 ----- src/aca_model/benchmark.py | 2 ++ src/aca_model/consumption_grid.py | 33 ++++++--------------- tests/test_benchmark.py | 2 +- tests/test_model_creation.py | 14 ++++----- 7 files changed, 45 insertions(+), 61 deletions(-) diff --git a/src/aca_model/aca/model.py b/src/aca_model/aca/model.py index a9097b4..942be22 100644 --- a/src/aca_model/aca/model.py +++ b/src/aca_model/aca/model.py @@ -11,7 +11,6 @@ from aca_model.aca import PolicyVariant from aca_model.aca.regimes import build_all_regimes -from aca_model.baseline.model import _with_max_consumption_default from aca_model.baseline.regimes import RegimeId from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig @@ -19,6 +18,7 @@ def create_model( *, n_subjects: int, + max_consumption: float, policy: PolicyVariant = PolicyVariant.ACA, fixed_params: Mapping[str, Any] | None = None, wage_params: Mapping[str, Any] | None = None, @@ -29,6 +29,7 @@ def create_model( """Create an ACA policy variant model. Args: + n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. policy: Which ACA policy combination to apply. fixed_params: Parameters to fix at model creation time. These are partialled into compiled functions and removed from the params @@ -43,6 +44,9 @@ def create_model( contains `pd.Series` indexed by DAG function outputs. grid_config: Continuous-grid point counts. Defaults to production values. + max_consumption: Upper bound of the runtime consumption grid in + $/year. Attached to the returned Model and read back at inject + time by `inject_consumption_points`. Returns: pylcm Model with ACA-specific function overrides. @@ -53,7 +57,6 @@ def create_model( stop=MODEL_CONFIG.end_age - 1, step="Y", ) - fixed_params = _with_max_consumption_default(fixed_params) regimes = build_all_regimes( policy=policy, grid_config=grid_config, @@ -61,12 +64,14 @@ def create_model( wage_params=wage_params, ) - return Model( + model = Model( regimes=regimes, ages=ages, regime_id_class=RegimeId, description=f"Structural retirement model ({policy.name})", - fixed_params=fixed_params, + fixed_params=fixed_params or {}, derived_categoricals=derived_categoricals, n_subjects=n_subjects, ) + model.max_consumption = max_consumption + return model diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index 78c90eb..0ff7c47 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -5,7 +5,7 @@ Usage: from aca_model.baseline.model import create_model - model = create_model() + model = create_model(n_subjects=...) params = get_default_params() V = model.solve(params) """ @@ -18,15 +18,11 @@ from aca_model.baseline.regimes import RegimeId, build_all_regimes from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig -_DEFAULT_MAX_CONSUMPTION: float = 300_000.0 -"""Upper bound of the consumption grid in $/year. Brackets the unconstrained -CRRA optimum for the highest-asset, highest-income agents in the state space. -Callers can override by passing `max_consumption` in `fixed_params`.""" - def create_model( *, n_subjects: int, + max_consumption: float, fixed_params: Mapping[str, Any] | None = None, wage_params: Mapping[str, Any] | None = None, derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]] @@ -37,6 +33,7 @@ def create_model( """Create the baseline structural retirement model. Args: + n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. fixed_params: Parameters to fix at model creation time. These are partialled into compiled functions and removed from the params template. Pass data-derived constants here; only estimation @@ -54,6 +51,9 @@ def create_model( pref_type_grid: Optional override for the `pref_type` `DiscreteGrid`. Defaults to `DiscreteGrid(PrefType)`. Used by the benchmark to substitute a 2-type variant with `DispatchStrategy.PARTITION_SCAN`. + max_consumption: Upper bound of the runtime consumption grid in + $/year. Attached to the returned Model and read back at inject + time by `inject_consumption_points`. Returns: A pylcm Model with 19 regimes (18 non-terminal + dead) spanning @@ -65,7 +65,6 @@ def create_model( stop=MODEL_CONFIG.end_age - 1, step="Y", ) - fixed_params = _with_max_consumption_default(fixed_params) regimes = build_all_regimes( grid_config, fixed_params=fixed_params, @@ -73,21 +72,21 @@ def create_model( pref_type_grid=pref_type_grid, ) - return Model( + model = Model( regimes=regimes, ages=ages, regime_id_class=RegimeId, description="Baseline structural retirement model (pre-ACA)", - fixed_params=fixed_params, + fixed_params=fixed_params or {}, derived_categoricals=derived_categoricals, n_subjects=n_subjects, ) - - -def _with_max_consumption_default( - fixed_params: Mapping[str, Any] | None, -) -> dict[str, Any]: - """Return a copy of `fixed_params` with `max_consumption` defaulted.""" - out = dict(fixed_params or {}) - out.setdefault("max_consumption", _DEFAULT_MAX_CONSUMPTION) - return out + # Attach the consumption-grid upper bound directly to the Model + # instance. Tried surfacing it via a marker function in the regime + # DAG first — pylcm's pruning drops unused-output functions before + # their parameters reach the params template, so the value never + # made it into `resolved_fixed_params`. Direct attachment sidesteps + # the templating machinery entirely; `inject_consumption_points` + # reads `model.max_consumption`. + model.max_consumption = max_consumption + return model diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index d46f921..327aa15 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -37,7 +37,6 @@ from aca_model.baseline import health_insurance from aca_model.baseline.health_insurance import BuyPrivate from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig -from aca_model.consumption_grid import consumption_grid_upper_bound from aca_model.environment import social_security, taxes from aca_model.environment.social_security import ClaimedSS @@ -538,12 +537,6 @@ def build_common_functions(spec: dict[str, str]) -> dict: functions["cash_on_hand"] = assets_and_income.cash_on_hand functions["transfers"] = assets_and_income.transfers - # Marker: surfaces `max_consumption` in the params template so it - # can be supplied via fixed_params and read back at inject time - # by `inject_consumption_points`. Output unused; pruned at - # solve / simulate. - functions["consumption_grid_upper_bound"] = consumption_grid_upper_bound - return functions diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 08f5ec6..3d7fbad 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -74,6 +74,7 @@ def create_benchmark_model( *, n_subjects: int, + max_consumption: float, pref_type_grid: DiscreteGrid | None = None, ) -> Model: """Create the aca baseline with `BENCHMARK_GRID_CONFIG` and frozen fixed_params. @@ -103,6 +104,7 @@ def create_benchmark_model( derived_categoricals=_DERIVED_CATEGORICALS, pref_type_grid=pref_type_grid, n_subjects=n_subjects, + max_consumption=max_consumption, ) diff --git a/src/aca_model/consumption_grid.py b/src/aca_model/consumption_grid.py index 6238328..bd342ee 100644 --- a/src/aca_model/consumption_grid.py +++ b/src/aca_model/consumption_grid.py @@ -3,10 +3,11 @@ Consumption is declared as `IrregSpacedGrid(n_points=N)` in `baseline.regimes._common.build_grids` so the bounds can track runtime parameters: the lower bound from the per-iteration -`consumption_floor` parameter, the upper bound from the per-creation-time -`max_consumption` fixed param. Callers must inject the actual -gridpoints into `params` via `inject_consumption_points` before -calling `model.solve()` / `model.simulate()`. +`consumption_floor` parameter, the upper bound from a per-model +`max_consumption` knob attached to the `Model` instance by the +`create_model` factories. Callers must inject the actual gridpoints +into `params` via `inject_consumption_points` before calling +`model.solve()` / `model.simulate()`. """ from collections.abc import Mapping @@ -24,13 +25,13 @@ def inject_consumption_points( ) -> dict[str, Any]: """Inject consumption gridpoints into per-regime params. - Walks `model.regimes`, finds those with `consumption` declared as + Walks every regime, finds the action whose grid is an `IrregSpacedGrid` with runtime-supplied points, and writes `params[regime_name]["consumption"] = {"points": }`. Lower bound: `params["consumption_floor"]` (varies per iteration). - Upper bound: `max_consumption` from the regime's resolved - fixed-params (set once at model creation). + Upper bound: `model.max_consumption` (required attribute; set by + the `create_model` factory). Args: params: Existing params mapping. Returned as a new dict; the input is @@ -41,6 +42,7 @@ def inject_consumption_points( New params dict with consumption points injected. """ consumption_floor = float(params["consumption_floor"]) + max_consumption = float(model.max_consumption) out: dict[str, Any] = dict(params) for regime_name, regime in model.regimes.items(): grid = regime.actions.get("consumption") @@ -49,9 +51,6 @@ def inject_consumption_points( # Runtime-points grids always have `n_points` set (the constructor # rejects the (points=None, n_points=None) combo); narrow for ty. assert grid.n_points is not None - max_consumption = float( - model.internal_regimes[regime_name].resolved_fixed_params["max_consumption"] - ) points = _compute_consumption_points( consumption_floor=consumption_floor, max_consumption=max_consumption, @@ -63,20 +62,6 @@ def inject_consumption_points( return out -def consumption_grid_upper_bound(max_consumption: float) -> float: - """Surface `max_consumption` in the regime params template. - - pylcm builds the params template from each regime function's - signature. `max_consumption` is read at runtime by - `inject_consumption_points` from `resolved_fixed_params`; for - that to work via pylcm's fixed-params machinery, the key must - appear in some function's signature. This marker function is - the entry point — its output is intentionally unused, and - dags.tree pruning drops the call at solve / simulate time. - """ - return max_consumption - - def _compute_consumption_points( *, consumption_floor: float, diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index adafd66..6173318 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -12,7 +12,7 @@ @pytest.mark.long_running def test_benchmark_model_simulates_end_to_end() -> None: n_subjects = 20 - model = create_benchmark_model(n_subjects=n_subjects) + model = create_benchmark_model(n_subjects=n_subjects, max_consumption=300_000.0) _, params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=n_subjects, seed=0 diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index d154e6b..accde27 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -21,24 +21,24 @@ def build_regime(name: str): def test_model_creates_successfully() -> None: - model = create_model(n_subjects=1) + model = create_model(n_subjects=1, max_consumption=300_000.0) assert len(model.regimes) == 19 assert model.n_periods == 45 def test_model_age_range() -> None: - model = create_model(n_subjects=1) + model = create_model(n_subjects=1, max_consumption=300_000.0) assert model.ages.values[0] == 51.0 assert model.ages.values[-1] == 95.0 def test_dead_regime_is_terminal() -> None: - model = create_model(n_subjects=1) + model = create_model(n_subjects=1, max_consumption=300_000.0) assert model.regimes["dead"].terminal def test_non_terminal_regimes_not_terminal() -> None: - model = create_model(n_subjects=1) + model = create_model(n_subjects=1, max_consumption=300_000.0) for name in REGIME_SPECS: assert not model.regimes[name].terminal @@ -170,7 +170,7 @@ def test_hcc_persistent_and_transitory_are_shock_grids() -> None: def test_aca_model_creates_successfully() -> None: - model = create_aca_model(n_subjects=1) + model = create_aca_model(n_subjects=1, max_consumption=300_000.0) assert len(model.regimes) == 19 assert model.n_periods == 45 @@ -211,7 +211,7 @@ def test_aca_other_regimes_have_no_aca_policy_keys() -> None: @pytest.mark.parametrize("policy", list(PolicyVariant)) def test_all_policy_variants_create(policy: PolicyVariant) -> None: """All policy variants create valid models.""" - model = create_aca_model(n_subjects=1, policy=policy) + model = create_aca_model(n_subjects=1, max_consumption=300_000.0, policy=policy) assert len(model.regimes) == 19 @@ -251,5 +251,5 @@ def test_aca_only_medicaid_expansion() -> None: def test_baseline_model_creates() -> None: """Baseline model creates successfully without PolicyVariant.""" - model = create_model(n_subjects=1) + model = create_model(n_subjects=1, max_consumption=300_000.0) assert len(model.regimes) == 19 From 31a0ad20e70c9e1859f4268cd954979070d8b17f Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 3 May 2026 17:29:58 +0200 Subject: [PATCH 13/54] Move max_consumption to canonical constant; drop kwarg threading MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds `MAX_CONSUMPTION = 300_000.0` to `baseline/regimes/_common.py` next to the other grid bounds (assets `stop=500_000.0`, AIME `stop=8_000.0`). The two `create_model` factories and `create_benchmark_model` no longer take `max_consumption` as a kwarg; each factory reads the constant directly and attaches it onto `model.max_consumption`. `inject_consumption_points` is unchanged — it still reads `model.max_consumption` (the legitimate consumer that combines it with the per-iteration `consumption_floor`). Routed via the Model attribute rather than `fixed_params` because pylcm validates fixed_params keys against the regime DAG and rejects entries no function consumes (`InvalidParamsError: Unknown keys: ['max_consumption']`). Also pins the pylcm CI ref to 6c610d1 — the squash-merge of pylcm #341 (int32 lock-in) into feat/simulate-aot-n-subjects — to make this build deterministic against pylcm drift. Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/main.yml | 2 +- src/aca_model/aca/model.py | 7 ++----- src/aca_model/baseline/model.py | 16 ++++------------ src/aca_model/baseline/regimes/_common.py | 12 ++++++++++++ src/aca_model/benchmark.py | 2 -- src/aca_model/consumption_grid.py | 15 ++++++++------- tests/test_benchmark.py | 2 +- tests/test_model_creation.py | 21 +++++++++++++-------- 8 files changed, 41 insertions(+), 36 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index bd40367..110aafe 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -29,7 +29,7 @@ jobs: - name: Install pylcm (unreleased feature branch required) run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@feat/simulate-aot-n-subjects" + git+https://github.com/OpenSourceEconomics/pylcm.git@6c610d19644d3f524ad112ed16c0621ee2ecd326" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest diff --git a/src/aca_model/aca/model.py b/src/aca_model/aca/model.py index 942be22..ee8efc6 100644 --- a/src/aca_model/aca/model.py +++ b/src/aca_model/aca/model.py @@ -12,13 +12,13 @@ from aca_model.aca import PolicyVariant from aca_model.aca.regimes import build_all_regimes from aca_model.baseline.regimes import RegimeId +from aca_model.baseline.regimes._common import MAX_CONSUMPTION from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig def create_model( *, n_subjects: int, - max_consumption: float, policy: PolicyVariant = PolicyVariant.ACA, fixed_params: Mapping[str, Any] | None = None, wage_params: Mapping[str, Any] | None = None, @@ -44,9 +44,6 @@ def create_model( contains `pd.Series` indexed by DAG function outputs. grid_config: Continuous-grid point counts. Defaults to production values. - max_consumption: Upper bound of the runtime consumption grid in - $/year. Attached to the returned Model and read back at inject - time by `inject_consumption_points`. Returns: pylcm Model with ACA-specific function overrides. @@ -73,5 +70,5 @@ def create_model( derived_categoricals=derived_categoricals, n_subjects=n_subjects, ) - model.max_consumption = max_consumption + model.max_consumption = MAX_CONSUMPTION return model diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index 0ff7c47..fe181eb 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -16,13 +16,13 @@ from lcm import AgeGrid, DiscreteGrid, Model from aca_model.baseline.regimes import RegimeId, build_all_regimes +from aca_model.baseline.regimes._common import MAX_CONSUMPTION from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig def create_model( *, n_subjects: int, - max_consumption: float, fixed_params: Mapping[str, Any] | None = None, wage_params: Mapping[str, Any] | None = None, derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]] @@ -51,9 +51,6 @@ def create_model( pref_type_grid: Optional override for the `pref_type` `DiscreteGrid`. Defaults to `DiscreteGrid(PrefType)`. Used by the benchmark to substitute a 2-type variant with `DispatchStrategy.PARTITION_SCAN`. - max_consumption: Upper bound of the runtime consumption grid in - $/year. Attached to the returned Model and read back at inject - time by `inject_consumption_points`. Returns: A pylcm Model with 19 regimes (18 non-terminal + dead) spanning @@ -81,12 +78,7 @@ def create_model( derived_categoricals=derived_categoricals, n_subjects=n_subjects, ) - # Attach the consumption-grid upper bound directly to the Model - # instance. Tried surfacing it via a marker function in the regime - # DAG first — pylcm's pruning drops unused-output functions before - # their parameters reach the params template, so the value never - # made it into `resolved_fixed_params`. Direct attachment sidesteps - # the templating machinery entirely; `inject_consumption_points` - # reads `model.max_consumption`. - model.max_consumption = max_consumption + # See `MAX_CONSUMPTION` in `baseline.regimes._common` for why this + # rides on the Model instance instead of `fixed_params`. + model.max_consumption = MAX_CONSUMPTION return model diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 327aa15..30198aa 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -195,6 +195,18 @@ class Grids: _AIME_PIECE_N_POINTS: tuple[int, int, int] = (10, 11, 11) +MAX_CONSUMPTION: float = 300_000.0 +"""Upper bound of the runtime consumption grid in $/year. + +Lives here next to the other grid bounds (assets `stop=500_000.0`, +AIME `stop=8_000.0`). The `create_model` factories attach this onto +`model.max_consumption` so `inject_consumption_points` can read it +back at runtime. Routed via a Model attribute rather than +`fixed_params` because pylcm validates `fixed_params` keys against +the regime DAG and rejects entries no function consumes. +""" + + def build_grids( grid_config: GridConfig = GRID_CONFIG, *, diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 3d7fbad..08f5ec6 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -74,7 +74,6 @@ def create_benchmark_model( *, n_subjects: int, - max_consumption: float, pref_type_grid: DiscreteGrid | None = None, ) -> Model: """Create the aca baseline with `BENCHMARK_GRID_CONFIG` and frozen fixed_params. @@ -104,7 +103,6 @@ def create_benchmark_model( derived_categoricals=_DERIVED_CATEGORICALS, pref_type_grid=pref_type_grid, n_subjects=n_subjects, - max_consumption=max_consumption, ) diff --git a/src/aca_model/consumption_grid.py b/src/aca_model/consumption_grid.py index bd342ee..7123c1f 100644 --- a/src/aca_model/consumption_grid.py +++ b/src/aca_model/consumption_grid.py @@ -3,11 +3,12 @@ Consumption is declared as `IrregSpacedGrid(n_points=N)` in `baseline.regimes._common.build_grids` so the bounds can track runtime parameters: the lower bound from the per-iteration -`consumption_floor` parameter, the upper bound from a per-model -`max_consumption` knob attached to the `Model` instance by the -`create_model` factories. Callers must inject the actual gridpoints -into `params` via `inject_consumption_points` before calling -`model.solve()` / `model.simulate()`. +`consumption_floor` parameter, the upper bound from +`MAX_CONSUMPTION` in `baseline.regimes._common`, which the +`create_model` factories attach to `model.max_consumption`. +Callers must inject the actual gridpoints into `params` via +`inject_consumption_points` before calling `model.solve()` / +`model.simulate()`. """ from collections.abc import Mapping @@ -30,8 +31,8 @@ def inject_consumption_points( `params[regime_name]["consumption"] = {"points": }`. Lower bound: `params["consumption_floor"]` (varies per iteration). - Upper bound: `model.max_consumption` (required attribute; set by - the `create_model` factory). + Upper bound: `model.max_consumption` (set by the `create_model` + factory from `MAX_CONSUMPTION` in `baseline.regimes._common`). Args: params: Existing params mapping. Returned as a new dict; the input is diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 6173318..adafd66 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -12,7 +12,7 @@ @pytest.mark.long_running def test_benchmark_model_simulates_end_to_end() -> None: n_subjects = 20 - model = create_benchmark_model(n_subjects=n_subjects, max_consumption=300_000.0) + model = create_benchmark_model(n_subjects=n_subjects) _, params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=n_subjects, seed=0 diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index accde27..75a87d9 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -11,7 +11,7 @@ from aca_model.baseline.model import create_model from aca_model.baseline.regimes import REGIME_SPECS, RegimeId from aca_model.baseline.regimes import build_regime as _build_regime -from aca_model.baseline.regimes._common import build_grids +from aca_model.baseline.regimes._common import MAX_CONSUMPTION, build_grids _GRIDS = build_grids() @@ -21,24 +21,24 @@ def build_regime(name: str): def test_model_creates_successfully() -> None: - model = create_model(n_subjects=1, max_consumption=300_000.0) + model = create_model(n_subjects=1) assert len(model.regimes) == 19 assert model.n_periods == 45 def test_model_age_range() -> None: - model = create_model(n_subjects=1, max_consumption=300_000.0) + model = create_model(n_subjects=1) assert model.ages.values[0] == 51.0 assert model.ages.values[-1] == 95.0 def test_dead_regime_is_terminal() -> None: - model = create_model(n_subjects=1, max_consumption=300_000.0) + model = create_model(n_subjects=1) assert model.regimes["dead"].terminal def test_non_terminal_regimes_not_terminal() -> None: - model = create_model(n_subjects=1, max_consumption=300_000.0) + model = create_model(n_subjects=1) for name in REGIME_SPECS: assert not model.regimes[name].terminal @@ -170,7 +170,7 @@ def test_hcc_persistent_and_transitory_are_shock_grids() -> None: def test_aca_model_creates_successfully() -> None: - model = create_aca_model(n_subjects=1, max_consumption=300_000.0) + model = create_aca_model(n_subjects=1) assert len(model.regimes) == 19 assert model.n_periods == 45 @@ -211,7 +211,7 @@ def test_aca_other_regimes_have_no_aca_policy_keys() -> None: @pytest.mark.parametrize("policy", list(PolicyVariant)) def test_all_policy_variants_create(policy: PolicyVariant) -> None: """All policy variants create valid models.""" - model = create_aca_model(n_subjects=1, max_consumption=300_000.0, policy=policy) + model = create_aca_model(n_subjects=1, policy=policy) assert len(model.regimes) == 19 @@ -251,5 +251,10 @@ def test_aca_only_medicaid_expansion() -> None: def test_baseline_model_creates() -> None: """Baseline model creates successfully without PolicyVariant.""" - model = create_model(n_subjects=1, max_consumption=300_000.0) + model = create_model(n_subjects=1) assert len(model.regimes) == 19 + + +def test_max_consumption_attached_from_canonical_constant() -> None: + model = create_model(n_subjects=1) + assert model.max_consumption == MAX_CONSUMPTION From 714fee0496c63547da047670fae058acfae6bfa2 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 07:16:37 +0200 Subject: [PATCH 14/54] Assets grid: subtract MAX_CONSUMPTION margin from the floor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With consumption now declared as `IrregSpacedGrid(n_points=N)` and points filled at runtime from `geomspace(consumption_floor, max_consumption, N)`, the grid clusters densely just above `consumption_floor`. At the lowest-asset / highest-OOP-shock corner, those near-floor consumption choices push `next_assets = cash_on_hand - OOP - consumption` slightly below the assets grid's old lower bound (`0` for the bare model, `-max_annual_labor_income` when wage_params are available). Out-of-bounds interpolation of next-period V then injects NaN, which propagates back through E[V] and eventually fails `validate_V`. Symptom on the production solve: `Value function at age 93 in regime 'retiree_oamc_forced_forcedout': 7317 of 207360 values are NaN`, with the `[NOTE]` showing E[V] NaN concentrated at the lowest assets indices and the highest hcc_transitory shock. Subtract `MAX_CONSUMPTION` from the assets floor to give a worst-case single-period drain margin. With 24 linspace points spanning the wider range, the per-point density change is negligible; the dead state and the bare-model fallback get the margin too. The asymmetry fix is the cheapest one — no change to the consumption grid type, no change to per-iteration parameters, no new constraints. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/baseline/regimes/_common.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 30198aa..d32e2c1 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -262,10 +262,25 @@ def build_grids( sigma=1.0, ) - assets_start = 0.0 + # Assets-grid lower bound includes a one-period margin below the + # binding borrowing limit so that `next_assets = cash_on_hand - OOP - + # consumption` stays inside the grid even at the worst-shock × low- + # consumption corner. With the runtime log-spaced consumption grid + # `geomspace(consumption_floor, max_consumption, n_points)`, choices + # cluster densely just above `consumption_floor`, and at the lowest- + # asset/highest-OOP-shock corner those choices push `next_assets` + # slightly off the grid bottom — out-of-bounds interpolation of + # next-period V then injects NaN that propagates through E[V]. + # Subtracting `MAX_CONSUMPTION` gives a worst-case single-period + # drain margin; cheap at production grid sizes (24 linspace points + # over the wider range). + assets_start = -MAX_CONSUMPTION if wage_params is not None and _has_required_wage_keys(wage_params=wage_params): - assets_start = -_compute_max_annual_labor_income( - wage_params=wage_params, wage_res_grid=wage_res + assets_start = ( + -_compute_max_annual_labor_income( + wage_params=wage_params, wage_res_grid=wage_res + ) + - MAX_CONSUMPTION ) return Grids( From 63d2a3819d08cf33f95c0149b6d3531b5292e729 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 11:05:11 +0200 Subject: [PATCH 15/54] Revert "Assets grid: subtract MAX_CONSUMPTION margin from the floor" This reverts commit 714fee0496c63547da047670fae058acfae6bfa2. --- src/aca_model/baseline/regimes/_common.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index d32e2c1..30198aa 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -262,25 +262,10 @@ def build_grids( sigma=1.0, ) - # Assets-grid lower bound includes a one-period margin below the - # binding borrowing limit so that `next_assets = cash_on_hand - OOP - - # consumption` stays inside the grid even at the worst-shock × low- - # consumption corner. With the runtime log-spaced consumption grid - # `geomspace(consumption_floor, max_consumption, n_points)`, choices - # cluster densely just above `consumption_floor`, and at the lowest- - # asset/highest-OOP-shock corner those choices push `next_assets` - # slightly off the grid bottom — out-of-bounds interpolation of - # next-period V then injects NaN that propagates through E[V]. - # Subtracting `MAX_CONSUMPTION` gives a worst-case single-period - # drain margin; cheap at production grid sizes (24 linspace points - # over the wider range). - assets_start = -MAX_CONSUMPTION + assets_start = 0.0 if wage_params is not None and _has_required_wage_keys(wage_params=wage_params): - assets_start = ( - -_compute_max_annual_labor_income( - wage_params=wage_params, wage_res_grid=wage_res - ) - - MAX_CONSUMPTION + assets_start = -_compute_max_annual_labor_income( + wage_params=wage_params, wage_res_grid=wage_res ) return Grids( From 4ae44469ef99a8cdc26da164f0009743aaa72652 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 11:08:02 +0200 Subject: [PATCH 16/54] Wire pension imputation correction (FJ 2011 Appendix A.5) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two new DAG functions in canwork & ss != "forced" regimes: - target_his(his, labor_supply, is_medicaid_eligible): HIS class of the surviving target regime. Mirrors the cross-HIS branches inside _make_transition_canwork (tied → nongroup when stopping work, Medicaid override → nongroup). - imputed_pension_wealth_next_period(next_aime, target_his, period, ...): computes pw_next_imputed = benefit_imputed(next_pia, next_period, target_his) · epdv_constant_pension[next_period] using bare-name parameters into 1-period-shifted views of the imputation arrays (`*_next_period`). Inlining is required because pylcm's AST shape inference doesn't trace nested calls into pensions.benefit. next_assets continues to consume pension_assets_adjustment, which now sees a real imputed_pension_wealth_next_period via the DAG (previously fixed to 0.0 in aca-estimation). The chained dependency next_aime → imputed_pension_wealth_next_period → pension_assets_adjustment is unblocked by pylcm exempting next_ names from fixed_param extraction (PR pylcm#342). Also drops pension_assets_adjustment from borrowing_constraint: a negative correction at a cross-HIS transition can leave no feasible action and inject `-inf` into V via `argmax_and_max(initial=-inf, where=F_arr)`, which then cancels with `0 * -inf = NaN`. The correction is a post-decision shift on next-period assets and must not gate the current consumption choice. --- src/aca_model/agent/assets_and_income.py | 13 +++++-- src/aca_model/baseline/health_insurance.py | 23 ++++++++++++ src/aca_model/baseline/regimes/_nongroup.py | 4 ++ src/aca_model/baseline/regimes/_retiree.py | 4 ++ src/aca_model/baseline/regimes/_tied.py | 4 ++ src/aca_model/environment/pensions.py | 41 ++++++++++++++++++++- 6 files changed, 85 insertions(+), 4 deletions(-) diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index dfa83ef..46d4b1c 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -69,7 +69,14 @@ def borrowing_constraint( consumption: ContinuousAction, cash_on_hand: FloatND, transfers: FloatND, - pension_assets_adjustment: FloatND, ) -> BoolND: - """Consumption cannot exceed available resources (no borrowing).""" - return consumption <= cash_on_hand + transfers + pension_assets_adjustment + """Consumption cannot exceed available resources (no borrowing). + + `pension_assets_adjustment` is excluded: it can be negative (e.g., + when the imputation overstates next-period pension wealth at a + cross-HIS transition), and including it here can leave no feasible + action at low-asset / mid-AIME corners. The correction enters + `next_assets` instead — a post-decision shift that does not gate + the current consumption choice. + """ + return consumption <= cash_on_hand + transfers diff --git a/src/aca_model/baseline/health_insurance.py b/src/aca_model/baseline/health_insurance.py index 741d160..3732d6d 100644 --- a/src/aca_model/baseline/health_insurance.py +++ b/src/aca_model/baseline/health_insurance.py @@ -246,6 +246,29 @@ def is_medicaid_eligible(is_ssi_eligible: BoolND) -> BoolND: return is_ssi_eligible +def target_his( + his: IntND, + labor_supply: DiscreteAction, + is_medicaid_eligible: BoolND, +) -> IntND: + """Return the HIS class of the surviving target regime. + + Mirrors the cross-HIS branches inside `_make_transition_canwork` (retiree, + tied, nongroup): tied agents who stop working become nongroup, and + Medicaid-eligible agents are overridden to nongroup. Used by + `imputed_pension_wealth_next_period` to look up next-period imputation + coefficients at the target's HIS. + """ + tied_to_ng = (his == HealthInsuranceState.tied) & ( + labor_supply == LaborSupply.do_not_work + ) + return jnp.where( + tied_to_ng | is_medicaid_eligible, + HealthInsuranceState.nongroup, + his, + ).astype(jnp.int32) + + def oop_with_medicaid( primary_oop: FloatND, is_medicaid_eligible: BoolND, diff --git a/src/aca_model/baseline/regimes/_nongroup.py b/src/aca_model/baseline/regimes/_nongroup.py index 5cdb6dc..7ee82ff 100644 --- a/src/aca_model/baseline/regimes/_nongroup.py +++ b/src/aca_model/baseline/regimes/_nongroup.py @@ -99,6 +99,10 @@ def _build_functions(spec: dict[str, str]) -> dict: functions["pension_wealth_next_before_adjustment"] = ( pensions.wealth_next_before_adjustment ) + functions["target_his"] = health_insurance.target_his + functions["imputed_pension_wealth_next_period"] = ( + pensions.imputed_pension_wealth_next_period + ) functions["pension_assets_adjustment"] = pensions.assets_adjustment functions["total_to_pia"] = pensions.total_to_pia diff --git a/src/aca_model/baseline/regimes/_retiree.py b/src/aca_model/baseline/regimes/_retiree.py index ac76bfd..a941fa9 100644 --- a/src/aca_model/baseline/regimes/_retiree.py +++ b/src/aca_model/baseline/regimes/_retiree.py @@ -109,6 +109,10 @@ def _build_functions(spec: dict[str, str]) -> dict: functions["pension_wealth_next_before_adjustment"] = ( pensions.wealth_next_before_adjustment ) + functions["target_his"] = health_insurance.target_his + functions["imputed_pension_wealth_next_period"] = ( + pensions.imputed_pension_wealth_next_period + ) functions["pension_assets_adjustment"] = pensions.assets_adjustment functions["total_to_pia"] = pensions.total_to_pia diff --git a/src/aca_model/baseline/regimes/_tied.py b/src/aca_model/baseline/regimes/_tied.py index 5d59274..4351cf5 100644 --- a/src/aca_model/baseline/regimes/_tied.py +++ b/src/aca_model/baseline/regimes/_tied.py @@ -83,6 +83,10 @@ def _build_functions(spec: dict[str, str]) -> dict: functions["pension_wealth_next_before_adjustment"] = ( pensions.wealth_next_before_adjustment ) + functions["target_his"] = health_insurance.target_his + functions["imputed_pension_wealth_next_period"] = ( + pensions.imputed_pension_wealth_next_period + ) functions["pension_assets_adjustment"] = pensions.assets_adjustment functions["total_to_pia"] = pensions.total_to_pia diff --git a/src/aca_model/environment/pensions.py b/src/aca_model/environment/pensions.py index a23a800..eef72d4 100644 --- a/src/aca_model/environment/pensions.py +++ b/src/aca_model/environment/pensions.py @@ -4,7 +4,7 @@ """ import jax.numpy as jnp -from lcm.typing import FloatND, IntND, Period +from lcm.typing import ContinuousState, FloatND, IntND, Period def benefit( @@ -164,3 +164,42 @@ def assets_adjustment( * unconditional_survival_prob[period] * (pension_wealth_next_before_adjustment - imputed_pension_wealth_next_period) ) + + +def imputed_pension_wealth_next_period( + next_aime: ContinuousState, + target_his: IntND, + period: Period, + pia_table: FloatND, + pia_aime_grid: FloatND, + imp_intercept_next_period: FloatND, + imp_pia_coeff_next_period: FloatND, + imp_pia_kink_0_coeff_next_period: FloatND, + imp_pia_kink_1_coeff_next_period: FloatND, + imp_kink_0_next_period: FloatND, + imp_kink_1_next_period: FloatND, + imp_fraction_receiving_next_period: FloatND, + epdv_constant_pension_next_period: FloatND, +) -> FloatND: + """Imputed pension wealth at next period using the target regime's HIS. + + Mirrors `benefit` and `wealth` but indexes into 1-period-shifted views + of the imputation arrays so all subscripts use bare-name parameters + (`period`, `target_his`). Inlining is required: pylcm's AST shape + inference inspects the registered function's body and does not trace + through nested calls into `benefit`. + """ + next_pia = jnp.interp(next_aime, pia_aime_grid, pia_table) + + intercept = imp_intercept_next_period[period, target_his] + pia_pred = imp_pia_coeff_next_period[period, target_his] * next_pia + kink_0_adj = imp_pia_kink_0_coeff_next_period[period, target_his] * jnp.maximum( + 0.0, next_pia - imp_kink_0_next_period[period] + ) + kink_1_adj = imp_pia_kink_1_coeff_next_period[period, target_his] * jnp.maximum( + 0.0, next_pia - imp_kink_1_next_period[period] + ) + + full_benefit = jnp.maximum(0.0, intercept + pia_pred + kink_0_adj + kink_1_adj) + benefit_next = full_benefit * imp_fraction_receiving_next_period[period] + return benefit_next * epdv_constant_pension_next_period[period] From 83f22500e97a6675aa4cd15235dea359dae94f2d Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 11:37:08 +0200 Subject: [PATCH 17/54] =?UTF-8?q?Bump=20pyproject-fmt=20v2.19.0=20?= =?UTF-8?q?=E2=86=92=20v2.21.1=20and=20ruff-pre-commit=20v0.15.6=20?= =?UTF-8?q?=E2=86=92=20v0.15.12?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6e36542..f3188ab 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: check-hooks-apply - id: check-useless-excludes - repo: https://github.com/tox-dev/pyproject-fmt - rev: v2.19.0 + rev: v2.21.1 hooks: - id: pyproject-fmt - repo: https://github.com/lyz-code/yamlfix @@ -47,7 +47,7 @@ repos: hooks: - id: yamllint - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.15.6 + rev: v0.15.12 hooks: - id: ruff-check args: From 3453080fd08afa049483f6ddda215a998a55b757 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 20:48:08 +0200 Subject: [PATCH 18/54] get_benchmark_params: filter obsolete imputed_pension_wealth_next_period key The frozen benchmark_params.pkl was generated when aca-estimation's _assemble_params.py still wrote the placeholder `fp["imputed_pension_wealth_next_period"] = 0.0` into fixed_params. Now that the regime registers `imputed_pension_wealth_next_period` as a DAG function (pension imputation correction in 4ae4446), pylcm's `_resolve_fixed_params` rejects the stale key with `InvalidParamsError: Unknown keys: ['imputed_pension_wealth_next_period']`. Drop the key on load so the snapshot stays valid. Regenerating `benchmark_params.pkl` end-to-end would also remove it; the filter is a no-op for a fresh snapshot. --- src/aca_model/benchmark.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 08f5ec6..47cb628 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -122,13 +122,23 @@ def get_benchmark_params( """ with _PARAMS_FILE.open("rb") as fh: data = cloudpickle.load(fh) - fixed_params = data["fixed_params"] + fixed_params = { + k: v for k, v in data["fixed_params"].items() if k not in _STALE_FIXED_KEYS + } params = _truncate_pref_type_indexed(data["params"]) if model is not None: params = inject_consumption_points(params=params, model=model) return fixed_params, params +# Keys that the older aca-estimation `_assemble_params.py` wrote into +# `fixed_params` but that the current regime now resolves as a DAG +# function. Drop them on load so pylcm's `_resolve_fixed_params` does +# not reject the snapshot. Regenerating `benchmark_params.pkl` would +# also remove these — the filter is a no-op when the snapshot is fresh. +_STALE_FIXED_KEYS: frozenset[str] = frozenset({"imputed_pension_wealth_next_period"}) + + def _truncate_pref_type_indexed(params: dict[str, Any]) -> dict[str, Any]: """Return a copy of `params` with pref_type-indexed Series cut to 2 rows. From b2e90bb58a1c6721046a3e860a95a29485b25117 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 21:04:19 +0200 Subject: [PATCH 19/54] get_benchmark_params: synthesise _next_period shifted views The frozen `benchmark_params.pkl` predates aca-data's `_shift_one_period_forward` change, so the 1-period-shifted views the pension correction consumes are missing. Synthesise them on load with the same transformation aca-data applies. Regenerating the snapshot end-to-end would also produce the keys; this filter is a no-op for a fresh snapshot. --- src/aca_model/benchmark.py | 50 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 47cb628..5c24d4f 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -125,6 +125,7 @@ def get_benchmark_params( fixed_params = { k: v for k, v in data["fixed_params"].items() if k not in _STALE_FIXED_KEYS } + fixed_params = _add_shifted_imputation_arrays(fixed_params) params = _truncate_pref_type_indexed(data["params"]) if model is not None: params = inject_consumption_points(params=params, model=model) @@ -139,6 +140,55 @@ def get_benchmark_params( _STALE_FIXED_KEYS: frozenset[str] = frozenset({"imputed_pension_wealth_next_period"}) +# Source → derived key mapping for the 1-period-shifted views of the +# imputation arrays. The current pension correction (`imputed_pension_ +# wealth_next_period`) consumes these. The frozen `benchmark_params.pkl` +# predates aca-data's `_shift_one_period_forward` change, so synthesise +# the shifted views on load. The transformation is deterministic: row +# `period` carries the original at row `period + 1`; the last row holds +# flat. A regenerated snapshot can drop this synthesis (the filter is a +# no-op when the keys already exist). +_SHIFTED_IMPUTATION_KEYS: tuple[str, ...] = ( + "imp_intercept", + "imp_pia_coeff", + "imp_pia_kink_0_coeff", + "imp_pia_kink_1_coeff", + "imp_kink_0", + "imp_kink_1", + "imp_fraction_receiving", + "epdv_constant_pension", +) + + +def _add_shifted_imputation_arrays(fixed_params: dict[str, Any]) -> dict[str, Any]: + """Synthesise `_next_period` views from the source arrays.""" + out = dict(fixed_params) + for key in _SHIFTED_IMPUTATION_KEYS: + next_period_key = f"{key}_next_period" + if next_period_key in out or key not in out: + continue + out[next_period_key] = _shift_one_period_forward(out[key]) + return out + + +def _shift_one_period_forward(sr: pd.Series) -> pd.Series: + """Shift age-axis values forward one position (last row held flat).""" + if isinstance(sr.index, pd.MultiIndex) and sr.index.names[0] == "age": + n_periods = sr.index.levshape[0] + n_other = int( + np.prod([sr.index.levshape[i] for i in range(1, sr.index.nlevels)]) + ) + values = sr.to_numpy().reshape(n_periods, n_other) + shifted = np.concatenate([values[1:], values[-1:]], axis=0) + return pd.Series(shifted.ravel(), index=sr.index) + if sr.index.name == "age": + values = sr.to_numpy() + shifted = np.concatenate([values[1:], values[-1:]]) + return pd.Series(shifted, index=sr.index) + msg = f"Unexpected index for _shift_one_period_forward: {sr.index!r}" + raise ValueError(msg) + + def _truncate_pref_type_indexed(params: dict[str, Any]) -> dict[str, Any]: """Return a copy of `params` with pref_type-indexed Series cut to 2 rows. From 35eddcc9ee06b960c30c6ea09e3f07541f3144a6 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 21:18:43 +0200 Subject: [PATCH 20/54] benchmark: declare target_his as derived categorical MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit target_his is a DAG function returning an HealthInsuranceState int, used to index 2D imputation arrays inside imputed_pension_wealth_next_period. pylcm needs the categorical mapping declared so array_from_series can reshape (age, target_his)-indexed Series correctly. Mirrors the existing 'his' entry — same enum class. --- src/aca_model/benchmark.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 5c24d4f..13dfae4 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -56,6 +56,7 @@ "good_health": DiscreteGrid(GoodHealth), "is_married": DiscreteGrid(IsMarried), "his": DiscreteGrid(HealthInsuranceState), + "target_his": DiscreteGrid(HealthInsuranceState), "pref_type": DiscreteGrid(BenchmarkPrefType), } From 64d656791230ebf20622c247bf2935880de2fcfd Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 21:32:47 +0200 Subject: [PATCH 21/54] _shift_one_period_forward: rename his level to target_his The shifted imputation arrays (`imp_*_next_period`) are consumed by `imputed_pension_wealth_next_period(target_his, period, ...)`. pylcm's `_validate_and_reorder_levels` matches Series MultiIndex level names against the function's parameter names, so the level needs to be `target_his`, not `his`. --- src/aca_model/benchmark.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 13dfae4..8e242b0 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -173,7 +173,12 @@ def _add_shifted_imputation_arrays(fixed_params: dict[str, Any]) -> dict[str, An def _shift_one_period_forward(sr: pd.Series) -> pd.Series: - """Shift age-axis values forward one position (last row held flat).""" + """Shift age-axis values forward one position (last row held flat). + + For (age, his)-indexed inputs, also rename the `his` level to + `target_his` so the resulting Series matches the level naming the + consuming `imputed_pension_wealth_next_period` function expects. + """ if isinstance(sr.index, pd.MultiIndex) and sr.index.names[0] == "age": n_periods = sr.index.levshape[0] n_other = int( @@ -181,7 +186,10 @@ def _shift_one_period_forward(sr: pd.Series) -> pd.Series: ) values = sr.to_numpy().reshape(n_periods, n_other) shifted = np.concatenate([values[1:], values[-1:]], axis=0) - return pd.Series(shifted.ravel(), index=sr.index) + new_index = sr.index.rename( + [_rename_his_level(name) for name in sr.index.names] + ) + return pd.Series(shifted.ravel(), index=new_index) if sr.index.name == "age": values = sr.to_numpy() shifted = np.concatenate([values[1:], values[-1:]]) @@ -190,6 +198,11 @@ def _shift_one_period_forward(sr: pd.Series) -> pd.Series: raise ValueError(msg) +def _rename_his_level(name: str) -> str: + """Rename `his` to `target_his`, leave others alone.""" + return "target_his" if name == "his" else name + + def _truncate_pref_type_indexed(params: dict[str, Any]) -> dict[str, Any]: """Return a copy of `params` with pref_type-indexed Series cut to 2 rows. From f09b5e34102ff42f739b95be5a9d388795b734a1 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 4 May 2026 22:00:55 +0200 Subject: [PATCH 22/54] Per-target next_assets: dead target uses next_assets_terminal (no pension chain) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit state_transitions["assets"] becomes a per-target dict. The dead target gets a simpler `next_assets_terminal` (cash + transfers - consumption - oop) without the `pension_assets_adjustment` chain, because: 1. There is no future for a dead agent — the imputation correction is meaningless. 2. `pension_assets_adjustment` consumes `imputed_pension_wealth_next_period` which consumes `next_aime`. The dead per-target transitions don't include `next_aime` (dead has no aime state), so dags can't resolve it and pylcm leaks `next_aime` into the kernel signature with no value to pass. Non-dead targets keep `assets_and_income.next_assets` (full version with the pension correction). --- src/aca_model/agent/assets_and_income.py | 19 +++++++++++++- src/aca_model/baseline/regimes/_common.py | 30 ++++++++++++++++++++++- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index 46d4b1c..cb89c89 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -55,7 +55,7 @@ def next_assets( consumption: ContinuousAction, oop_costs: FloatND, ) -> ContinuousState: - """Compute beginning-of-next-period assets. + """Compute beginning-of-next-period assets for non-terminal targets. OOP health costs are deducted here (not from cash_on_hand) so that the consumption choice does not condition on the HCC shock realization. @@ -65,6 +65,23 @@ def next_assets( ) +def next_assets_terminal( + cash_on_hand: FloatND, + transfers: FloatND, + consumption: ContinuousAction, + oop_costs: FloatND, +) -> ContinuousState: + """Compute beginning-of-next-period assets for the dead/terminal target. + + No `pension_assets_adjustment` term: with no future, there is no + next-period pension wealth to impute against. Avoiding the dependency + also keeps the `dead` per-target transition's DAG free of `next_aime` + (which would otherwise need to come from a transition `dead` does not + have, since `aime` is not a state in the terminal regime). + """ + return cash_on_hand + transfers - consumption - oop_costs + + def borrowing_constraint( consumption: ContinuousAction, cash_on_hand: FloatND, diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 30198aa..56887d9 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -644,7 +644,7 @@ def build_state_transitions(spec: dict[str, str]) -> dict: """Build the state transitions dict for a non-dead regime.""" transitions: dict = {} transitions["health"] = _build_per_target_health(spec) - transitions["assets"] = assets_and_income.next_assets + transitions["assets"] = _build_per_target_next_assets(spec) transitions["pref_type"] = None transitions["aime"] = ( social_security.next_aime @@ -661,6 +661,34 @@ def build_state_transitions(spec: dict[str, str]) -> dict: return transitions +def _build_per_target_next_assets(spec: dict[str, str]) -> dict: + """Build per-target assets transitions. + + The `dead` target uses `next_assets_terminal` (no + `pension_assets_adjustment`), so the dead per-target DAG does not + pull in the `next_aime`-dependent imputation chain — `dead` has no + `aime` state and pylcm cannot resolve `next_aime` there. Non-dead + targets use the full `next_assets` with the pension correction. + """ + targets = precompute_targets(spec) + id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} + + result: dict = {} + seen_ids: set[int] = set() + + for target_id in targets.values(): + if target_id in seen_ids: + continue + seen_ids.add(target_id) + target_name = id_to_name.get(target_id) + if target_name is None: + continue + result[target_name] = assets_and_income.next_assets + + result["dead"] = assets_and_income.next_assets_terminal + return result + + def _build_per_target_health(spec: dict[str, str]) -> dict: """Build per-target health transitions. From e1a3eb2478c8616317afa762583f50d9c31de86d Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Tue, 5 May 2026 20:45:22 +0200 Subject: [PATCH 23/54] create_model: register target_his as derived categorical at base layer The pension imputation correction's `imputed_pension_wealth_next_period` indexes shifted arrays via `arr[period, target_his]`, where `target_his` is a DAG output (computed by `health_insurance.target_his` on nongroup/tied/retiree regimes), not a state. pylcm reads the level name `target_his` off the function body via AST inference and rejects matching `pd.Series` fixed_params unless `target_his` is declared as a derived categorical. Production `task_simulate_baseline` calls `create_model(...)` directly, which previously only forwarded the user's `derived_categoricals` arg. The benchmark module was masking this by injecting target_his via `_DERIVED_CATEGORICALS`. Move the declaration to `create_model` itself so the correction works in production without per-caller setup. Tighten the param annotation: pylcm's `Model.derived_categoricals` is a flat `Mapping[str, DiscreteGrid]`, never the nested form. --- src/aca_model/baseline/model.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index fe181eb..b0a2d79 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -15,6 +15,7 @@ from lcm import AgeGrid, DiscreteGrid, Model +from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.baseline.regimes import RegimeId, build_all_regimes from aca_model.baseline.regimes._common import MAX_CONSUMPTION from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig @@ -25,8 +26,7 @@ def create_model( n_subjects: int, fixed_params: Mapping[str, Any] | None = None, wage_params: Mapping[str, Any] | None = None, - derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]] - | None = None, + derived_categoricals: Mapping[str, DiscreteGrid] | None = None, grid_config: GridConfig = GRID_CONFIG, pref_type_grid: DiscreteGrid | None = None, ) -> Model: @@ -69,13 +69,24 @@ def create_model( pref_type_grid=pref_type_grid, ) + # `target_his` is a DAG output of `health_insurance.target_his` (set on + # nongroup/tied/retiree regimes). The pension imputation correction + # (`imputed_pension_wealth_next_period`) indexes shifted arrays by + # `arr[period, target_his]`; pylcm needs the categorical declared so + # `pd.Series` fixed_params with a `target_his` index level resolve. + base_derived: dict[str, DiscreteGrid] = { + "target_his": DiscreteGrid(HealthInsuranceState), + } + if derived_categoricals is not None: + base_derived.update(derived_categoricals) + model = Model( regimes=regimes, ages=ages, regime_id_class=RegimeId, description="Baseline structural retirement model (pre-ACA)", fixed_params=fixed_params or {}, - derived_categoricals=derived_categoricals, + derived_categoricals=base_derived, n_subjects=n_subjects, ) # See `MAX_CONSUMPTION` in `baseline.regimes._common` for why this From 00ee7d2236be9c62286c2aeceffc3b2fc5128b4a Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 05:11:54 +0200 Subject: [PATCH 24/54] aca/model.create_model: register target_his at base layer Same fix as baseline.model.create_model e1a3eb2: ACA variant model creation also takes its own path through `Model(...)`, so the production `task_simulate_aca_*` flows hit the same "Unrecognised indexing parameter 'target_his'" error after the pension correction landed. Move the derived-categorical declaration into the function itself rather than relying on per-caller setup. Tighten the param annotation to match pylcm's flat `Mapping[str, DiscreteGrid]`. --- src/aca_model/aca/model.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/aca_model/aca/model.py b/src/aca_model/aca/model.py index ee8efc6..1cc7ff4 100644 --- a/src/aca_model/aca/model.py +++ b/src/aca_model/aca/model.py @@ -11,6 +11,7 @@ from aca_model.aca import PolicyVariant from aca_model.aca.regimes import build_all_regimes +from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.baseline.regimes import RegimeId from aca_model.baseline.regimes._common import MAX_CONSUMPTION from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig @@ -22,8 +23,7 @@ def create_model( policy: PolicyVariant = PolicyVariant.ACA, fixed_params: Mapping[str, Any] | None = None, wage_params: Mapping[str, Any] | None = None, - derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]] - | None = None, + derived_categoricals: Mapping[str, DiscreteGrid] | None = None, grid_config: GridConfig = GRID_CONFIG, ) -> Model: """Create an ACA policy variant model. @@ -61,13 +61,21 @@ def create_model( wage_params=wage_params, ) + # See `baseline.model.create_model` for why `target_his` is declared + # as a base-layer derived categorical. + base_derived: dict[str, DiscreteGrid] = { + "target_his": DiscreteGrid(HealthInsuranceState), + } + if derived_categoricals is not None: + base_derived.update(derived_categoricals) + model = Model( regimes=regimes, ages=ages, regime_id_class=RegimeId, description=f"Structural retirement model ({policy.name})", fixed_params=fixed_params or {}, - derived_categoricals=derived_categoricals, + derived_categoricals=base_derived, n_subjects=n_subjects, ) model.max_consumption = MAX_CONSUMPTION From edfa540ad23299d625fc2c247970014fd31fb91e Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 09:39:33 +0200 Subject: [PATCH 25/54] =?UTF-8?q?tests:=20positive=20regression=20guard=20?= =?UTF-8?q?=E2=80=94=20assets=3D-$1M=20passes=20benchmark=20validation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Asserts that `validate_initial_conditions` admits a subject placed at `assets = -1_000_000` in `retiree_nomc_inelig_canwork` under the benchmark model. Encodes the economic story: with the consumption floor / transfer system, any past assets level is representable — `c = c_floor` is always feasible because `transfers` tops up cash-on-hand to the floor. The test passes today on benchmark params; it doesn't reproduce the gpu-01 failure (production-side, separate setup loaded by `aca-estimation`'s `assemble_fixed_params`). Kept as a permanent regression guard so a future change that re-introduces a constraint shape that rejects extreme negatives is caught immediately at benchmark scale. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../test_initial_conditions_extreme_assets.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 tests/test_initial_conditions_extreme_assets.py diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py new file mode 100644 index 0000000..0007966 --- /dev/null +++ b/tests/test_initial_conditions_extreme_assets.py @@ -0,0 +1,50 @@ +"""Subjects at extreme negative assets must clear `validate_initial_conditions`. + +The transfer system (`agent.assets_and_income.transfers`) tops cash-on-hand +to `consumption_floor * equivalence_scale` at any starting state, so the +lowest consumption-grid point is always a feasible action regardless of +how negative starting assets are. The model's constraints — and pylcm's +`validate_initial_conditions` pass — must reflect this. +""" + +import jax.numpy as jnp +from lcm.simulation.initial_conditions import validate_initial_conditions + +from aca_model.benchmark import ( + create_benchmark_model, + get_benchmark_initial_conditions, + get_benchmark_params, +) + + +def test_extreme_negative_assets_subject_passes_validation() -> None: + """A subject placed at `assets = -1_000_000` clears initial-conditions validation. + + HRS bottom-codes very-large-negative net wealth at exactly $-1{,}000{,}000$. + Such subjects should remain in the simulated population: the consumption + floor / transfer system absorbs them, with `c = c_floor` always feasible. + """ + n_subjects = 1 + model = create_benchmark_model(n_subjects=n_subjects) + _, params = get_benchmark_params(model=model) + + initial_conditions = get_benchmark_initial_conditions( + model=model, n_subjects=n_subjects, seed=0 + ) + initial_conditions = { + **initial_conditions, + "assets": jnp.asarray([-1_000_000.0]), + "regime": jnp.asarray( + [model.regime_names_to_ids["retiree_nomc_inelig_canwork"]], + dtype=jnp.int32, + ), + } + + internal_params = model._process_params(params) # noqa: SLF001 + validate_initial_conditions( + initial_conditions=initial_conditions, + internal_regimes=model.internal_regimes, + regime_names_to_ids=model.regime_names_to_ids, + internal_params=internal_params, + ages=model.ages, + ) From d05df9e19433a7bdeeb30828692f3042608174d0 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 10:00:09 +0200 Subject: [PATCH 26/54] borrowing_constraint: use max(cash_on_hand, floor) to dodge fp32 cancellation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The expression `cash_on_hand + transfers` suffers float32 catastrophic cancellation when `|cash_on_hand|` is much larger than `consumption_floor`. For a subject at $-1{,}000{,}000$ in starting assets: cash_on_hand ≈ -1e6 (dominated by assets) transfers = max(0, c_floor - cash_on_hand) ≈ c_floor + 1e6 cash_on_hand + transfers ≈ c_floor ± 0.1 (fp32 error at 1e6 magnitude) The lowest grid `c` is exactly `c_floor`. With unfavorable rounding, `c_floor <= c_floor - 0.1` is False — every action gets rejected and `validate_initial_conditions` raises. This is exactly the failure gpu-01 hit on `task_simulate_aca_*`: the per-constraint diagnostic showed `borrowing_constraint = False` (rejects every action by itself) while `positive_leisure = True`. The algebraic identity `cash_on_hand + transfers == max(cash_on_hand, floor)` (where `floor = c_floor * equivalence_scale`) holds exactly because `transfers` is defined as `max(0, floor - cash_on_hand)`. Substituting in: cash_on_hand + max(0, floor - cash_on_hand) = max(cash_on_hand, cash_on_hand + floor - cash_on_hand) = max(cash_on_hand, floor) The `max` form has no cancellation: it returns `floor` exactly when `cash_on_hand << floor`, and `cash_on_hand` exactly otherwise. Switch the constraint to take `consumption_floor` and `equivalence_scale` directly and compute `floor = consumption_floor * equivalence_scale` in-line. Add a precision-specific unit test asserting `c = c_floor` is admitted at `cash_on_hand = -$1M` in fp32. The pre-existing benchmark-based regression guard (`test_extreme_negative_assets_subject_passes_ validation`) didn't catch the bug because benchmark params land on the favorable side of the rounding; the new test exercises the exact cancellation case. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/agent/assets_and_income.py | 23 +++++++++++++++---- .../test_initial_conditions_extreme_assets.py | 22 ++++++++++++++++++ 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index cb89c89..629fc42 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -85,15 +85,28 @@ def next_assets_terminal( def borrowing_constraint( consumption: ContinuousAction, cash_on_hand: FloatND, - transfers: FloatND, + consumption_floor: float, + equivalence_scale: FloatND, ) -> BoolND: - """Consumption cannot exceed available resources (no borrowing). - - `pension_assets_adjustment` is excluded: it can be negative (e.g., + """Consumption cannot exceed available resources after transfers. + + Post-transfer resources are `max(cash_on_hand, consumption_floor * + equivalence_scale)`: the transfer system tops `cash_on_hand` to the + floor when below, otherwise resources are unchanged. The algebraic + identity is `cash_on_hand + transfers == max(cash_on_hand, floor)`, + but writing it as `cash_on_hand + transfers` triggers float32 + catastrophic cancellation when `|cash_on_hand|` dwarfs + `consumption_floor` — e.g. a subject at $-1{,}000{,}000$ in starting + assets gives `(-1e6) + (c_floor + 1e6)` with ~0.1 of rounding error, + which can wipe out the `c == c_floor` boundary and reject every + feasible action. The `max` form has no cancellation. + + `pension_assets_adjustment` is excluded: it can be negative (e.g. when the imputation overstates next-period pension wealth at a cross-HIS transition), and including it here can leave no feasible action at low-asset / mid-AIME corners. The correction enters `next_assets` instead — a post-decision shift that does not gate the current consumption choice. """ - return consumption <= cash_on_hand + transfers + floor = consumption_floor * equivalence_scale + return consumption <= jnp.maximum(cash_on_hand, floor) diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py index 0007966..bd88045 100644 --- a/tests/test_initial_conditions_extreme_assets.py +++ b/tests/test_initial_conditions_extreme_assets.py @@ -10,6 +10,7 @@ import jax.numpy as jnp from lcm.simulation.initial_conditions import validate_initial_conditions +from aca_model.agent.assets_and_income import borrowing_constraint from aca_model.benchmark import ( create_benchmark_model, get_benchmark_initial_conditions, @@ -17,6 +18,27 @@ ) +def test_borrowing_constraint_admits_c_floor_at_million_dollar_negative_cash() -> None: + """At `cash_on_hand = -$1M` (fp32), `c = c_floor` remains a feasible choice. + + Computing `cash_on_hand + transfers` directly suffers float32 catastrophic + cancellation: `(-1e6) + (c_floor + 1e6)` loses ~0.1 of precision, enough + to wipe out the `c == c_floor` boundary. The constraint must use the + algebraically equivalent but numerically stable `max(cash_on_hand, floor)` + form. + """ + consumption_floor = 5_000.0 + admitted = bool( + borrowing_constraint( + consumption=jnp.float32(consumption_floor), + cash_on_hand=jnp.float32(-1_000_000.0), + consumption_floor=consumption_floor, + equivalence_scale=jnp.float32(1.0), + ) + ) + assert admitted + + def test_extreme_negative_assets_subject_passes_validation() -> None: """A subject placed at `assets = -1_000_000` clears initial-conditions validation. From 4af83596dac2161c6485bea46f16fec0744e69c9 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 10:05:40 +0200 Subject: [PATCH 27/54] ci: bump pylcm pin to e4cae2aa (post-#342, post-#340 diagnostic) The previous pin (6c610d1, "Lock integer dtype to int32 end-to-end") predates pylcm #342, so the test_initial_conditions_extreme_assets test (and any other test that solves a benchmark regime carrying the pension-imputation correction) raised: InvalidParamsError: Missing required parameter: 'retiree_nomc_inelig_canwork__imputed_pension_wealth_next_period__next_aime' #342's `regime_template` change exempts `next_` references inside transition signatures from `fixed_param` extraction, which the correction's `imputed_pension_wealth_next_period(next_aime, ...)` signature relies on. The new pin tracks `feat/simulate-aot-n-subjects`, which carries #342, #339, #340 (n_subjects API used by `create_benchmark_model`), and the per-constraint validation diagnostic. Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 110aafe..565245d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -29,7 +29,7 @@ jobs: - name: Install pylcm (unreleased feature branch required) run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@6c610d19644d3f524ad112ed16c0621ee2ecd326" + git+https://github.com/OpenSourceEconomics/pylcm.git@e4cae2aa57d4bf568b8ebbade55d44571e3a086f" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest From 0c7f2d589e8dba50dfc115f33dd44ba7e6396ae0 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 10:56:49 +0200 Subject: [PATCH 28/54] =?UTF-8?q?wip:=20debug=20script=20=E2=80=94=20cash?= =?UTF-8?q?=5Fon=5Fhand=20per=20failing=20subject?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug_cash_on_hand.py | 176 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 176 insertions(+) create mode 100644 debug_cash_on_hand.py diff --git a/debug_cash_on_hand.py b/debug_cash_on_hand.py new file mode 100644 index 0000000..bac4a08 --- /dev/null +++ b/debug_cash_on_hand.py @@ -0,0 +1,176 @@ +"""Print cash_on_hand for the failing subjects at every labor_supply choice. + +If `cash_on_hand` evaluates to NaN for any subject, that explains why my +new `borrowing_constraint = c <= max(cash_on_hand, floor)` rejects every +action: `max(NaN, floor) == NaN` and `c <= NaN == False`. + +Usage on gpu-01: + cd ~/aca-dev + pixi run -e cuda12 python aca-model/debug_cash_on_hand.py +""" + +import pickle + +import jax.numpy as jnp +import numpy as np +import pandas as pd +from dags import concatenate_functions + +from aca_data.config import data_catalog +from aca_estimation._assemble_params import ( + _NON_MODEL_KEYS, + assemble_fixed_params, + assemble_params, + broadcast_to_template, +) +from aca_estimation._type_prediction import triple_initdist_by_pref_type +from aca_model.aca import PolicyVariant +from aca_model.aca.model import create_model as create_aca_model +from aca_model.config import GRID_CONFIG_FOR_RUN +from aca_model.consumption_grid import inject_consumption_points + +# Subjects whose `borrowing_constraint=False` in the gpu-01 production +# diagnostic. (subject_id, regime_name) tuples. Subject 1299 is included +# as a positive control: production showed `borrowing_constraint=True` +# for it, so its cash_on_hand should be finite. +_TARGETS: tuple[tuple[int, str], ...] = ( + (1131, "nongroup_nomc_inelig_canwork"), + (1299, "nongroup_nomc_inelig_canwork"), # positive control + (9013, "retiree_nomc_inelig_canwork"), + (10108, "nongroup_dimc_inelig_canwork"), +) + + +def _load_pickle(name: str): + with open(data_catalog[name], "rb") as fh: + return pickle.load(fh) + + +def main() -> None: + ss = _load_pickle("social_security_params") + tax = _load_pickle("tax_params") + ssi = _load_pickle("ssi_medicaid_params") + hi = _load_pickle("health_insurance_params") + pension = _load_pickle("pension_params") + wage = _load_pickle("wage_offer") + transition = _load_pickle("transition_params") + env = _load_pickle("environment_constants") + hcc_insurer = _load_pickle("hcc_insurer_params") + pref = _load_pickle("preference_start_values") + initdist_df = pd.read_pickle(data_catalog["initial_conditions"]) + + n_subjects = 3 * len(initdist_df) + bare_model = create_aca_model( + policy=PolicyVariant.ACA, grid_config=GRID_CONFIG_FOR_RUN, n_subjects=1 + ) + template = bare_model.get_params_template() + fixed_params = assemble_fixed_params( + bare_model=bare_model, + ss_params=ss, + tax_params=tax, + ssi_params=ssi, + hi_params=hi, + pension_params=pension, + wage_params=wage, + transition_params=transition, + env_params=env, + hcc_insurer_params=hcc_insurer, + pref_params=pref, + ) + broadcast_to_template(params=fixed_params, template=template, required=False) + params = assemble_params( + pref_params=pref, base_wage_profile=wage["log_ft_wage_base"] + ) + + model = create_aca_model( + n_subjects=n_subjects, + policy=PolicyVariant.ACA, + fixed_params=fixed_params, + wage_params=wage, + grid_config=GRID_CONFIG_FOR_RUN, + ) + model_params = {k: v for k, v in params.items() if k not in _NON_MODEL_KEYS} + model_params = inject_consumption_points(params=model_params, model=model) + initial = triple_initdist_by_pref_type(initdist_df) + + internal_params = model._process_params(model_params) # noqa: SLF001 + + # Evaluate cash_on_hand and borrowing_constraint for each target subject + # at each labor_supply choice with c = consumption_floor. + consumption_floor = float(model_params["consumption_floor"]) + for subject_id, regime_name in _TARGETS: + regime = model.regimes[regime_name] + internal_regime = model.internal_regimes[regime_name] + functions = internal_regime.simulate_functions.functions + constraints = internal_regime.simulate_functions.constraints + regime_params = { + **internal_regime.resolved_fixed_params, + **dict(internal_params.get(regime_name, {})), + } + + # Build a function returning (cash_on_hand, borrowing_constraint). + targets = ["cash_on_hand"] + if "borrowing_constraint" in constraints: + targets.append("borrowing_constraint") + all_funcs = dict(functions) + all_funcs.update(dict(constraints)) + evaluator = concatenate_functions( + functions=all_funcs, + targets=targets, + return_type="dict", + enforce_signature=False, + set_annotations=True, + ) + + # Per-subject states (single subject; pull idx subject_id from the + # already-tripled initial conditions). + subject_state = { + k: v[subject_id : subject_id + 1] + for k, v in initial.items() + if k != "regime" + } + + labor_supply_grid = np.asarray(regime.actions["labor_supply"].to_jax()) + print(f"\n=== subject {subject_id} ({regime_name}) ===") + print( + f" state: assets={float(subject_state['assets'][0]):.2f}, " + f"aime={float(subject_state['aime'][0]):.2f}, " + f"spousal_income={int(subject_state['spousal_income'][0])}, " + f"health={int(subject_state['health'][0])}, " + f"hcc_persistent={float(subject_state['hcc_persistent'][0]):.4f}, " + f"hcc_transitory={float(subject_state['hcc_transitory'][0]):.4f}" + ) + for ls in labor_supply_grid: + kwargs = { + **{k: v[0] for k, v in subject_state.items()}, + "consumption": jnp.float32(consumption_floor), + "labor_supply": jnp.int32(int(ls)), + "age": jnp.float32(51.0), + "period": jnp.int32(0), + **{k: v for k, v in regime_params.items()}, + } + try: + out = evaluator( + **{ + k: v + for k, v in kwargs.items() + if k in evaluator.__signature__.parameters + } + ) + coh = float(out["cash_on_hand"]) + bc = ( + bool(out.get("borrowing_constraint", True)) + if "borrowing_constraint" in out + else "n/a" + ) + nan_flag = " <-- NaN!" if not np.isfinite(coh) else "" + print( + f" ls={int(ls):d}: cash_on_hand={coh:14.2f} " + f"borrowing_constraint(c=c_floor)={bc}{nan_flag}" + ) + except (KeyError, TypeError) as exc: + print(f" ls={int(ls):d}: eval failed: {exc!r}") + + +if __name__ == "__main__": + main() From 8ffbf5c53063919fff3bbd1f0f49ea4f1691c321 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 11:05:57 +0200 Subject: [PATCH 29/54] wip: fix imports in debug script (broadcast_to_template + ACA_DATA_BLD) --- debug_cash_on_hand.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/debug_cash_on_hand.py b/debug_cash_on_hand.py index bac4a08..c9fc3bc 100644 --- a/debug_cash_on_hand.py +++ b/debug_cash_on_hand.py @@ -15,14 +15,14 @@ import numpy as np import pandas as pd from dags import concatenate_functions +from lcm.params.processing import broadcast_to_template -from aca_data.config import data_catalog from aca_estimation._assemble_params import ( _NON_MODEL_KEYS, assemble_fixed_params, assemble_params, - broadcast_to_template, ) +from aca_estimation.config import ACA_DATA_BLD from aca_estimation._type_prediction import triple_initdist_by_pref_type from aca_model.aca import PolicyVariant from aca_model.aca.model import create_model as create_aca_model @@ -41,23 +41,23 @@ ) -def _load_pickle(name: str): - with open(data_catalog[name], "rb") as fh: +def _load(name: str): + with open(ACA_DATA_BLD / f"{name}.pkl", "rb") as fh: return pickle.load(fh) def main() -> None: - ss = _load_pickle("social_security_params") - tax = _load_pickle("tax_params") - ssi = _load_pickle("ssi_medicaid_params") - hi = _load_pickle("health_insurance_params") - pension = _load_pickle("pension_params") - wage = _load_pickle("wage_offer") - transition = _load_pickle("transition_params") - env = _load_pickle("environment_constants") - hcc_insurer = _load_pickle("hcc_insurer_params") - pref = _load_pickle("preference_start_values") - initdist_df = pd.read_pickle(data_catalog["initial_conditions"]) + ss = _load("social_security_params") + tax = _load("tax_params") + ssi = _load("ssi_medicaid_params") + hi = _load("health_insurance_params") + pension = _load("pension_params") + wage = _load("wage_params") + transition = _load("transition_probs") + env = _load("environment_constants") + hcc_insurer = _load("hcc_insurer_params") + pref = _load("preference_start_values") + initdist_df = pd.read_pickle(ACA_DATA_BLD / "initial_conditions.pkl") n_subjects = 3 * len(initdist_df) bare_model = create_aca_model( From 81cca3c5fe0995b27e4eafceb4352bc9a11c8dc3 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 11:07:26 +0200 Subject: [PATCH 30/54] wip: import GRID_CONFIG_FOR_RUN from aca_estimation --- debug_cash_on_hand.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug_cash_on_hand.py b/debug_cash_on_hand.py index c9fc3bc..2b9e4b4 100644 --- a/debug_cash_on_hand.py +++ b/debug_cash_on_hand.py @@ -26,7 +26,7 @@ from aca_estimation._type_prediction import triple_initdist_by_pref_type from aca_model.aca import PolicyVariant from aca_model.aca.model import create_model as create_aca_model -from aca_model.config import GRID_CONFIG_FOR_RUN +from aca_estimation.config import GRID_CONFIG_FOR_RUN from aca_model.consumption_grid import inject_consumption_points # Subjects whose `borrowing_constraint=False` in the gpu-01 production From e320f41ac4137a14c8f499d88aa6b49838b60f5f Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 11:11:40 +0200 Subject: [PATCH 31/54] wip: pass derived_categoricals to create_aca_model in debug --- debug_cash_on_hand.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/debug_cash_on_hand.py b/debug_cash_on_hand.py index 2b9e4b4..a2b6fd3 100644 --- a/debug_cash_on_hand.py +++ b/debug_cash_on_hand.py @@ -17,18 +17,30 @@ from dags import concatenate_functions from lcm.params.processing import broadcast_to_template +from lcm import DiscreteGrid + from aca_estimation._assemble_params import ( _NON_MODEL_KEYS, assemble_fixed_params, assemble_params, ) -from aca_estimation.config import ACA_DATA_BLD from aca_estimation._type_prediction import triple_initdist_by_pref_type +from aca_estimation.config import ACA_DATA_BLD, GRID_CONFIG_FOR_RUN from aca_model.aca import PolicyVariant from aca_model.aca.model import create_model as create_aca_model -from aca_estimation.config import GRID_CONFIG_FOR_RUN +from aca_model.agent.health import GoodHealth +from aca_model.agent.labor_market import IsMarried +from aca_model.agent.preferences import PrefType +from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.consumption_grid import inject_consumption_points +_DERIVED_CATEGORICALS = { + "good_health": DiscreteGrid(GoodHealth), + "is_married": DiscreteGrid(IsMarried), + "his": DiscreteGrid(HealthInsuranceState), + "pref_type": DiscreteGrid(PrefType), +} + # Subjects whose `borrowing_constraint=False` in the gpu-01 production # diagnostic. (subject_id, regime_name) tuples. Subject 1299 is included # as a positive control: production showed `borrowing_constraint=True` @@ -87,6 +99,7 @@ def main() -> None: policy=PolicyVariant.ACA, fixed_params=fixed_params, wage_params=wage, + derived_categoricals=_DERIVED_CATEGORICALS, grid_config=GRID_CONFIG_FOR_RUN, ) model_params = {k: v for k, v in params.items() if k not in _NON_MODEL_KEYS} From 2208fa6777f46c55bb03d54dc06eb5fd536c8cac Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 11:20:46 +0200 Subject: [PATCH 32/54] wip: augment fixed_params for ACA policy in debug --- debug_cash_on_hand.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/debug_cash_on_hand.py b/debug_cash_on_hand.py index a2b6fd3..72531f3 100644 --- a/debug_cash_on_hand.py +++ b/debug_cash_on_hand.py @@ -89,6 +89,11 @@ def main() -> None: hcc_insurer_params=hcc_insurer, pref_params=pref, ) + from aca_estimation.task_simulate_aca import _augment_fixed_params_for_aca + + _augment_fixed_params_for_aca( + fixed_params=fixed_params, ssi_params=ssi, policy=PolicyVariant.ACA + ) broadcast_to_template(params=fixed_params, template=template, required=False) params = assemble_params( pref_params=pref, base_wage_profile=wage["log_ft_wage_base"] From 8adabda7952b5d0e96218322f3fe7620dbb10a62 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 6 May 2026 11:54:15 +0200 Subject: [PATCH 33/54] borrowing_constraint: cast consumption_floor to consumption's dtype MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Production failure root cause: `consumption_floor` is a Python fp64 float (≈ 1597.0921419521899); `consumption` arrives from the model's fp32 grid (`jnp.geomspace(consumption_floor, ...)`), quantized to 1597.0921630859375 — one fp32 ulp above the input. Without an explicit dtype cast on the floor, `consumption_floor * equivalence_scale` keeps its fp64 type, the comparison promotes to fp64, and the lowest grid point evaluates as 1597.0921630859375 > 1597.0921419521899 → False. Constraint rejects every action. Cast `consumption_floor` to `consumption.dtype` before the multiply so both sides of the `max` use the same precision. Constraint then admits c=c_floor by exact equality in fp32. Diagnosed via the per-constraint admissibility table (pylcm 838473e/ e4cae2a): production showed `borrowing_constraint=False` at modest asset levels (e.g. -$42k), where neither cash_on_hand magnitude nor NaN propagation could explain the rejection. Local repro pinned the ulp mismatch. Add `test_borrowing_constraint_admits_c_floor_with_python_float_floor` as a regression guard at the precise production scenario. Drop the debug script; it served its purpose. Co-Authored-By: Claude Opus 4.7 (1M context) --- debug_cash_on_hand.py | 194 ------------------ src/aca_model/agent/assets_and_income.py | 2 +- .../test_initial_conditions_extreme_assets.py | 28 +++ 3 files changed, 29 insertions(+), 195 deletions(-) delete mode 100644 debug_cash_on_hand.py diff --git a/debug_cash_on_hand.py b/debug_cash_on_hand.py deleted file mode 100644 index 72531f3..0000000 --- a/debug_cash_on_hand.py +++ /dev/null @@ -1,194 +0,0 @@ -"""Print cash_on_hand for the failing subjects at every labor_supply choice. - -If `cash_on_hand` evaluates to NaN for any subject, that explains why my -new `borrowing_constraint = c <= max(cash_on_hand, floor)` rejects every -action: `max(NaN, floor) == NaN` and `c <= NaN == False`. - -Usage on gpu-01: - cd ~/aca-dev - pixi run -e cuda12 python aca-model/debug_cash_on_hand.py -""" - -import pickle - -import jax.numpy as jnp -import numpy as np -import pandas as pd -from dags import concatenate_functions -from lcm.params.processing import broadcast_to_template - -from lcm import DiscreteGrid - -from aca_estimation._assemble_params import ( - _NON_MODEL_KEYS, - assemble_fixed_params, - assemble_params, -) -from aca_estimation._type_prediction import triple_initdist_by_pref_type -from aca_estimation.config import ACA_DATA_BLD, GRID_CONFIG_FOR_RUN -from aca_model.aca import PolicyVariant -from aca_model.aca.model import create_model as create_aca_model -from aca_model.agent.health import GoodHealth -from aca_model.agent.labor_market import IsMarried -from aca_model.agent.preferences import PrefType -from aca_model.baseline.health_insurance import HealthInsuranceState -from aca_model.consumption_grid import inject_consumption_points - -_DERIVED_CATEGORICALS = { - "good_health": DiscreteGrid(GoodHealth), - "is_married": DiscreteGrid(IsMarried), - "his": DiscreteGrid(HealthInsuranceState), - "pref_type": DiscreteGrid(PrefType), -} - -# Subjects whose `borrowing_constraint=False` in the gpu-01 production -# diagnostic. (subject_id, regime_name) tuples. Subject 1299 is included -# as a positive control: production showed `borrowing_constraint=True` -# for it, so its cash_on_hand should be finite. -_TARGETS: tuple[tuple[int, str], ...] = ( - (1131, "nongroup_nomc_inelig_canwork"), - (1299, "nongroup_nomc_inelig_canwork"), # positive control - (9013, "retiree_nomc_inelig_canwork"), - (10108, "nongroup_dimc_inelig_canwork"), -) - - -def _load(name: str): - with open(ACA_DATA_BLD / f"{name}.pkl", "rb") as fh: - return pickle.load(fh) - - -def main() -> None: - ss = _load("social_security_params") - tax = _load("tax_params") - ssi = _load("ssi_medicaid_params") - hi = _load("health_insurance_params") - pension = _load("pension_params") - wage = _load("wage_params") - transition = _load("transition_probs") - env = _load("environment_constants") - hcc_insurer = _load("hcc_insurer_params") - pref = _load("preference_start_values") - initdist_df = pd.read_pickle(ACA_DATA_BLD / "initial_conditions.pkl") - - n_subjects = 3 * len(initdist_df) - bare_model = create_aca_model( - policy=PolicyVariant.ACA, grid_config=GRID_CONFIG_FOR_RUN, n_subjects=1 - ) - template = bare_model.get_params_template() - fixed_params = assemble_fixed_params( - bare_model=bare_model, - ss_params=ss, - tax_params=tax, - ssi_params=ssi, - hi_params=hi, - pension_params=pension, - wage_params=wage, - transition_params=transition, - env_params=env, - hcc_insurer_params=hcc_insurer, - pref_params=pref, - ) - from aca_estimation.task_simulate_aca import _augment_fixed_params_for_aca - - _augment_fixed_params_for_aca( - fixed_params=fixed_params, ssi_params=ssi, policy=PolicyVariant.ACA - ) - broadcast_to_template(params=fixed_params, template=template, required=False) - params = assemble_params( - pref_params=pref, base_wage_profile=wage["log_ft_wage_base"] - ) - - model = create_aca_model( - n_subjects=n_subjects, - policy=PolicyVariant.ACA, - fixed_params=fixed_params, - wage_params=wage, - derived_categoricals=_DERIVED_CATEGORICALS, - grid_config=GRID_CONFIG_FOR_RUN, - ) - model_params = {k: v for k, v in params.items() if k not in _NON_MODEL_KEYS} - model_params = inject_consumption_points(params=model_params, model=model) - initial = triple_initdist_by_pref_type(initdist_df) - - internal_params = model._process_params(model_params) # noqa: SLF001 - - # Evaluate cash_on_hand and borrowing_constraint for each target subject - # at each labor_supply choice with c = consumption_floor. - consumption_floor = float(model_params["consumption_floor"]) - for subject_id, regime_name in _TARGETS: - regime = model.regimes[regime_name] - internal_regime = model.internal_regimes[regime_name] - functions = internal_regime.simulate_functions.functions - constraints = internal_regime.simulate_functions.constraints - regime_params = { - **internal_regime.resolved_fixed_params, - **dict(internal_params.get(regime_name, {})), - } - - # Build a function returning (cash_on_hand, borrowing_constraint). - targets = ["cash_on_hand"] - if "borrowing_constraint" in constraints: - targets.append("borrowing_constraint") - all_funcs = dict(functions) - all_funcs.update(dict(constraints)) - evaluator = concatenate_functions( - functions=all_funcs, - targets=targets, - return_type="dict", - enforce_signature=False, - set_annotations=True, - ) - - # Per-subject states (single subject; pull idx subject_id from the - # already-tripled initial conditions). - subject_state = { - k: v[subject_id : subject_id + 1] - for k, v in initial.items() - if k != "regime" - } - - labor_supply_grid = np.asarray(regime.actions["labor_supply"].to_jax()) - print(f"\n=== subject {subject_id} ({regime_name}) ===") - print( - f" state: assets={float(subject_state['assets'][0]):.2f}, " - f"aime={float(subject_state['aime'][0]):.2f}, " - f"spousal_income={int(subject_state['spousal_income'][0])}, " - f"health={int(subject_state['health'][0])}, " - f"hcc_persistent={float(subject_state['hcc_persistent'][0]):.4f}, " - f"hcc_transitory={float(subject_state['hcc_transitory'][0]):.4f}" - ) - for ls in labor_supply_grid: - kwargs = { - **{k: v[0] for k, v in subject_state.items()}, - "consumption": jnp.float32(consumption_floor), - "labor_supply": jnp.int32(int(ls)), - "age": jnp.float32(51.0), - "period": jnp.int32(0), - **{k: v for k, v in regime_params.items()}, - } - try: - out = evaluator( - **{ - k: v - for k, v in kwargs.items() - if k in evaluator.__signature__.parameters - } - ) - coh = float(out["cash_on_hand"]) - bc = ( - bool(out.get("borrowing_constraint", True)) - if "borrowing_constraint" in out - else "n/a" - ) - nan_flag = " <-- NaN!" if not np.isfinite(coh) else "" - print( - f" ls={int(ls):d}: cash_on_hand={coh:14.2f} " - f"borrowing_constraint(c=c_floor)={bc}{nan_flag}" - ) - except (KeyError, TypeError) as exc: - print(f" ls={int(ls):d}: eval failed: {exc!r}") - - -if __name__ == "__main__": - main() diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index 629fc42..374edf4 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -108,5 +108,5 @@ def borrowing_constraint( `next_assets` instead — a post-decision shift that does not gate the current consumption choice. """ - floor = consumption_floor * equivalence_scale + floor = jnp.asarray(consumption_floor, dtype=consumption.dtype) * equivalence_scale return consumption <= jnp.maximum(cash_on_hand, floor) diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py index bd88045..6df5e18 100644 --- a/tests/test_initial_conditions_extreme_assets.py +++ b/tests/test_initial_conditions_extreme_assets.py @@ -39,6 +39,34 @@ def test_borrowing_constraint_admits_c_floor_at_million_dollar_negative_cash() - assert admitted +def test_borrowing_constraint_admits_c_floor_with_python_float_floor() -> None: + """Python-fp64 `consumption_floor` against fp32 `consumption` must compare in fp32. + + `consumption_floor` arrives at the constraint as a Python float (fp64), but + `consumption` comes from the model's fp32 grid (`jnp.geomspace(...)`), + quantized to a value that differs from the fp64 input by one fp32 ulp + (~2e-5 at $c_{floor} \\approx 1597$). Without an explicit dtype cast on the + floor, the comparison promotes to fp64 and the lowest grid point fails + the constraint. The fix forces the floor into `consumption.dtype` before + the `max` so both sides use the same precision. + + Reproduces the production failure on gpu-01 where every subject in + `nongroup_nomc_inelig_canwork` (and similar regimes) hit + `borrowing_constraint=False` despite legitimate cash_on_hand values. + """ + consumption_floor = 1597.0921419521899 # production value, fp64 + consumption_fp32 = jnp.float32(consumption_floor) + admitted = bool( + borrowing_constraint( + consumption=consumption_fp32, + cash_on_hand=jnp.float32(-44_937.9), + consumption_floor=consumption_floor, # raw Python float + equivalence_scale=jnp.float32(1.0), + ) + ) + assert admitted + + def test_extreme_negative_assets_subject_passes_validation() -> None: """A subject placed at `assets = -1_000_000` clears initial-conditions validation. From c895bd9f8a027355263d589d93dd21cca59af902 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 7 May 2026 06:39:26 +0200 Subject: [PATCH 34/54] borrowing_constraint: drop dtype cast workaround `jnp.asarray(consumption_floor, dtype=consumption.dtype)` quantized the Python-float `consumption_floor` to the action grid's dtype to match the fp32-quantized consumption grid, so the `c == c_floor` boundary compared as exact equality. The pylcm canonical-float boundary cast (#345) routes every continuous-grid `to_jax()` through `canonical_float_dtype()`. Under `jax_enable_x64=True` (set in `aca_model/__init__.py`) that's `fp64`, so the action grid no longer quantizes the floor and Python-float / grid-value cannot disagree on dtype in the first place. Drop the regression test pinned to the cast workaround; the `max(cash_on_hand, floor)` cancellation guard and the full validate- initial-conditions integration test stay in place. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/agent/assets_and_income.py | 2 +- .../test_initial_conditions_extreme_assets.py | 28 ------------------- 2 files changed, 1 insertion(+), 29 deletions(-) diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index 374edf4..629fc42 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -108,5 +108,5 @@ def borrowing_constraint( `next_assets` instead — a post-decision shift that does not gate the current consumption choice. """ - floor = jnp.asarray(consumption_floor, dtype=consumption.dtype) * equivalence_scale + floor = consumption_floor * equivalence_scale return consumption <= jnp.maximum(cash_on_hand, floor) diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py index 6df5e18..bd88045 100644 --- a/tests/test_initial_conditions_extreme_assets.py +++ b/tests/test_initial_conditions_extreme_assets.py @@ -39,34 +39,6 @@ def test_borrowing_constraint_admits_c_floor_at_million_dollar_negative_cash() - assert admitted -def test_borrowing_constraint_admits_c_floor_with_python_float_floor() -> None: - """Python-fp64 `consumption_floor` against fp32 `consumption` must compare in fp32. - - `consumption_floor` arrives at the constraint as a Python float (fp64), but - `consumption` comes from the model's fp32 grid (`jnp.geomspace(...)`), - quantized to a value that differs from the fp64 input by one fp32 ulp - (~2e-5 at $c_{floor} \\approx 1597$). Without an explicit dtype cast on the - floor, the comparison promotes to fp64 and the lowest grid point fails - the constraint. The fix forces the floor into `consumption.dtype` before - the `max` so both sides use the same precision. - - Reproduces the production failure on gpu-01 where every subject in - `nongroup_nomc_inelig_canwork` (and similar regimes) hit - `borrowing_constraint=False` despite legitimate cash_on_hand values. - """ - consumption_floor = 1597.0921419521899 # production value, fp64 - consumption_fp32 = jnp.float32(consumption_floor) - admitted = bool( - borrowing_constraint( - consumption=consumption_fp32, - cash_on_hand=jnp.float32(-44_937.9), - consumption_floor=consumption_floor, # raw Python float - equivalence_scale=jnp.float32(1.0), - ) - ) - assert admitted - - def test_extreme_negative_assets_subject_passes_validation() -> None: """A subject placed at `assets = -1_000_000` clears initial-conditions validation. From e0cc62211438afd877865504281228c1f205d90e Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 7 May 2026 15:47:19 +0200 Subject: [PATCH 35/54] tests: switch helpers import to relative form MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `from tests.helpers.social_security import …` collided with the sibling `tests/__init__.py` packages in aca-data and aca-estimation when pytest collected from the aca-dev workspace root — whichever `tests` package got imported first shadowed the others. Use a relative import so each test module resolves its own helpers package unambiguously. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_social_security.py | 2 +- tests/test_ss_benefit_integration.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_social_security.py b/tests/test_social_security.py index d75e458..e399d3b 100644 --- a/tests/test_social_security.py +++ b/tests/test_social_security.py @@ -9,7 +9,7 @@ from aca_model.agent.labor_market import LaborSupply from aca_model.environment import social_security from aca_model.environment.social_security import ClaimedSS -from tests.helpers.social_security import compute_di_dropout_scale, compute_pia_table +from .helpers.social_security import compute_di_dropout_scale, compute_pia_table ATOL = 0.01 diff --git a/tests/test_ss_benefit_integration.py b/tests/test_ss_benefit_integration.py index 0e77ea5..dadcac9 100644 --- a/tests/test_ss_benefit_integration.py +++ b/tests/test_ss_benefit_integration.py @@ -9,7 +9,7 @@ from aca_model.agent.labor_market import LaborSupply from aca_model.environment import social_security from aca_model.environment.social_security import ClaimedSS -from tests.helpers.social_security import compute_pia_table +from .helpers.social_security import compute_pia_table ATOL = 0.01 From 3d2faf4a8f04b27e60e41f4bb2d3efe4c35e1f49 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 7 May 2026 17:47:58 +0200 Subject: [PATCH 36/54] tests: drop tests/__init__.py; expose helpers via conftest sys.path Reverts the relative-import attempt and instead removes the empty tests/__init__.py (which was colliding with aca-data and aca-estimation's identically named stubs across the aca-dev workspace). A new tests/conftest.py prepends the tests directory to sys.path so `from helpers.social_security import ...` resolves unambiguously. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/__init__.py | 0 tests/conftest.py | 4 ++++ tests/test_social_security.py | 2 +- tests/test_ss_benefit_integration.py | 2 +- 4 files changed, 6 insertions(+), 2 deletions(-) delete mode 100644 tests/__init__.py create mode 100644 tests/conftest.py diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1da1dcf --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,4 @@ +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent)) diff --git a/tests/test_social_security.py b/tests/test_social_security.py index e399d3b..d612f7d 100644 --- a/tests/test_social_security.py +++ b/tests/test_social_security.py @@ -9,7 +9,7 @@ from aca_model.agent.labor_market import LaborSupply from aca_model.environment import social_security from aca_model.environment.social_security import ClaimedSS -from .helpers.social_security import compute_di_dropout_scale, compute_pia_table +from helpers.social_security import compute_di_dropout_scale, compute_pia_table ATOL = 0.01 diff --git a/tests/test_ss_benefit_integration.py b/tests/test_ss_benefit_integration.py index dadcac9..488df32 100644 --- a/tests/test_ss_benefit_integration.py +++ b/tests/test_ss_benefit_integration.py @@ -9,7 +9,7 @@ from aca_model.agent.labor_market import LaborSupply from aca_model.environment import social_security from aca_model.environment.social_security import ClaimedSS -from .helpers.social_security import compute_pia_table +from helpers.social_security import compute_pia_table ATOL = 0.01 From 97c84cd02bc08461e1a4316a013c2ddf24f13261 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 7 May 2026 22:16:09 +0200 Subject: [PATCH 37/54] Drop precision-related workarounds and function defaults MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cleanup driven by pylcm's canonical-float boundary cast (#345). With every input pinned to fp64 under `jax_enable_x64=True` (which `aca_model/__init__.py` sets at import), aca-side precision workarounds no longer have a hook. Source: - `borrowing_constraint`: switch from `consumption <= max(cash_on_hand, floor)` to `consumption <= cash_on_hand + transfers`. The two are algebraically identical (`cash_on_hand + transfers == max(cash_on_hand, floor)`); the `max` form was justified by float32 catastrophic cancellation at extreme negative cash_on_hand, which cannot occur under fp64. The constraint now consumes `transfers` directly instead of recomputing `consumption_floor * equivalence_scale` — `transfers` is already a DAG node, so the resolved interface is shorter. Defaults dropped (callers must pass everything explicitly): - `aca_model.benchmark.create_benchmark_model`: `pref_type_grid`. - `aca_model.benchmark.get_benchmark_params`: `model`. - `aca_model.benchmark.get_benchmark_initial_conditions`: `n_subjects`, `seed`. - `aca_model.baseline.model.create_model`: `fixed_params`, `wage_params`, `derived_categoricals`, `grid_config`, `pref_type_grid`. - `aca_model.aca.model.create_model`: `policy`, `fixed_params`, `wage_params`, `derived_categoricals`, `grid_config`. - `aca_model.baseline.regimes.build_all_regimes`: same five. - `aca_model.aca.regimes.build_all_regimes`: same four. - `aca_model.baseline.regimes._common.build_grids`: same four. - Drop `GRID_CONFIG` import where it was only used as a default value. Tests: - New `tests/helpers/model.py` exposes `make_baseline_model` and `make_aca_model` factories that wrap `create_model` with `None` for every optional input. Tests that don't need fixed params reach the factories through the helper rather than spelling out six `None`s each. Production code stays default-free. - New `test_benchmark_simulate_obeys_borrowing_constraint`: pins the invariant `consumption <= cash_on_hand + transfers` on every alive row of the benchmark simulation. Catches a regression that drops the constraint from a regime, replaces transfers with something looser, or lets an action grid skip the floor. - `test_initial_conditions_extreme_assets`: drop the fp32-specific cancellation regression test (the runtime no longer reaches that path); replace with a pair of unit tests for the new `borrowing_constraint(consumption, cash_on_hand, transfers)` signature. --- src/aca_model/aca/model.py | 38 ++++++++------- src/aca_model/aca/regimes/__init__.py | 15 +++--- src/aca_model/agent/assets_and_income.py | 33 ++++--------- src/aca_model/baseline/model.py | 45 +++++++++--------- src/aca_model/baseline/regimes/__init__.py | 15 +++--- src/aca_model/baseline/regimes/_common.py | 10 ++-- src/aca_model/benchmark.py | 26 +++++----- tests/helpers/model.py | 38 +++++++++++++++ tests/test_benchmark.py | 47 ++++++++++++++++++- .../test_initial_conditions_extreme_assets.py | 46 ++++++++++++------ tests/test_model_creation.py | 39 ++++++++++----- tests/test_social_security.py | 2 +- tests/test_ss_benefit_integration.py | 2 +- 13 files changed, 235 insertions(+), 121 deletions(-) create mode 100644 tests/helpers/model.py diff --git a/src/aca_model/aca/model.py b/src/aca_model/aca/model.py index 1cc7ff4..b76adc6 100644 --- a/src/aca_model/aca/model.py +++ b/src/aca_model/aca/model.py @@ -14,36 +14,40 @@ from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.baseline.regimes import RegimeId from aca_model.baseline.regimes._common import MAX_CONSUMPTION -from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig +from aca_model.config import MODEL_CONFIG, GridConfig def create_model( *, n_subjects: int, - policy: PolicyVariant = PolicyVariant.ACA, - fixed_params: Mapping[str, Any] | None = None, - wage_params: Mapping[str, Any] | None = None, - derived_categoricals: Mapping[str, DiscreteGrid] | None = None, - grid_config: GridConfig = GRID_CONFIG, + policy: PolicyVariant, + fixed_params: Mapping[str, Any] | None, + wage_params: Mapping[str, Any] | None, + derived_categoricals: Mapping[str, DiscreteGrid] | None, + grid_config: GridConfig, ) -> Model: """Create an ACA policy variant model. Args: n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. - policy: Which ACA policy combination to apply. - fixed_params: Parameters to fix at model creation time. These are - partialled into compiled functions and removed from the params - template. Pass data-derived constants here; only estimation - parameters should go through `model.simulate(params=...)`. + policy: Which ACA policy combination to apply (e.g. + `PolicyVariant.ACA`). + fixed_params: Parameters to fix at model creation time, or `None` + to skip. Fixed params are partialled into compiled functions + and removed from the params template. Pass data-derived + constants here; only estimation parameters should go through + `model.simulate(params=...)`. wage_params: Data-derived wage profile dict (`log_ft_wage_mean`, `log_ft_wage_std`, `adj_wage_hours_*`) used only at grid-build time to size the assets-floor to `-max_annual_labor_income`. - Not routed to the pylcm Model. - derived_categoricals: Extra categorical mappings for derived variables - not in the model's state/action grids. Needed when `fixed_params` - contains `pd.Series` indexed by DAG function outputs. - grid_config: Continuous-grid point counts. Defaults to production - values. + Not routed to the pylcm Model. `None` skips the floor sizing. + derived_categoricals: Extra categorical mappings for derived + variables not in the model's state/action grids, or `None`. + Needed when `fixed_params` contains `pd.Series` indexed by DAG + function outputs. + grid_config: Continuous-grid point counts. Pass `GRID_CONFIG` for + production values or `BENCHMARK_GRID_CONFIG` for the + fast-but-structurally-faithful benchmark. Returns: pylcm Model with ACA-specific function overrides. diff --git a/src/aca_model/aca/regimes/__init__.py b/src/aca_model/aca/regimes/__init__.py index 2c143bd..5b9f4bf 100644 --- a/src/aca_model/aca/regimes/__init__.py +++ b/src/aca_model/aca/regimes/__init__.py @@ -10,19 +10,22 @@ from aca_model.aca.regimes._overrides import apply_aca_overrides from aca_model.baseline.regimes import build_all_regimes as baseline_build_all_regimes from aca_model.baseline.regimes._common import REGIME_SPECS -from aca_model.config import GRID_CONFIG, GridConfig +from aca_model.config import GridConfig def build_all_regimes( - policy: PolicyVariant, - grid_config: GridConfig = GRID_CONFIG, *, - fixed_params: Mapping[str, Any] | None = None, - wage_params: Mapping[str, Any] | None = None, + policy: PolicyVariant, + grid_config: GridConfig, + fixed_params: Mapping[str, Any] | None, + wage_params: Mapping[str, Any] | None, ) -> dict[str, Regime]: """Build all 19 regimes with ACA policy overrides.""" regimes = baseline_build_all_regimes( - grid_config, fixed_params=fixed_params, wage_params=wage_params + grid_config=grid_config, + fixed_params=fixed_params, + wage_params=wage_params, + pref_type_grid=None, ) result = {} for name, regime in regimes.items(): diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index 629fc42..c07fdb4 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -85,28 +85,15 @@ def next_assets_terminal( def borrowing_constraint( consumption: ContinuousAction, cash_on_hand: FloatND, - consumption_floor: float, - equivalence_scale: FloatND, + transfers: FloatND, ) -> BoolND: - """Consumption cannot exceed available resources after transfers. - - Post-transfer resources are `max(cash_on_hand, consumption_floor * - equivalence_scale)`: the transfer system tops `cash_on_hand` to the - floor when below, otherwise resources are unchanged. The algebraic - identity is `cash_on_hand + transfers == max(cash_on_hand, floor)`, - but writing it as `cash_on_hand + transfers` triggers float32 - catastrophic cancellation when `|cash_on_hand|` dwarfs - `consumption_floor` — e.g. a subject at $-1{,}000{,}000$ in starting - assets gives `(-1e6) + (c_floor + 1e6)` with ~0.1 of rounding error, - which can wipe out the `c == c_floor` boundary and reject every - feasible action. The `max` form has no cancellation. - - `pension_assets_adjustment` is excluded: it can be negative (e.g. - when the imputation overstates next-period pension wealth at a - cross-HIS transition), and including it here can leave no feasible - action at low-asset / mid-AIME corners. The correction enters - `next_assets` instead — a post-decision shift that does not gate - the current consumption choice. + """Consumption cannot exceed post-transfer resources. + + `pension_assets_adjustment` is excluded from the constraint: it can + be negative (e.g. when the imputation overstates next-period pension + wealth at a cross-HIS transition), and including it here can leave + no feasible action at low-asset / mid-AIME corners. The correction + enters `next_assets` instead — a post-decision shift that does not + gate the current consumption choice. """ - floor = consumption_floor * equivalence_scale - return consumption <= jnp.maximum(cash_on_hand, floor) + return consumption <= cash_on_hand + transfers diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index b0a2d79..1185eeb 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -18,39 +18,42 @@ from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.baseline.regimes import RegimeId, build_all_regimes from aca_model.baseline.regimes._common import MAX_CONSUMPTION -from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig +from aca_model.config import MODEL_CONFIG, GridConfig def create_model( *, n_subjects: int, - fixed_params: Mapping[str, Any] | None = None, - wage_params: Mapping[str, Any] | None = None, - derived_categoricals: Mapping[str, DiscreteGrid] | None = None, - grid_config: GridConfig = GRID_CONFIG, - pref_type_grid: DiscreteGrid | None = None, + fixed_params: Mapping[str, Any] | None, + wage_params: Mapping[str, Any] | None, + derived_categoricals: Mapping[str, DiscreteGrid] | None, + grid_config: GridConfig, + pref_type_grid: DiscreteGrid | None, ) -> Model: """Create the baseline structural retirement model. Args: n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. - fixed_params: Parameters to fix at model creation time. These are - partialled into compiled functions and removed from the params - template. Pass data-derived constants here; only estimation - parameters should go through `model.simulate(params=...)`. + fixed_params: Parameters to fix at model creation time, or `None` + to skip. Fixed params are partialled into compiled functions + and removed from the params template. Pass data-derived + constants here; only estimation parameters should go through + `model.simulate(params=...)`. wage_params: Data-derived wage profile dict (`log_ft_wage_mean`, `log_ft_wage_std`, `adj_wage_hours_*`) used only at grid-build time to size the assets-floor to `-max_annual_labor_income`. - Not routed to the pylcm Model. - derived_categoricals: Extra categorical mappings for derived variables - not in the model's state/action grids. Needed when `fixed_params` - contains `pd.Series` indexed by DAG function outputs. - grid_config: Continuous-grid point counts. Defaults to production - values; pass `BENCHMARK_GRID_CONFIG` for a fast-but-structurally- - faithful benchmark. - pref_type_grid: Optional override for the `pref_type` `DiscreteGrid`. - Defaults to `DiscreteGrid(PrefType)`. Used by the benchmark to - substitute a 2-type variant with `DispatchStrategy.PARTITION_SCAN`. + Not routed to the pylcm Model. `None` skips the floor sizing. + derived_categoricals: Extra categorical mappings for derived + variables not in the model's state/action grids, or `None`. + Needed when `fixed_params` contains `pd.Series` indexed by DAG + function outputs. + grid_config: Continuous-grid point counts. Pass `GRID_CONFIG` for + production values or `BENCHMARK_GRID_CONFIG` for the + fast-but-structurally-faithful benchmark. + pref_type_grid: Pref-type `DiscreteGrid`, or `None` to use + `DiscreteGrid(PrefType)`. Pass a custom grid (e.g. with a + `DispatchStrategy.PARTITION_SCAN` strategy) to substitute the + production layout. Returns: A pylcm Model with 19 regimes (18 non-terminal + dead) spanning @@ -63,7 +66,7 @@ def create_model( step="Y", ) regimes = build_all_regimes( - grid_config, + grid_config=grid_config, fixed_params=fixed_params, wage_params=wage_params, pref_type_grid=pref_type_grid, diff --git a/src/aca_model/baseline/regimes/__init__.py b/src/aca_model/baseline/regimes/__init__.py index a0eaf9e..02e8a05 100644 --- a/src/aca_model/baseline/regimes/__init__.py +++ b/src/aca_model/baseline/regimes/__init__.py @@ -25,7 +25,7 @@ build_dead_regime, build_grids, ) -from aca_model.config import GRID_CONFIG, GridConfig +from aca_model.config import GridConfig __all__ = [ "REGIME_SPECS", @@ -58,11 +58,11 @@ def build_regime(name: str, grids: Grids) -> Regime: def build_all_regimes( - grid_config: GridConfig = GRID_CONFIG, *, - fixed_params: Mapping[str, Any] | None = None, - wage_params: Mapping[str, Any] | None = None, - pref_type_grid: DiscreteGrid | None = None, + grid_config: GridConfig, + fixed_params: Mapping[str, Any] | None, + wage_params: Mapping[str, Any] | None, + pref_type_grid: DiscreteGrid | None, ) -> dict[str, Regime]: """Build all 19 baseline regimes (18 non-terminal + dead). @@ -71,10 +71,11 @@ def build_all_regimes( either being `None` keeps the corresponding static fallback. `pref_type_grid` lets callers inject a compact or partition-lifted `DiscreteGrid(...)` (e.g. the benchmark uses a 2-type - `BenchmarkPrefType` with `DispatchStrategy.PARTITION_SCAN`). + `BenchmarkPrefType` with `DispatchStrategy.PARTITION_SCAN`); `None` + falls back to `DiscreteGrid(PrefType)`. """ grids = build_grids( - grid_config, + grid_config=grid_config, fixed_params=fixed_params, wage_params=wage_params, pref_type_grid=pref_type_grid, diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 56887d9..25347c6 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -36,7 +36,7 @@ from aca_model.agent.preferences import PrefType from aca_model.baseline import health_insurance from aca_model.baseline.health_insurance import BuyPrivate -from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig +from aca_model.config import MODEL_CONFIG, GridConfig from aca_model.environment import social_security, taxes from aca_model.environment.social_security import ClaimedSS @@ -208,11 +208,11 @@ class Grids: def build_grids( - grid_config: GridConfig = GRID_CONFIG, *, - fixed_params: Mapping[str, Any] | None = None, - wage_params: Mapping[str, Any] | None = None, - pref_type_grid: DiscreteGrid | None = None, + grid_config: GridConfig, + fixed_params: Mapping[str, Any] | None, + wage_params: Mapping[str, Any] | None, + pref_type_grid: DiscreteGrid | None, ) -> Grids: """Build continuous-state/action grids from a `GridConfig`. diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 8e242b0..19416f2 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -75,7 +75,7 @@ def create_benchmark_model( *, n_subjects: int, - pref_type_grid: DiscreteGrid | None = None, + pref_type_grid: DiscreteGrid, ) -> Model: """Create the aca baseline with `BENCHMARK_GRID_CONFIG` and frozen fixed_params. @@ -85,22 +85,21 @@ def create_benchmark_model( `n_aime_batch_size = 0`). Args: - pref_type_grid: Override for the pref_type grid. Default is a plain - `DiscreteGrid(BenchmarkPrefType)` (fused vmap). Pass - `DiscreteGrid(BenchmarkPrefType, dispatch=DispatchStrategy.PARTITION_SCAN)` - (or `PARTITION_VMAP`) to get the partition-lifted kernel — the - recommended production setting for aca-model at scale, but only - supported on pylcm versions that expose `DispatchStrategy`. n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. When set, the first matching `simulate(...)` call AOT-compiles all simulate functions for that batch shape. + pref_type_grid: Pref-type grid. Pass `DiscreteGrid(BenchmarkPrefType)` + for plain fused-vmap, or + `DiscreteGrid(BenchmarkPrefType, dispatch=DispatchStrategy.PARTITION_SCAN)` + (or `PARTITION_VMAP`) for the partition-lifted kernel — the + recommended production setting for aca-model at scale, but only + supported on pylcm versions that expose `DispatchStrategy`. """ - if pref_type_grid is None: - pref_type_grid = DiscreteGrid(BenchmarkPrefType) - fixed_params, _ = get_benchmark_params() + fixed_params, _ = get_benchmark_params(model=None) return create_model( grid_config=BENCHMARK_GRID_CONFIG, fixed_params=fixed_params, + wage_params=None, derived_categoricals=_DERIVED_CATEGORICALS, pref_type_grid=pref_type_grid, n_subjects=n_subjects, @@ -108,7 +107,7 @@ def create_benchmark_model( def get_benchmark_params( - *, model: Model | None = None + *, model: Model | None ) -> tuple[dict[str, Any], dict[str, Any]]: """Load the frozen `(fixed_params, params)` snapshot. @@ -119,7 +118,8 @@ def get_benchmark_params( When `model` is provided, consumption gridpoints are injected into `params` for each regime that declares `consumption` as an `IrregSpacedGrid` with runtime-supplied points. The lower bound is - read from `params["consumption_floor"]`. + read from `params["consumption_floor"]`. Pass `model=None` to skip + injection (e.g. when constructing the model with `fixed_params`). """ with _PARAMS_FILE.open("rb") as fh: data = cloudpickle.load(fh) @@ -222,7 +222,7 @@ def _truncate_pref_type_indexed(params: dict[str, Any]) -> dict[str, Any]: def get_benchmark_initial_conditions( - *, model: Model, n_subjects: int = 100, seed: int = 42 + *, model: Model, n_subjects: int, seed: int ) -> dict[str, Array]: """Draw random feasible initial conditions across five age-51 regimes. diff --git a/tests/helpers/model.py b/tests/helpers/model.py new file mode 100644 index 0000000..930c33e --- /dev/null +++ b/tests/helpers/model.py @@ -0,0 +1,38 @@ +"""Tiny factories that wrap `create_model` with `None` for every optional input. + +Used by tests that don't need fixed params, wage params, or a custom pref-type +grid. These helpers exist so production `create_model` factories can stay +default-free without forcing every test call site to spell out +`fixed_params=None, wage_params=None, ...` six times. +""" + +from lcm import Model + +from aca_model.aca.health_insurance import PolicyVariant +from aca_model.aca.model import create_model as _create_aca_model +from aca_model.baseline.model import create_model as _create_baseline_model +from aca_model.config import GRID_CONFIG + + +def make_baseline_model(*, n_subjects: int) -> Model: + """Baseline model with `GRID_CONFIG` and no fixed/wage/derived params.""" + return _create_baseline_model( + n_subjects=n_subjects, + fixed_params=None, + wage_params=None, + derived_categoricals=None, + grid_config=GRID_CONFIG, + pref_type_grid=None, + ) + + +def make_aca_model(*, n_subjects: int, policy: PolicyVariant) -> Model: + """ACA model with `GRID_CONFIG` and no fixed/wage/derived params.""" + return _create_aca_model( + n_subjects=n_subjects, + policy=policy, + fixed_params=None, + wage_params=None, + derived_categoricals=None, + grid_config=GRID_CONFIG, + ) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index adafd66..5e5a68d 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -1,7 +1,9 @@ """Integration test: the benchmark-sized baseline solves + simulates end-to-end.""" import pytest +from lcm import DiscreteGrid +from aca_model.agent.preferences import BenchmarkPrefType from aca_model.benchmark import ( create_benchmark_model, get_benchmark_initial_conditions, @@ -12,7 +14,10 @@ @pytest.mark.long_running def test_benchmark_model_simulates_end_to_end() -> None: n_subjects = 20 - model = create_benchmark_model(n_subjects=n_subjects) + model = create_benchmark_model( + n_subjects=n_subjects, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) _, params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=n_subjects, seed=0 @@ -31,3 +36,43 @@ def test_benchmark_model_simulates_end_to_end() -> None: # Period 0 rows reflect initial conditions — no NaN in continuous states. period_0 = df.loc[df["period"] == 0] assert not period_0[["assets", "aime", "value"]].isna().any().any() + + +@pytest.mark.long_running +def test_benchmark_simulate_obeys_borrowing_constraint() -> None: + """`consumption <= cash_on_hand + transfers` holds for every alive row. + + The simulator only ever picks feasible actions — the borrowing + constraint must hold post-hoc on the simulated panel. A regression + that drops the constraint from a regime, replaces transfers with + something looser, or lets an action grid skip the floor would + surface as a row with `consumption > cash_on_hand + transfers`. + """ + n_subjects = 4 + model = create_benchmark_model( + n_subjects=n_subjects, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) + _, params = get_benchmark_params(model=model) + initial_conditions = get_benchmark_initial_conditions( + model=model, n_subjects=n_subjects, seed=0 + ) + + result = model.simulate( + params=params, + initial_conditions=initial_conditions, + period_to_regime_to_V_arr=None, + log_level="off", + check_initial_conditions=False, + ) + + df = result.to_dataframe(additional_targets=["cash_on_hand", "transfers"]) + alive = df.loc[df["regime"] != "dead"].copy() + slack = (alive["cash_on_hand"] + alive["transfers"]) - alive["consumption"] + # Non-negative within fp64 tolerance; allow 1e-6 of the magnitude scale + # to absorb the float64 rounding budget. + assert (slack >= -1e-6).all(), ( + f"borrowing_constraint violated on " + f"{int((slack < -1e-6).sum())} row(s); " + f"min slack = {slack.min():.6g}" + ) diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py index bd88045..47aeb6a 100644 --- a/tests/test_initial_conditions_extreme_assets.py +++ b/tests/test_initial_conditions_extreme_assets.py @@ -8,9 +8,11 @@ """ import jax.numpy as jnp +from lcm import DiscreteGrid from lcm.simulation.initial_conditions import validate_initial_conditions from aca_model.agent.assets_and_income import borrowing_constraint +from aca_model.agent.preferences import BenchmarkPrefType from aca_model.benchmark import ( create_benchmark_model, get_benchmark_initial_conditions, @@ -18,27 +20,40 @@ ) -def test_borrowing_constraint_admits_c_floor_at_million_dollar_negative_cash() -> None: - """At `cash_on_hand = -$1M` (fp32), `c = c_floor` remains a feasible choice. +def test_borrowing_constraint_admits_consumption_at_post_transfer_resources() -> None: + """`consumption == cash_on_hand + transfers` is feasible by equality.""" + cash_on_hand = jnp.asarray(-50_000.0) + transfers = jnp.asarray(55_000.0) + consumption = cash_on_hand + transfers - Computing `cash_on_hand + transfers` directly suffers float32 catastrophic - cancellation: `(-1e6) + (c_floor + 1e6)` loses ~0.1 of precision, enough - to wipe out the `c == c_floor` boundary. The constraint must use the - algebraically equivalent but numerically stable `max(cash_on_hand, floor)` - form. - """ - consumption_floor = 5_000.0 admitted = bool( borrowing_constraint( - consumption=jnp.float32(consumption_floor), - cash_on_hand=jnp.float32(-1_000_000.0), - consumption_floor=consumption_floor, - equivalence_scale=jnp.float32(1.0), + consumption=consumption, + cash_on_hand=cash_on_hand, + transfers=transfers, ) ) assert admitted +def test_borrowing_constraint_rejects_consumption_above_post_transfer_resources() -> ( + None +): + """`consumption > cash_on_hand + transfers` is rejected.""" + cash_on_hand = jnp.asarray(-50_000.0) + transfers = jnp.asarray(55_000.0) + consumption = cash_on_hand + transfers + 1.0 + + admitted = bool( + borrowing_constraint( + consumption=consumption, + cash_on_hand=cash_on_hand, + transfers=transfers, + ) + ) + assert not admitted + + def test_extreme_negative_assets_subject_passes_validation() -> None: """A subject placed at `assets = -1_000_000` clears initial-conditions validation. @@ -47,7 +62,10 @@ def test_extreme_negative_assets_subject_passes_validation() -> None: floor / transfer system absorbs them, with `c = c_floor` always feasible. """ n_subjects = 1 - model = create_benchmark_model(n_subjects=n_subjects) + model = create_benchmark_model( + n_subjects=n_subjects, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) _, params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index 75a87d9..fca2ef6 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -3,17 +3,32 @@ from collections.abc import Mapping import pytest +from helpers.model import make_aca_model, make_baseline_model from aca_model.aca import health_insurance as aca_hi from aca_model.aca.health_insurance import PolicyVariant -from aca_model.aca.model import create_model as create_aca_model -from aca_model.aca.regimes import build_all_regimes as build_aca_regimes -from aca_model.baseline.model import create_model +from aca_model.aca.regimes import build_all_regimes as _build_aca_regimes from aca_model.baseline.regimes import REGIME_SPECS, RegimeId from aca_model.baseline.regimes import build_regime as _build_regime from aca_model.baseline.regimes._common import MAX_CONSUMPTION, build_grids +from aca_model.config import GRID_CONFIG -_GRIDS = build_grids() + +def build_aca_regimes(policy: PolicyVariant) -> dict: + return _build_aca_regimes( + policy=policy, + grid_config=GRID_CONFIG, + fixed_params=None, + wage_params=None, + ) + + +_GRIDS = build_grids( + grid_config=GRID_CONFIG, + fixed_params=None, + wage_params=None, + pref_type_grid=None, +) def build_regime(name: str): @@ -21,24 +36,24 @@ def build_regime(name: str): def test_model_creates_successfully() -> None: - model = create_model(n_subjects=1) + model = make_baseline_model(n_subjects=1) assert len(model.regimes) == 19 assert model.n_periods == 45 def test_model_age_range() -> None: - model = create_model(n_subjects=1) + model = make_baseline_model(n_subjects=1) assert model.ages.values[0] == 51.0 assert model.ages.values[-1] == 95.0 def test_dead_regime_is_terminal() -> None: - model = create_model(n_subjects=1) + model = make_baseline_model(n_subjects=1) assert model.regimes["dead"].terminal def test_non_terminal_regimes_not_terminal() -> None: - model = create_model(n_subjects=1) + model = make_baseline_model(n_subjects=1) for name in REGIME_SPECS: assert not model.regimes[name].terminal @@ -170,7 +185,7 @@ def test_hcc_persistent_and_transitory_are_shock_grids() -> None: def test_aca_model_creates_successfully() -> None: - model = create_aca_model(n_subjects=1) + model = make_aca_model(n_subjects=1, policy=PolicyVariant.ACA) assert len(model.regimes) == 19 assert model.n_periods == 45 @@ -211,7 +226,7 @@ def test_aca_other_regimes_have_no_aca_policy_keys() -> None: @pytest.mark.parametrize("policy", list(PolicyVariant)) def test_all_policy_variants_create(policy: PolicyVariant) -> None: """All policy variants create valid models.""" - model = create_aca_model(n_subjects=1, policy=policy) + model = make_aca_model(n_subjects=1, policy=policy) assert len(model.regimes) == 19 @@ -251,10 +266,10 @@ def test_aca_only_medicaid_expansion() -> None: def test_baseline_model_creates() -> None: """Baseline model creates successfully without PolicyVariant.""" - model = create_model(n_subjects=1) + model = make_baseline_model(n_subjects=1) assert len(model.regimes) == 19 def test_max_consumption_attached_from_canonical_constant() -> None: - model = create_model(n_subjects=1) + model = make_baseline_model(n_subjects=1) assert model.max_consumption == MAX_CONSUMPTION diff --git a/tests/test_social_security.py b/tests/test_social_security.py index d612f7d..90c5128 100644 --- a/tests/test_social_security.py +++ b/tests/test_social_security.py @@ -5,11 +5,11 @@ import jax.numpy as jnp import numpy as np +from helpers.social_security import compute_di_dropout_scale, compute_pia_table from aca_model.agent.labor_market import LaborSupply from aca_model.environment import social_security from aca_model.environment.social_security import ClaimedSS -from helpers.social_security import compute_di_dropout_scale, compute_pia_table ATOL = 0.01 diff --git a/tests/test_ss_benefit_integration.py b/tests/test_ss_benefit_integration.py index 488df32..ef09775 100644 --- a/tests/test_ss_benefit_integration.py +++ b/tests/test_ss_benefit_integration.py @@ -5,11 +5,11 @@ """ import jax.numpy as jnp +from helpers.social_security import compute_pia_table from aca_model.agent.labor_market import LaborSupply from aca_model.environment import social_security from aca_model.environment.social_security import ClaimedSS -from helpers.social_security import compute_pia_table ATOL = 0.01 From 9d59174143dc065e791da909c18db15ca0856aac Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 06:18:40 +0200 Subject: [PATCH 38/54] borrowing_constraint: restore max() form for kink-stability at extreme cash MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `consumption <= cash_on_hand + transfers` form (algebraically identical to `consumption <= max(cash_on_hand, floor)`) rounds short by sub-ULP at extreme `|cash_on_hand|` ~ 1e6 — for HRS-bottom-coded subjects at `assets=-$1{,}000{,}000$`, the additive RHS comes in at `floor - 5.7e-11` (fp64), flipping the kink-boundary `<=` for the lowest consumption gridpoint. Production task_simulate_aca_no_mandate on HPC fails at validate_initial_conditions for those subjects. The `max(cash_on_hand, floor)` form has no cancellation and returns `floor` exactly when `cash_on_hand < floor`. This is a general floating-point precision concern at extreme operands, not an fp32-specific workaround. Docstring updated accordingly. Reverts the signature back to `(consumption, cash_on_hand, consumption_floor, equivalence_scale)`. Tests: - `test_borrowing_constraint_admits_floor_at_million_dollar_negative_cash`: unit-level reproducer of the production failure — passes only with the `max` form. - The two new `_at_floor` / `_above_post_transfer_resources` unit tests switch back to the new signature. - `test_benchmark_simulate_obeys_borrowing_constraint`: post-hoc check uses `max(cash_on_hand, floor)` rather than `cash_on_hand + transfers` (the additive form has the same sub-ULP issue and would spuriously trip on the same rows). --- src/aca_model/agent/assets_and_income.py | 15 +++++- tests/test_benchmark.py | 29 +++++++---- .../test_initial_conditions_extreme_assets.py | 50 +++++++++++++++---- 3 files changed, 71 insertions(+), 23 deletions(-) diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index c07fdb4..b0ee689 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -85,10 +85,20 @@ def next_assets_terminal( def borrowing_constraint( consumption: ContinuousAction, cash_on_hand: FloatND, - transfers: FloatND, + consumption_floor: float, + equivalence_scale: FloatND, ) -> BoolND: """Consumption cannot exceed post-transfer resources. + Post-transfer resources are `max(cash_on_hand, consumption_floor * + equivalence_scale)`: the transfer system tops `cash_on_hand` to the + floor when below, otherwise resources are unchanged. The algebraic + identity is `cash_on_hand + transfers == max(cash_on_hand, floor)`; + the `max` form is preferred because the additive form rounds to + `floor + ε` (with `|ε| ~ ULP(|cash_on_hand|)`) at extreme cash, which + flips the kink-boundary comparison for HRS-bottom-coded subjects at + `assets=-$1{,}000{,}000`. The `max` form returns `floor` exactly. + `pension_assets_adjustment` is excluded from the constraint: it can be negative (e.g. when the imputation overstates next-period pension wealth at a cross-HIS transition), and including it here can leave @@ -96,4 +106,5 @@ def borrowing_constraint( enters `next_assets` instead — a post-decision shift that does not gate the current consumption choice. """ - return consumption <= cash_on_hand + transfers + floor = consumption_floor * equivalence_scale + return consumption <= jnp.maximum(cash_on_hand, floor) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 5e5a68d..8b1ed88 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -40,14 +40,21 @@ def test_benchmark_model_simulates_end_to_end() -> None: @pytest.mark.long_running def test_benchmark_simulate_obeys_borrowing_constraint() -> None: - """`consumption <= cash_on_hand + transfers` holds for every alive row. + """`consumption <= max(cash_on_hand, floor)` holds for every alive row. The simulator only ever picks feasible actions — the borrowing constraint must hold post-hoc on the simulated panel. A regression - that drops the constraint from a regime, replaces transfers with + that drops the constraint from a regime, replaces the floor with something looser, or lets an action grid skip the floor would - surface as a row with `consumption > cash_on_hand + transfers`. + surface as a row with `consumption > max(cash_on_hand, floor)`. + + The constraint's RHS is `max(cash_on_hand, floor)` rather than + `cash_on_hand + transfers`: the additive form rounds short by + sub-ULP at extreme `|cash_on_hand|`, so the post-hoc check would + also flip on the same kink. """ + import numpy as np + n_subjects = 4 model = create_benchmark_model( n_subjects=n_subjects, @@ -66,13 +73,15 @@ def test_benchmark_simulate_obeys_borrowing_constraint() -> None: check_initial_conditions=False, ) - df = result.to_dataframe(additional_targets=["cash_on_hand", "transfers"]) + df = result.to_dataframe( + additional_targets=["cash_on_hand", "equivalence_scale"] + ) alive = df.loc[df["regime"] != "dead"].copy() - slack = (alive["cash_on_hand"] + alive["transfers"]) - alive["consumption"] - # Non-negative within fp64 tolerance; allow 1e-6 of the magnitude scale - # to absorb the float64 rounding budget. - assert (slack >= -1e-6).all(), ( - f"borrowing_constraint violated on " - f"{int((slack < -1e-6).sum())} row(s); " + consumption_floor = float(params["consumption_floor"]) + floor = consumption_floor * alive["equivalence_scale"].to_numpy() + rhs = np.maximum(alive["cash_on_hand"].to_numpy(), floor) + slack = rhs - alive["consumption"].to_numpy() + assert (slack >= 0).all(), ( + f"borrowing_constraint violated on {int((slack < 0).sum())} row(s); " f"min slack = {slack.min():.6g}" ) diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py index 47aeb6a..6078547 100644 --- a/tests/test_initial_conditions_extreme_assets.py +++ b/tests/test_initial_conditions_extreme_assets.py @@ -20,17 +20,18 @@ ) -def test_borrowing_constraint_admits_consumption_at_post_transfer_resources() -> None: - """`consumption == cash_on_hand + transfers` is feasible by equality.""" - cash_on_hand = jnp.asarray(-50_000.0) - transfers = jnp.asarray(55_000.0) - consumption = cash_on_hand + transfers +def test_borrowing_constraint_admits_consumption_at_floor() -> None: + """`consumption == consumption_floor` at the kink is feasible by equality.""" + consumption_floor = 5_000.0 + equivalence_scale = jnp.asarray(1.0) + cash_on_hand = jnp.asarray(-50_000.0) # below floor — RHS = floor admitted = bool( borrowing_constraint( - consumption=consumption, + consumption=jnp.asarray(consumption_floor), cash_on_hand=cash_on_hand, - transfers=transfers, + consumption_floor=consumption_floor, + equivalence_scale=equivalence_scale, ) ) assert admitted @@ -39,21 +40,48 @@ def test_borrowing_constraint_admits_consumption_at_post_transfer_resources() -> def test_borrowing_constraint_rejects_consumption_above_post_transfer_resources() -> ( None ): - """`consumption > cash_on_hand + transfers` is rejected.""" + """`consumption > max(cash_on_hand, floor)` is rejected.""" + consumption_floor = 5_000.0 + equivalence_scale = jnp.asarray(1.0) cash_on_hand = jnp.asarray(-50_000.0) - transfers = jnp.asarray(55_000.0) - consumption = cash_on_hand + transfers + 1.0 + consumption = jnp.asarray(consumption_floor + 1.0) admitted = bool( borrowing_constraint( consumption=consumption, cash_on_hand=cash_on_hand, - transfers=transfers, + consumption_floor=consumption_floor, + equivalence_scale=equivalence_scale, ) ) assert not admitted +def test_borrowing_constraint_admits_floor_at_million_dollar_negative_cash() -> None: + """The kink-boundary check survives sub-ULP rounding at `|cash_on_hand| ~ 1e6`. + + Reproduces the production failure mode at `assets=-$1{,}000{,}000$` (HRS + bottom-code): the algebraically equivalent `cash_on_hand + transfers` + form rounds to `floor - 5.7e-11` at fp64, flipping `consumption <= ...` + for the lowest consumption gridpoint. The `max(cash_on_hand, floor)` + form returns `floor` exactly. + """ + consumption_floor = 1597.0921419521899 # production value + equivalence_scale = jnp.asarray(1.0) + cash_on_hand = jnp.asarray(-1_000_000.0) + consumption = jnp.asarray(consumption_floor) # lowest grid point + + admitted = bool( + borrowing_constraint( + consumption=consumption, + cash_on_hand=cash_on_hand, + consumption_floor=consumption_floor, + equivalence_scale=equivalence_scale, + ) + ) + assert admitted + + def test_extreme_negative_assets_subject_passes_validation() -> None: """A subject placed at `assets = -1_000_000` clears initial-conditions validation. From 67edfe0f54a305c23297f17ec53aee07b7d90496 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 08:00:39 +0200 Subject: [PATCH 39/54] consumption_grid: pin first gridpoint to consumption_floor exactly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `jnp.geomspace(consumption_floor, max_consumption, num=n)` returns `consumption_floor * r^0 == consumption_floor` mathematically, but some XLA backends drift the first point by sub-ULP. CUDA at n=70 produces `consumption_floor + 2.27e-13`. The borrowing_constraint compares `consumption[0]` against `max(cash_on_hand, consumption_floor)` and any positive drift above `consumption_floor` flips the kink- boundary `<=` for subjects with very negative cash — explaining the HPC-only `task_simulate` failures (~250 subjects) that didn't reproduce on CPU. Pin the first gridpoint back to `consumption_floor` after geomspace. The same drift exists at the upper end (`pts[-1] != max_consumption` exactly) but doesn't flip any constraint comparison, so it's left alone. `tests/test_consumption_grid.py` parametrises the invariant over `n_points = 5, 16, 64, 70, 100` so a future XLA / JAX upgrade that introduces drift at any of these counts surfaces here rather than at `validate_initial_conditions` on HPC. --- src/aca_model/consumption_grid.py | 16 +++++++++-- tests/test_consumption_grid.py | 48 +++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 tests/test_consumption_grid.py diff --git a/src/aca_model/consumption_grid.py b/src/aca_model/consumption_grid.py index 7123c1f..7e004fa 100644 --- a/src/aca_model/consumption_grid.py +++ b/src/aca_model/consumption_grid.py @@ -69,5 +69,17 @@ def _compute_consumption_points( max_consumption: float, n_points: int, ) -> Array: - """Return log-spaced consumption gridpoints from floor to max.""" - return jnp.geomspace(consumption_floor, max_consumption, num=n_points) + """Return log-spaced consumption gridpoints from floor to max. + + `jnp.geomspace` computes intermediate points as `start * r^i` with + `r = (stop/start)^(1/(n-1))`; the first point is `start * r^0`, + which is `start` mathematically but can be off by sub-ULP under + some XLA backends (CUDA + 70 points: `start + 2.27e-13`). The + borrowing constraint compares the first action against + `max(cash_on_hand, consumption_floor)`, and any positive drift + above `consumption_floor` flips the kink-boundary `<=` for + subjects with very negative cash. Pin the first element back to + `consumption_floor` exactly. + """ + pts = jnp.geomspace(consumption_floor, max_consumption, num=n_points) + return pts.at[0].set(consumption_floor) diff --git a/tests/test_consumption_grid.py b/tests/test_consumption_grid.py new file mode 100644 index 0000000..40b7caa --- /dev/null +++ b/tests/test_consumption_grid.py @@ -0,0 +1,48 @@ +"""Consumption-grid invariants required by the borrowing constraint. + +The borrowing constraint in `agent.assets_and_income.borrowing_constraint` +compares the lowest consumption action against +`max(cash_on_hand, consumption_floor * equivalence_scale)`. For subjects +with cash below the floor (HRS bottom-coded `assets=-$1{,}000{,}000$`, +moderate-negative-asset retirees etc.) this RHS collapses to exactly +`consumption_floor` for singles. The constraint is feasible iff the +lowest consumption gridpoint is `<= consumption_floor`. + +`jnp.geomspace(start, stop, num=n)` returns `start * r^i` with +`r = (stop/start)^(1/(n-1))`; mathematically `r^0 == 1` so the first +point equals `start`, but XLA backends can drift by sub-ULP for some +`(start, stop, n)` combinations (observed: CUDA, n=70, drift +2.27e-13). +A positive drift above `consumption_floor` flips the kink-boundary `<=` +and rejects every action for those subjects. + +`_compute_consumption_points` therefore pins the first point back to +`consumption_floor` after `geomspace`. Test that invariant directly. +""" + +import jax.numpy as jnp +import pytest + +from aca_model.consumption_grid import _compute_consumption_points + + +@pytest.mark.parametrize("n_points", [5, 16, 64, 70, 100]) +def test_compute_consumption_points_first_equals_floor_exactly(n_points: int) -> None: + """The first gridpoint equals `consumption_floor` exactly under any `n_points`.""" + consumption_floor = 1597.0921419521899 # production value + pts = _compute_consumption_points( + consumption_floor=consumption_floor, + max_consumption=300_000.0, + n_points=n_points, + ) + assert float(pts[0]) == consumption_floor + + +def test_compute_consumption_points_strictly_increasing() -> None: + """Gridpoints are strictly increasing — no kink-pinning ties.""" + pts = _compute_consumption_points( + consumption_floor=1597.0921419521899, + max_consumption=300_000.0, + n_points=70, + ) + diffs = jnp.diff(pts) + assert bool((diffs > 0).all()) From d9339ab1a00861b2d8f4b5c3f70aa216b9cbd0a6 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 17:55:40 +0200 Subject: [PATCH 40/54] ci: bump pylcm pin to 2f486dc Sweeps in the dtype-barrier polish, simulate AOT-during-solve, and the persistence/benchmark fixes from feat/canonical-float-dtype. Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 565245d..fa0891f 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -29,7 +29,7 @@ jobs: - name: Install pylcm (unreleased feature branch required) run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@e4cae2aa57d4bf568b8ebbade55d44571e3a086f" + git+https://github.com/OpenSourceEconomics/pylcm.git@2f486dc36425ca6339a36cc8214ab4aef1d85df2" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest From 9e39a0679aba8f63b0a7ff03ebf4db1273c4a423 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 8 May 2026 18:00:25 +0200 Subject: [PATCH 41/54] ci: bump pylcm pin to ca66ba9 Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index fa0891f..ec60050 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -29,7 +29,7 @@ jobs: - name: Install pylcm (unreleased feature branch required) run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@2f486dc36425ca6339a36cc8214ab4aef1d85df2" + git+https://github.com/OpenSourceEconomics/pylcm.git@ca66ba9" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest From de4d16fd03a28e7298cd7a22ded16a4638889612 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 10 May 2026 17:14:29 +0200 Subject: [PATCH 42/54] =?UTF-8?q?Rename=20consumption=20=E2=86=92=20consum?= =?UTF-8?q?ption=5Funequiv=20across=20model=20+=20drop=20stale=20docstring?= =?UTF-8?q?s?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bare `consumption` is the raw-$ household action that gets divided by `equivalence_scale` to compute the per-equivalent input to utility. Renaming it `consumption_unequiv` (and the local `equiv_cons` → `consumption_equiv`) makes the equiv/unequiv distinction explicit at every reference. `consumption_weight` (the CES weight α) and `coefficient_rra` carry no equiv/unequiv meaning and stay as-is. Renames applied (case-sensitive substring + word-boundary token): - consumption (action) → consumption_unequiv - consumption_floor → consumption_unequiv_floor - max_consumption / MAX_CONSUMPTION → max_consumption_unequiv / MAX_CONSUMPTION_UNEQUIV - n_consumption_gridpoints → n_consumption_unequiv_gridpoints - inject_consumption_points → inject_consumption_unequiv_points - _compute_consumption_points → _compute_consumption_unequiv_points - average_consumption → average_consumption_unequiv - equiv_cons → consumption_equiv - module file consumption_grid.py → consumption_unequiv_grid.py - test file test_consumption_grid.py → test_consumption_unequiv_grid.py Also drop pref_type_grid docstring references to `DispatchStrategy`, `PARTITION_SCAN`, `PARTITION_VMAP`, and `FUSED_VMAP` from baseline/model.py, baseline/regimes/__init__.py, and baseline/regimes/_common.py — that pylcm enum was never merged. `benchmark.py` and `_benchmark_data/` are out of scope for this commit; their `_STALE_FIXED_KEYS` and `_rename_his_level` workarounds plus the DispatchStrategy docstrings in benchmark.py will be cleaned up when the benchmark snapshot is regenerated and benchmarks become a true special case of the production model. The benchmark.py imports get the minimum updates needed for it to load. `tests/test_initial_conditions_extreme_assets.py::test_extreme_negative_assets_subject_passes_validation` is skipped because it uses the benchmark snapshot, which still carries `average_consumption` (renamed to `average_consumption_unequiv`). Unskip after the snapshot is regenerated. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/aca/health_insurance.py | 2 +- src/aca_model/aca/model.py | 4 +- src/aca_model/agent/assets_and_income.py | 32 +++++------ src/aca_model/agent/preferences.py | 14 ++--- src/aca_model/agent/utility.py | 4 +- src/aca_model/baseline/model.py | 11 ++-- src/aca_model/baseline/regimes/__init__.py | 7 +-- src/aca_model/baseline/regimes/_common.py | 18 +++--- src/aca_model/benchmark.py | 14 ++--- src/aca_model/config.py | 4 +- ...on_grid.py => consumption_unequiv_grid.py} | 56 +++++++++---------- tests/test_benchmark.py | 10 ++-- tests/test_budget_chain_integration.py | 6 +- tests/test_consumption_grid.py | 48 ---------------- tests/test_consumption_unequiv_grid.py | 48 ++++++++++++++++ .../test_initial_conditions_extreme_assets.py | 46 ++++++++------- tests/test_model_components.py | 4 +- tests/test_model_creation.py | 8 +-- tests/test_pension_integration.py | 2 +- tests/test_preferences.py | 24 ++++---- 20 files changed, 183 insertions(+), 179 deletions(-) rename src/aca_model/{consumption_grid.py => consumption_unequiv_grid.py} (50%) delete mode 100644 tests/test_consumption_grid.py create mode 100644 tests/test_consumption_unequiv_grid.py diff --git a/src/aca_model/aca/health_insurance.py b/src/aca_model/aca/health_insurance.py index 1aa4133..cc62a6d 100644 --- a/src/aca_model/aca/health_insurance.py +++ b/src/aca_model/aca/health_insurance.py @@ -120,7 +120,7 @@ def cash_on_hand( OOP health costs are NOT deducted here — they are deducted from next-period assets instead, matching the timing where HCC shocks are - integrated over (agent does not condition consumption on OOP). + integrated over (agent does not condition consumption_unequiv on OOP). """ return ( assets diff --git a/src/aca_model/aca/model.py b/src/aca_model/aca/model.py index b76adc6..ba7d347 100644 --- a/src/aca_model/aca/model.py +++ b/src/aca_model/aca/model.py @@ -13,7 +13,7 @@ from aca_model.aca.regimes import build_all_regimes from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.baseline.regimes import RegimeId -from aca_model.baseline.regimes._common import MAX_CONSUMPTION +from aca_model.baseline.regimes._common import MAX_CONSUMPTION_UNEQUIV from aca_model.config import MODEL_CONFIG, GridConfig @@ -82,5 +82,5 @@ def create_model( derived_categoricals=base_derived, n_subjects=n_subjects, ) - model.max_consumption = MAX_CONSUMPTION + model.max_consumption_unequiv = MAX_CONSUMPTION_UNEQUIV return model diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index b0ee689..ebba99e 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -26,25 +26,25 @@ def cash_on_hand( ssi_benefit: FloatND, hic_premium: FloatND, ) -> FloatND: - """Compute cash on hand available for consumption and saving. + """Compute cash on hand available for consumption_unequiv and saving. OOP health costs are NOT deducted here — they are deducted from next-period assets instead, matching the timing where HCC shocks are - integrated over (agent does not condition consumption on OOP). + integrated over (agent does not condition consumption_unequiv on OOP). """ return assets + after_tax_income + ssi_benefit - hic_premium def transfers( cash_on_hand: FloatND, - consumption_floor: float, + consumption_unequiv_floor: float, equivalence_scale: FloatND, ) -> FloatND: - """Government transfers to enforce consumption floor. + """Government transfers to enforce consumption_unequiv floor. tr = max{0, C_min * equivalence_scale - cash_on_hand} """ - floor = consumption_floor * equivalence_scale + floor = consumption_unequiv_floor * equivalence_scale return jnp.maximum(0.0, floor - cash_on_hand) @@ -52,23 +52,23 @@ def next_assets( cash_on_hand: FloatND, transfers: FloatND, pension_assets_adjustment: FloatND, - consumption: ContinuousAction, + consumption_unequiv: ContinuousAction, oop_costs: FloatND, ) -> ContinuousState: """Compute beginning-of-next-period assets for non-terminal targets. OOP health costs are deducted here (not from cash_on_hand) so that the - consumption choice does not condition on the HCC shock realization. + consumption_unequiv choice does not condition on the HCC shock realization. """ return ( - cash_on_hand + transfers + pension_assets_adjustment - consumption - oop_costs + cash_on_hand + transfers + pension_assets_adjustment - consumption_unequiv - oop_costs ) def next_assets_terminal( cash_on_hand: FloatND, transfers: FloatND, - consumption: ContinuousAction, + consumption_unequiv: ContinuousAction, oop_costs: FloatND, ) -> ContinuousState: """Compute beginning-of-next-period assets for the dead/terminal target. @@ -79,18 +79,18 @@ def next_assets_terminal( (which would otherwise need to come from a transition `dead` does not have, since `aime` is not a state in the terminal regime). """ - return cash_on_hand + transfers - consumption - oop_costs + return cash_on_hand + transfers - consumption_unequiv - oop_costs def borrowing_constraint( - consumption: ContinuousAction, + consumption_unequiv: ContinuousAction, cash_on_hand: FloatND, - consumption_floor: float, + consumption_unequiv_floor: float, equivalence_scale: FloatND, ) -> BoolND: """Consumption cannot exceed post-transfer resources. - Post-transfer resources are `max(cash_on_hand, consumption_floor * + Post-transfer resources are `max(cash_on_hand, consumption_unequiv_floor * equivalence_scale)`: the transfer system tops `cash_on_hand` to the floor when below, otherwise resources are unchanged. The algebraic identity is `cash_on_hand + transfers == max(cash_on_hand, floor)`; @@ -104,7 +104,7 @@ def borrowing_constraint( wealth at a cross-HIS transition), and including it here can leave no feasible action at low-asset / mid-AIME corners. The correction enters `next_assets` instead — a post-decision shift that does not - gate the current consumption choice. + gate the current consumption_unequiv choice. """ - floor = consumption_floor * equivalence_scale - return consumption <= jnp.maximum(cash_on_hand, floor) + floor = consumption_unequiv_floor * equivalence_scale + return consumption_unequiv <= jnp.maximum(cash_on_hand, floor) diff --git a/src/aca_model/agent/preferences.py b/src/aca_model/agent/preferences.py index 28a8367..336cbd7 100644 --- a/src/aca_model/agent/preferences.py +++ b/src/aca_model/agent/preferences.py @@ -122,7 +122,7 @@ def leisure_retired( def utility( - consumption: ContinuousAction, + consumption_unequiv: ContinuousAction, leisure: FloatND, pref_type: DiscreteState, consumption_weight: FloatND, @@ -130,7 +130,7 @@ def utility( equivalence_scale: FloatND, utility_scale_factor: FloatND, ) -> FloatND: - """Within-period utility: CES aggregator over consumption and leisure. + """Within-period utility: CES aggregator over consumption_unequiv and leisure. u = utility_scale_factor * ((c/eq_scale)^α * l^(1-α))^(1-γ) / (1-γ) with log case for γ=1. `consumption_weight` and `coefficient_rra` are @@ -141,8 +141,8 @@ def utility( """ alpha = consumption_weight[pref_type] gamma = coefficient_rra[pref_type] - equiv_cons = consumption / equivalence_scale - composite = equiv_cons**alpha * leisure ** (1.0 - alpha) + consumption_equiv = consumption_unequiv / equivalence_scale + composite = consumption_equiv**alpha * leisure ** (1.0 - alpha) one_minus_gamma = jnp.where(jnp.isclose(gamma, 1.0), 1.0, 1.0 - gamma) u = jnp.where( @@ -168,7 +168,7 @@ def discount_factor( def utility_scale_factor( pref_type: DiscreteState, - average_consumption: float, + average_consumption_unequiv: float, consumption_weight: FloatND, coefficient_rra: FloatND, time_endowment: float, @@ -184,7 +184,7 @@ def utility_scale_factor( pattern: take the state as input, return a per-cell scalar. Registering this as a regime function and then doing `utility_scale_factor[pref_type]` in a downstream consumer is invalid — pylcm broadcasts function outputs to - per-cell scalars before consumption, and the validator in + per-cell scalars before consumption_unequiv, and the validator in `lcm.regime_building.validation` raises on that clash. """ alpha = consumption_weight[pref_type] @@ -195,7 +195,7 @@ def utility_scale_factor( - scale_reference_hours - (fixed_cost_of_work_intercept + fixed_cost_of_work_age_trend * age_offset) ) - u_cons = average_consumption**alpha + u_cons = average_consumption_unequiv**alpha u_leisure = average_leisure ** (1.0 - alpha) one_minus_gamma = jnp.where(jnp.isclose(gamma, 1.0), 1.0, 1.0 - gamma) diff --git a/src/aca_model/agent/utility.py b/src/aca_model/agent/utility.py index fd7bf16..7ba0e2a 100644 --- a/src/aca_model/agent/utility.py +++ b/src/aca_model/agent/utility.py @@ -19,7 +19,7 @@ def retired( - consumption: ContinuousAction, + consumption_unequiv: ContinuousAction, good_health: IntND, equivalence_scale: FloatND, pref_type: DiscreteState, @@ -36,7 +36,7 @@ def retired( leisure_cost_of_bad_health=leisure_cost_of_bad_health, ) return preferences.utility( - consumption=consumption, + consumption_unequiv=consumption_unequiv, leisure=lei, pref_type=pref_type, consumption_weight=consumption_weight, diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index 1185eeb..a8533f2 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -17,7 +17,7 @@ from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.baseline.regimes import RegimeId, build_all_regimes -from aca_model.baseline.regimes._common import MAX_CONSUMPTION +from aca_model.baseline.regimes._common import MAX_CONSUMPTION_UNEQUIV from aca_model.config import MODEL_CONFIG, GridConfig @@ -51,9 +51,8 @@ def create_model( production values or `BENCHMARK_GRID_CONFIG` for the fast-but-structurally-faithful benchmark. pref_type_grid: Pref-type `DiscreteGrid`, or `None` to use - `DiscreteGrid(PrefType)`. Pass a custom grid (e.g. with a - `DispatchStrategy.PARTITION_SCAN` strategy) to substitute the - production layout. + `DiscreteGrid(PrefType)`. Pass a custom grid to substitute + the production layout (e.g. the 2-type benchmark variant). Returns: A pylcm Model with 19 regimes (18 non-terminal + dead) spanning @@ -92,7 +91,7 @@ def create_model( derived_categoricals=base_derived, n_subjects=n_subjects, ) - # See `MAX_CONSUMPTION` in `baseline.regimes._common` for why this + # See `MAX_CONSUMPTION_UNEQUIV` in `baseline.regimes._common` for why this # rides on the Model instance instead of `fixed_params`. - model.max_consumption = MAX_CONSUMPTION + model.max_consumption_unequiv = MAX_CONSUMPTION_UNEQUIV return model diff --git a/src/aca_model/baseline/regimes/__init__.py b/src/aca_model/baseline/regimes/__init__.py index 02e8a05..2473489 100644 --- a/src/aca_model/baseline/regimes/__init__.py +++ b/src/aca_model/baseline/regimes/__init__.py @@ -69,10 +69,9 @@ def build_all_regimes( `fixed_params` is forwarded to `build_grids` for data-driven AIME breakpoints; `wage_params` for the data-driven assets floor; either being `None` keeps the corresponding static fallback. - `pref_type_grid` lets callers inject a compact or partition-lifted - `DiscreteGrid(...)` (e.g. the benchmark uses a 2-type - `BenchmarkPrefType` with `DispatchStrategy.PARTITION_SCAN`); `None` - falls back to `DiscreteGrid(PrefType)`. + `pref_type_grid` lets callers inject a compact `DiscreteGrid(...)` + (e.g. the benchmark's 2-type `BenchmarkPrefType`); `None` falls + back to `DiscreteGrid(PrefType)`. """ grids = build_grids( grid_config=grid_config, diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 25347c6..1bb73fd 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -183,7 +183,7 @@ class RegimeId: class Grids: assets: LinSpacedGrid aime: ContinuousGrid - consumption: ContinuousGrid + consumption_unequiv: ContinuousGrid wage_res: Any hcc_persistent: Any hcc_transitory: Any @@ -195,12 +195,12 @@ class Grids: _AIME_PIECE_N_POINTS: tuple[int, int, int] = (10, 11, 11) -MAX_CONSUMPTION: float = 300_000.0 -"""Upper bound of the runtime consumption grid in $/year. +MAX_CONSUMPTION_UNEQUIV: float = 300_000.0 +"""Upper bound of the runtime consumption_unequiv grid in $/year. Lives here next to the other grid bounds (assets `stop=500_000.0`, AIME `stop=8_000.0`). The `create_model` factories attach this onto -`model.max_consumption` so `inject_consumption_points` can read it +`model.max_consumption_unequiv` so `inject_consumption_unequiv_points` can read it back at runtime. Routed via a Model attribute rather than `fixed_params` because pylcm validates `fixed_params` keys against the regime DAG and rejects entries no function consumes. @@ -231,9 +231,9 @@ def build_grids( the grid floor must still be known at build time. `pref_type_grid` lets callers (e.g. the benchmark) substitute a - compact or partition-lifted `DiscreteGrid(...)` for the production + compact `DiscreteGrid(...)` for the production `DiscreteGrid(PrefType)`. When `None`, defaults to the production - 3-type grid with the default `DispatchStrategy.FUSED_VMAP`. + 3-type grid. """ # Unit-variance standardised shocks: the total_costs / wage # formulas rescale these by fixed_params-level std parameters @@ -276,8 +276,8 @@ def build_grids( batch_size=grid_config.n_assets_batch_size, ), aime=_build_aime_grid(grid_config=grid_config, fixed_params=fixed_params), - consumption=IrregSpacedGrid( - n_points=grid_config.n_consumption_gridpoints, + consumption_unequiv=IrregSpacedGrid( + n_points=grid_config.n_consumption_unequiv_gridpoints, ), wage_res=wage_res, hcc_persistent=hcc_persistent, @@ -425,7 +425,7 @@ def build_actions(spec: dict[str, str], grids: Grids) -> dict: actions["labor_supply"] = DiscreteGrid(LaborSupply) if spec["his"] == "nongroup" and spec["mc"] == "nomc": actions["buy_private"] = DiscreteGrid(BuyPrivate) - actions["consumption"] = grids.consumption + actions["consumption_unequiv"] = grids.consumption_unequiv return actions diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 19416f2..685b5cd 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -44,7 +44,7 @@ from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.baseline.model import create_model from aca_model.config import BENCHMARK_GRID_CONFIG -from aca_model.consumption_grid import inject_consumption_points +from aca_model.consumption_unequiv_grid import inject_consumption_unequiv_points _PARAMS_FILE = ( Path(__file__).resolve().parent / "_benchmark_data" / "benchmark_params.pkl" @@ -115,11 +115,11 @@ def get_benchmark_params( `_N_BENCHMARK_PREF_TYPES` rows so they line up with `BenchmarkPrefType`'s categories. - When `model` is provided, consumption gridpoints are injected into - `params` for each regime that declares `consumption` as an - `IrregSpacedGrid` with runtime-supplied points. The lower bound is - read from `params["consumption_floor"]`. Pass `model=None` to skip - injection (e.g. when constructing the model with `fixed_params`). + When `model` is provided, consumption_unequiv gridpoints are injected + into `params` for each regime that declares `consumption_unequiv` as + an `IrregSpacedGrid` with runtime-supplied points. The lower bound is + read from `params["consumption_unequiv_floor"]`. Pass `model=None` to + skip injection (e.g. when constructing the model with `fixed_params`). """ with _PARAMS_FILE.open("rb") as fh: data = cloudpickle.load(fh) @@ -129,7 +129,7 @@ def get_benchmark_params( fixed_params = _add_shifted_imputation_arrays(fixed_params) params = _truncate_pref_type_indexed(data["params"]) if model is not None: - params = inject_consumption_points(params=params, model=model) + params = inject_consumption_unequiv_points(params=params, model=model) return fixed_params, params diff --git a/src/aca_model/config.py b/src/aca_model/config.py index 37fc0c8..cfa132d 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -29,7 +29,7 @@ class ModelConfig: class GridConfig: n_assets_gridpoints: int = 24 n_aime_gridpoints: int = 12 - n_consumption_gridpoints: int = 70 + n_consumption_unequiv_gridpoints: int = 70 n_wage_res_gridpoints: int = 5 n_hcc_persistent_gridpoints: int = 3 n_hcc_transitory_gridpoints: int = 5 @@ -47,7 +47,7 @@ class GridConfig: BENCHMARK_GRID_CONFIG = GridConfig( n_assets_gridpoints=3, n_aime_gridpoints=3, - n_consumption_gridpoints=5, + n_consumption_unequiv_gridpoints=5, n_wage_res_gridpoints=3, n_hcc_persistent_gridpoints=3, n_hcc_transitory_gridpoints=3, diff --git a/src/aca_model/consumption_grid.py b/src/aca_model/consumption_unequiv_grid.py similarity index 50% rename from src/aca_model/consumption_grid.py rename to src/aca_model/consumption_unequiv_grid.py index 7e004fa..e0e9cd3 100644 --- a/src/aca_model/consumption_grid.py +++ b/src/aca_model/consumption_unequiv_grid.py @@ -1,13 +1,13 @@ -"""Runtime-supplied gridpoints for the consumption action. +"""Runtime-supplied gridpoints for the consumption_unequiv action. Consumption is declared as `IrregSpacedGrid(n_points=N)` in `baseline.regimes._common.build_grids` so the bounds can track runtime parameters: the lower bound from the per-iteration -`consumption_floor` parameter, the upper bound from -`MAX_CONSUMPTION` in `baseline.regimes._common`, which the -`create_model` factories attach to `model.max_consumption`. +`consumption_unequiv_floor` parameter, the upper bound from +`MAX_CONSUMPTION_UNEQUIV` in `baseline.regimes._common`, which the +`create_model` factories attach to `model.max_consumption_unequiv`. Callers must inject the actual gridpoints into `params` via -`inject_consumption_points` before calling `model.solve()` / +`inject_consumption_unequiv_points` before calling `model.solve()` / `model.simulate()`. """ @@ -19,20 +19,20 @@ from lcm import IrregSpacedGrid, Model -def inject_consumption_points( +def inject_consumption_unequiv_points( *, params: Mapping[str, Any], model: Model, ) -> dict[str, Any]: - """Inject consumption gridpoints into per-regime params. + """Inject consumption_unequiv gridpoints into per-regime params. Walks every regime, finds the action whose grid is an `IrregSpacedGrid` with runtime-supplied points, and writes - `params[regime_name]["consumption"] = {"points": }`. + `params[regime_name]["consumption_unequiv"] = {"points": }`. - Lower bound: `params["consumption_floor"]` (varies per iteration). - Upper bound: `model.max_consumption` (set by the `create_model` - factory from `MAX_CONSUMPTION` in `baseline.regimes._common`). + Lower bound: `params["consumption_unequiv_floor"]` (varies per iteration). + Upper bound: `model.max_consumption_unequiv` (set by the `create_model` + factory from `MAX_CONSUMPTION_UNEQUIV` in `baseline.regimes._common`). Args: params: Existing params mapping. Returned as a new dict; the input is @@ -40,46 +40,46 @@ def inject_consumption_points( model: Model whose regime specs determine which regimes need points. Returns: - New params dict with consumption points injected. + New params dict with consumption_unequiv points injected. """ - consumption_floor = float(params["consumption_floor"]) - max_consumption = float(model.max_consumption) + consumption_unequiv_floor = float(params["consumption_unequiv_floor"]) + max_consumption_unequiv = float(model.max_consumption_unequiv) out: dict[str, Any] = dict(params) for regime_name, regime in model.regimes.items(): - grid = regime.actions.get("consumption") + grid = regime.actions.get("consumption_unequiv") if not (isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime): continue # Runtime-points grids always have `n_points` set (the constructor # rejects the (points=None, n_points=None) combo); narrow for ty. assert grid.n_points is not None - points = _compute_consumption_points( - consumption_floor=consumption_floor, - max_consumption=max_consumption, + points = _compute_consumption_unequiv_points( + consumption_unequiv_floor=consumption_unequiv_floor, + max_consumption_unequiv=max_consumption_unequiv, n_points=grid.n_points, ) regime_entry = dict(out.get(regime_name, {})) - regime_entry["consumption"] = {"points": points} + regime_entry["consumption_unequiv"] = {"points": points} out[regime_name] = regime_entry return out -def _compute_consumption_points( +def _compute_consumption_unequiv_points( *, - consumption_floor: float, - max_consumption: float, + consumption_unequiv_floor: float, + max_consumption_unequiv: float, n_points: int, ) -> Array: - """Return log-spaced consumption gridpoints from floor to max. + """Return log-spaced consumption_unequiv gridpoints from floor to max. `jnp.geomspace` computes intermediate points as `start * r^i` with `r = (stop/start)^(1/(n-1))`; the first point is `start * r^0`, which is `start` mathematically but can be off by sub-ULP under some XLA backends (CUDA + 70 points: `start + 2.27e-13`). The borrowing constraint compares the first action against - `max(cash_on_hand, consumption_floor)`, and any positive drift - above `consumption_floor` flips the kink-boundary `<=` for + `max(cash_on_hand, consumption_unequiv_floor)`, and any positive drift + above `consumption_unequiv_floor` flips the kink-boundary `<=` for subjects with very negative cash. Pin the first element back to - `consumption_floor` exactly. + `consumption_unequiv_floor` exactly. """ - pts = jnp.geomspace(consumption_floor, max_consumption, num=n_points) - return pts.at[0].set(consumption_floor) + pts = jnp.geomspace(consumption_unequiv_floor, max_consumption_unequiv, num=n_points) + return pts.at[0].set(consumption_unequiv_floor) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 8b1ed88..b626b50 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -40,13 +40,13 @@ def test_benchmark_model_simulates_end_to_end() -> None: @pytest.mark.long_running def test_benchmark_simulate_obeys_borrowing_constraint() -> None: - """`consumption <= max(cash_on_hand, floor)` holds for every alive row. + """`consumption_unequiv <= max(cash_on_hand, floor)` holds for every alive row. The simulator only ever picks feasible actions — the borrowing constraint must hold post-hoc on the simulated panel. A regression that drops the constraint from a regime, replaces the floor with something looser, or lets an action grid skip the floor would - surface as a row with `consumption > max(cash_on_hand, floor)`. + surface as a row with `consumption_unequiv > max(cash_on_hand, floor)`. The constraint's RHS is `max(cash_on_hand, floor)` rather than `cash_on_hand + transfers`: the additive form rounds short by @@ -77,10 +77,10 @@ def test_benchmark_simulate_obeys_borrowing_constraint() -> None: additional_targets=["cash_on_hand", "equivalence_scale"] ) alive = df.loc[df["regime"] != "dead"].copy() - consumption_floor = float(params["consumption_floor"]) - floor = consumption_floor * alive["equivalence_scale"].to_numpy() + consumption_unequiv_floor = float(params["consumption_unequiv_floor"]) + floor = consumption_unequiv_floor * alive["equivalence_scale"].to_numpy() rhs = np.maximum(alive["cash_on_hand"].to_numpy(), floor) - slack = rhs - alive["consumption"].to_numpy() + slack = rhs - alive["consumption_unequiv"].to_numpy() assert (slack >= 0).all(), ( f"borrowing_constraint violated on {int((slack < 0).sum())} row(s); " f"min slack = {slack.min():.6g}" diff --git a/tests/test_budget_chain_integration.py b/tests/test_budget_chain_integration.py index 8bf206c..b511f8c 100644 --- a/tests/test_budget_chain_integration.py +++ b/tests/test_budget_chain_integration.py @@ -108,7 +108,7 @@ def test_retired_agent_with_pension() -> None: def test_transfers_kick_in_below_floor() -> None: - """When cash_on_hand < consumption_floor, transfers fill the gap.""" + """When cash_on_hand < consumption_unequiv_floor, transfers fill the gap.""" functions = { "cash_on_hand": assets_and_income.cash_on_hand, "transfers": assets_and_income.transfers, @@ -126,10 +126,10 @@ def test_transfers_kick_in_below_floor() -> None: ssi_benefit=jnp.array(0.0), hic_premium=jnp.array(0.0), oop_costs=jnp.array(0.0), - consumption_floor=5000.0, + consumption_unequiv_floor=5000.0, equivalence_scale=jnp.array(1.0), pension_assets_adjustment=jnp.array(0.0), - consumption=jnp.array(4000.0), + consumption_unequiv=jnp.array(4000.0), ) # cash_on_hand = 500 + 200 = 700 diff --git a/tests/test_consumption_grid.py b/tests/test_consumption_grid.py deleted file mode 100644 index 40b7caa..0000000 --- a/tests/test_consumption_grid.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Consumption-grid invariants required by the borrowing constraint. - -The borrowing constraint in `agent.assets_and_income.borrowing_constraint` -compares the lowest consumption action against -`max(cash_on_hand, consumption_floor * equivalence_scale)`. For subjects -with cash below the floor (HRS bottom-coded `assets=-$1{,}000{,}000$`, -moderate-negative-asset retirees etc.) this RHS collapses to exactly -`consumption_floor` for singles. The constraint is feasible iff the -lowest consumption gridpoint is `<= consumption_floor`. - -`jnp.geomspace(start, stop, num=n)` returns `start * r^i` with -`r = (stop/start)^(1/(n-1))`; mathematically `r^0 == 1` so the first -point equals `start`, but XLA backends can drift by sub-ULP for some -`(start, stop, n)` combinations (observed: CUDA, n=70, drift +2.27e-13). -A positive drift above `consumption_floor` flips the kink-boundary `<=` -and rejects every action for those subjects. - -`_compute_consumption_points` therefore pins the first point back to -`consumption_floor` after `geomspace`. Test that invariant directly. -""" - -import jax.numpy as jnp -import pytest - -from aca_model.consumption_grid import _compute_consumption_points - - -@pytest.mark.parametrize("n_points", [5, 16, 64, 70, 100]) -def test_compute_consumption_points_first_equals_floor_exactly(n_points: int) -> None: - """The first gridpoint equals `consumption_floor` exactly under any `n_points`.""" - consumption_floor = 1597.0921419521899 # production value - pts = _compute_consumption_points( - consumption_floor=consumption_floor, - max_consumption=300_000.0, - n_points=n_points, - ) - assert float(pts[0]) == consumption_floor - - -def test_compute_consumption_points_strictly_increasing() -> None: - """Gridpoints are strictly increasing — no kink-pinning ties.""" - pts = _compute_consumption_points( - consumption_floor=1597.0921419521899, - max_consumption=300_000.0, - n_points=70, - ) - diffs = jnp.diff(pts) - assert bool((diffs > 0).all()) diff --git a/tests/test_consumption_unequiv_grid.py b/tests/test_consumption_unequiv_grid.py new file mode 100644 index 0000000..f06d571 --- /dev/null +++ b/tests/test_consumption_unequiv_grid.py @@ -0,0 +1,48 @@ +"""Consumption-grid invariants required by the borrowing constraint. + +The borrowing constraint in `agent.assets_and_income.borrowing_constraint` +compares the lowest consumption_unequiv action against +`max(cash_on_hand, consumption_unequiv_floor * equivalence_scale)`. For subjects +with cash below the floor (HRS bottom-coded `assets=-$1{,}000{,}000$`, +moderate-negative-asset retirees etc.) this RHS collapses to exactly +`consumption_unequiv_floor` for singles. The constraint is feasible iff the +lowest consumption_unequiv gridpoint is `<= consumption_unequiv_floor`. + +`jnp.geomspace(start, stop, num=n)` returns `start * r^i` with +`r = (stop/start)^(1/(n-1))`; mathematically `r^0 == 1` so the first +point equals `start`, but XLA backends can drift by sub-ULP for some +`(start, stop, n)` combinations (observed: CUDA, n=70, drift +2.27e-13). +A positive drift above `consumption_unequiv_floor` flips the kink-boundary `<=` +and rejects every action for those subjects. + +`_compute_consumption_unequiv_points` therefore pins the first point back to +`consumption_unequiv_floor` after `geomspace`. Test that invariant directly. +""" + +import jax.numpy as jnp +import pytest + +from aca_model.consumption_unequiv_grid import _compute_consumption_unequiv_points + + +@pytest.mark.parametrize("n_points", [5, 16, 64, 70, 100]) +def test_compute_consumption_unequiv_points_first_equals_floor_exactly(n_points: int) -> None: + """The first gridpoint equals `consumption_unequiv_floor` exactly under any `n_points`.""" + consumption_unequiv_floor = 1597.0921419521899 # production value + pts = _compute_consumption_unequiv_points( + consumption_unequiv_floor=consumption_unequiv_floor, + max_consumption_unequiv=300_000.0, + n_points=n_points, + ) + assert float(pts[0]) == consumption_unequiv_floor + + +def test_compute_consumption_unequiv_points_strictly_increasing() -> None: + """Gridpoints are strictly increasing — no kink-pinning ties.""" + pts = _compute_consumption_unequiv_points( + consumption_unequiv_floor=1597.0921419521899, + max_consumption_unequiv=300_000.0, + n_points=70, + ) + diffs = jnp.diff(pts) + assert bool((diffs > 0).all()) diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py index 6078547..1994b4d 100644 --- a/tests/test_initial_conditions_extreme_assets.py +++ b/tests/test_initial_conditions_extreme_assets.py @@ -1,13 +1,14 @@ """Subjects at extreme negative assets must clear `validate_initial_conditions`. The transfer system (`agent.assets_and_income.transfers`) tops cash-on-hand -to `consumption_floor * equivalence_scale` at any starting state, so the -lowest consumption-grid point is always a feasible action regardless of +to `consumption_unequiv_floor * equivalence_scale` at any starting state, so the +lowest consumption_unequiv-grid point is always a feasible action regardless of how negative starting assets are. The model's constraints — and pylcm's `validate_initial_conditions` pass — must reflect this. """ import jax.numpy as jnp +import pytest from lcm import DiscreteGrid from lcm.simulation.initial_conditions import validate_initial_conditions @@ -20,37 +21,37 @@ ) -def test_borrowing_constraint_admits_consumption_at_floor() -> None: - """`consumption == consumption_floor` at the kink is feasible by equality.""" - consumption_floor = 5_000.0 +def test_borrowing_constraint_admits_consumption_unequiv_at_floor() -> None: + """`consumption_unequiv == consumption_unequiv_floor` at the kink is feasible by equality.""" + consumption_unequiv_floor = 5_000.0 equivalence_scale = jnp.asarray(1.0) cash_on_hand = jnp.asarray(-50_000.0) # below floor — RHS = floor admitted = bool( borrowing_constraint( - consumption=jnp.asarray(consumption_floor), + consumption_unequiv=jnp.asarray(consumption_unequiv_floor), cash_on_hand=cash_on_hand, - consumption_floor=consumption_floor, + consumption_unequiv_floor=consumption_unequiv_floor, equivalence_scale=equivalence_scale, ) ) assert admitted -def test_borrowing_constraint_rejects_consumption_above_post_transfer_resources() -> ( +def test_borrowing_constraint_rejects_consumption_unequiv_above_post_transfer_resources() -> ( None ): - """`consumption > max(cash_on_hand, floor)` is rejected.""" - consumption_floor = 5_000.0 + """`consumption_unequiv > max(cash_on_hand, floor)` is rejected.""" + consumption_unequiv_floor = 5_000.0 equivalence_scale = jnp.asarray(1.0) cash_on_hand = jnp.asarray(-50_000.0) - consumption = jnp.asarray(consumption_floor + 1.0) + consumption_unequiv = jnp.asarray(consumption_unequiv_floor + 1.0) admitted = bool( borrowing_constraint( - consumption=consumption, + consumption_unequiv=consumption_unequiv, cash_on_hand=cash_on_hand, - consumption_floor=consumption_floor, + consumption_unequiv_floor=consumption_unequiv_floor, equivalence_scale=equivalence_scale, ) ) @@ -62,31 +63,36 @@ def test_borrowing_constraint_admits_floor_at_million_dollar_negative_cash() -> Reproduces the production failure mode at `assets=-$1{,}000{,}000$` (HRS bottom-code): the algebraically equivalent `cash_on_hand + transfers` - form rounds to `floor - 5.7e-11` at fp64, flipping `consumption <= ...` - for the lowest consumption gridpoint. The `max(cash_on_hand, floor)` + form rounds to `floor - 5.7e-11` at fp64, flipping `consumption_unequiv <= ...` + for the lowest consumption_unequiv gridpoint. The `max(cash_on_hand, floor)` form returns `floor` exactly. """ - consumption_floor = 1597.0921419521899 # production value + consumption_unequiv_floor = 1597.0921419521899 # production value equivalence_scale = jnp.asarray(1.0) cash_on_hand = jnp.asarray(-1_000_000.0) - consumption = jnp.asarray(consumption_floor) # lowest grid point + consumption_unequiv = jnp.asarray(consumption_unequiv_floor) # lowest grid point admitted = bool( borrowing_constraint( - consumption=consumption, + consumption_unequiv=consumption_unequiv, cash_on_hand=cash_on_hand, - consumption_floor=consumption_floor, + consumption_unequiv_floor=consumption_unequiv_floor, equivalence_scale=equivalence_scale, ) ) assert admitted +@pytest.mark.skip( + reason="benchmark_params.pkl predates the consumption_unequiv rename " + "(carries `average_consumption` instead of `average_consumption_unequiv`). " + "Unskip once the benchmark snapshot is regenerated." +) def test_extreme_negative_assets_subject_passes_validation() -> None: """A subject placed at `assets = -1_000_000` clears initial-conditions validation. HRS bottom-codes very-large-negative net wealth at exactly $-1{,}000{,}000$. - Such subjects should remain in the simulated population: the consumption + Such subjects should remain in the simulated population: the consumption_unequiv floor / transfer system absorbs them, with `c = c_floor` always feasible. """ n_subjects = 1 diff --git a/tests/test_model_components.py b/tests/test_model_components.py index 9876ac8..3153195 100644 --- a/tests/test_model_components.py +++ b/tests/test_model_components.py @@ -78,7 +78,7 @@ def test_leisure_bad_health() -> None: def test_utility_positive_leisure() -> None: result = preferences.utility( - consumption=jnp.array(10000.0), + consumption_unequiv=jnp.array(10000.0), leisure=jnp.array(3000.0), pref_type=jnp.array(0), consumption_weight=jnp.array([0.4, 0.4, 0.4]), @@ -91,7 +91,7 @@ def test_utility_positive_leisure() -> None: def test_utility_log_case() -> None: result = preferences.utility( - consumption=jnp.array(10000.0), + consumption_unequiv=jnp.array(10000.0), leisure=jnp.array(3000.0), pref_type=jnp.array(0), consumption_weight=jnp.array([0.4, 0.4, 0.4]), diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index fca2ef6..0cf9655 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -10,7 +10,7 @@ from aca_model.aca.regimes import build_all_regimes as _build_aca_regimes from aca_model.baseline.regimes import REGIME_SPECS, RegimeId from aca_model.baseline.regimes import build_regime as _build_regime -from aca_model.baseline.regimes._common import MAX_CONSUMPTION, build_grids +from aca_model.baseline.regimes._common import MAX_CONSUMPTION_UNEQUIV, build_grids from aca_model.config import GRID_CONFIG @@ -70,7 +70,7 @@ def test_forcedout_regimes_no_labor_supply(name: str) -> None: regime = build_regime(name) assert "labor_supply" not in regime.actions assert "log_ft_wage_res" not in regime.states - assert "consumption" in regime.actions + assert "consumption_unequiv" in regime.actions @pytest.mark.parametrize( @@ -270,6 +270,6 @@ def test_baseline_model_creates() -> None: assert len(model.regimes) == 19 -def test_max_consumption_attached_from_canonical_constant() -> None: +def test_max_consumption_unequiv_attached_from_canonical_constant() -> None: model = make_baseline_model(n_subjects=1) - assert model.max_consumption == MAX_CONSUMPTION + assert model.max_consumption_unequiv == MAX_CONSUMPTION_UNEQUIV diff --git a/tests/test_pension_integration.py b/tests/test_pension_integration.py index 287cab6..9a9176d 100644 --- a/tests/test_pension_integration.py +++ b/tests/test_pension_integration.py @@ -95,7 +95,7 @@ def test_next_assets_includes_pension_adjustment() -> None: cash_on_hand=jnp.array(100_000.0), transfers=jnp.array(0.0), pension_assets_adjustment=jnp.array(5_000.0), - consumption=jnp.array(80_000.0), + consumption_unequiv=jnp.array(80_000.0), oop_costs=jnp.array(0.0), ) assert jnp.isclose(result, 25_000.0, atol=ATOL) diff --git a/tests/test_preferences.py b/tests/test_preferences.py index 4ff2266..2afe6b8 100644 --- a/tests/test_preferences.py +++ b/tests/test_preferences.py @@ -34,7 +34,7 @@ def test_utility_scale_factor_crra() -> None: result = preferences.utility_scale_factor( pref_type=jnp.array(0), - average_consumption=AVERAGE_CONSUMPTION, + average_consumption_unequiv=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_5_BY_TYPE, time_endowment=TIME_ENDOWMENT, @@ -50,7 +50,7 @@ def test_utility_scale_factor_crra() -> None: def test_utility_scale_factor_log() -> None: result = preferences.utility_scale_factor( pref_type=jnp.array(0), - average_consumption=AVERAGE_CONSUMPTION, + average_consumption_unequiv=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_1_BY_TYPE, time_endowment=TIME_ENDOWMENT, @@ -108,7 +108,7 @@ def test_scaled_bequest_weight_zero() -> None: def test_utility_log_regression() -> None: scale = preferences.utility_scale_factor( pref_type=jnp.array(0), - average_consumption=AVERAGE_CONSUMPTION, + average_consumption_unequiv=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_1_BY_TYPE, time_endowment=TIME_ENDOWMENT, @@ -119,7 +119,7 @@ def test_utility_log_regression() -> None: scale_reference_age=SCALE_REFERENCE_AGE, ) result = preferences.utility( - consumption=jnp.array(50000.0), + consumption_unequiv=jnp.array(50000.0), leisure=jnp.array(400.0), pref_type=jnp.array(0), consumption_weight=WEIGHT_BY_TYPE, @@ -133,7 +133,7 @@ def test_utility_log_regression() -> None: def test_utility_crra_regression() -> None: scale = preferences.utility_scale_factor( pref_type=jnp.array(0), - average_consumption=AVERAGE_CONSUMPTION, + average_consumption_unequiv=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_5_BY_TYPE, time_endowment=TIME_ENDOWMENT, @@ -144,7 +144,7 @@ def test_utility_crra_regression() -> None: scale_reference_age=SCALE_REFERENCE_AGE, ) result = preferences.utility( - consumption=jnp.array(50000.0), + consumption_unequiv=jnp.array(50000.0), leisure=jnp.array(400.0), pref_type=jnp.array(0), consumption_weight=WEIGHT_BY_TYPE, @@ -156,10 +156,10 @@ def test_utility_crra_regression() -> None: def test_utility_married_equivalence() -> None: - """Married with equiv-scaled consumption should equal single utility.""" + """Married with equiv-scaled consumption_unequiv should equal single utility.""" scale = preferences.utility_scale_factor( pref_type=jnp.array(0), - average_consumption=AVERAGE_CONSUMPTION, + average_consumption_unequiv=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_5_BY_TYPE, time_endowment=TIME_ENDOWMENT, @@ -170,7 +170,7 @@ def test_utility_married_equivalence() -> None: scale_reference_age=SCALE_REFERENCE_AGE, ) single = preferences.utility( - consumption=jnp.array(50000.0), + consumption_unequiv=jnp.array(50000.0), leisure=jnp.array(400.0), pref_type=jnp.array(0), consumption_weight=WEIGHT_BY_TYPE, @@ -179,7 +179,7 @@ def test_utility_married_equivalence() -> None: utility_scale_factor=scale, ) married = preferences.utility( - consumption=jnp.array(50000.0 * 2**0.7), + consumption_unequiv=jnp.array(50000.0 * 2**0.7), leisure=jnp.array(400.0), pref_type=jnp.array(0), consumption_weight=WEIGHT_BY_TYPE, @@ -196,7 +196,7 @@ def test_utility_married_equivalence() -> None: def test_bequest_log_regression() -> None: scale = preferences.utility_scale_factor( pref_type=jnp.array(0), - average_consumption=AVERAGE_CONSUMPTION, + average_consumption_unequiv=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_1_BY_TYPE, time_endowment=TIME_ENDOWMENT, @@ -229,7 +229,7 @@ def test_bequest_log_regression() -> None: def test_bequest_crra_regression() -> None: scale = preferences.utility_scale_factor( pref_type=jnp.array(0), - average_consumption=AVERAGE_CONSUMPTION, + average_consumption_unequiv=AVERAGE_CONSUMPTION, consumption_weight=WEIGHT_BY_TYPE, coefficient_rra=RRA_5_BY_TYPE, time_endowment=TIME_ENDOWMENT, From c31b5dd81701d99a1d5efc64a959dee5ac26432b Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 10 May 2026 18:04:58 +0200 Subject: [PATCH 43/54] benchmark: regenerate snapshot, drop rename/shift workarounds MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The frozen `benchmark_params.pkl` is now produced by `aca-dev/scripts/regen_benchmark_params.py` against the current aca-data + aca-estimation + aca-model code. With a fresh snapshot: - `consumption_unequiv_floor` and `average_consumption_unequiv` keys match the post-rename model — no more old-key carryover. - `imp_*_next_period` shifted views come from aca-data's `_shift_one_period_forward` directly — no synthesis needed at load. - `imputed_pension_wealth_next_period` is no longer a fixed param (regimes resolve it via DAG) — no filter needed. - pref-type-indexed Series are pre-truncated to BenchmarkPrefType's two rows — no truncation needed at load. Drop from `benchmark.py`: - `_STALE_FIXED_KEYS` filter - `_SHIFTED_IMPUTATION_KEYS` + `_add_shifted_imputation_arrays` synthesis - `_shift_one_period_forward` + `_rename_his_level` helpers - `_truncate_pref_type_indexed` - The `DispatchStrategy` / `PARTITION_SCAN` / `PARTITION_VMAP` docstring references — that pylcm enum was never merged. Drop the unused `pandas` import. Unskip `test_extreme_negative_assets_subject_passes_validation` — the benchmark snapshot it loads now matches the model's expected keys. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../_benchmark_data/benchmark_params.pkl | Bin 54669 -> 65428 bytes src/aca_model/benchmark.py | 128 ++---------------- .../test_initial_conditions_extreme_assets.py | 6 - 3 files changed, 11 insertions(+), 123 deletions(-) diff --git a/src/aca_model/_benchmark_data/benchmark_params.pkl b/src/aca_model/_benchmark_data/benchmark_params.pkl index af9dbc24375d2b7c51f31793e6c4f92d30d70feb..0d4005e1262a18000419eeb8e54f6c54b6167992 100644 GIT binary patch delta 13486 zcmds83tUr2)+Y}V-tvA5@>Jfakf5l`Ahf7Ku%fj;D-9@tP*5&W78Fawi%R5W>Lm)5tVAwW%O$01Svj=h z-;C&7jb%L8hBCMU{Nya21 zJc+;G6#P<2`~g1<%#TQgFNS51Z$)RZ*>W~)Rt32}<XxjqQ%&u?8-VONO_?!pMDb zyICwb^q(FA*Y?bT*Afc3ThJ)DA}iFJ*o8_IAa9}qu?OEe*sx;k7`q1BMJ0ujG7TA? z-o_#`#$RNy&#}q!CwH^>=TNtqEi8CEBGgHahmiPol1=_iaKfodO}AESq&av^7UXH9 z-o#q;6036^dLsO#-cnx-vOm#^UPu9f$x!&~TiWOrOxvhsKizV&zePfvMnz0EGjb{! zlfIE+wJT${j^%7E%l%AAh-+v>Ht``nj+(w-3L6`Vu_ZL2ZhA|&&W4I<0e;o^rIJ{M zABIY*fm^2u?R5gfFQ-MgZ$O>?qHW+?8 zJ;YltWjc>`gBs!-8-~pAL#lz3XZS^7F-5%oc}z+fhUtxm%V*4>G^Ou7h@#xkIMa&q zrdu#!BSl5hE#v!J3>B5mb1*4G){tdDR0n9v0nga)_lfE7k|-eDtfCND+3JN2&ax#6 z$WjLz=xk!vDb{J2mJ4CV_;9#y=7#VVL|9Z~4%rISN;HwB+0|Tso#qjF+SFgbrVfY- z*!*meM$KNyCwP6uydQ|M~;?GvQ;8r$sKTMpeQ^x?QJo@1sKyai zqH2)~t5I!Zpn8x(HF+>w^~ScUwNs`a_3$KiAUA@8t<1>vQgo2U7kNsa&w^Pcc0Agi zjJ6zC;!_3}e}T=)fMGB(&dBYUVdAC&F|nl7EEC@|nCQbyqz{HfC`O`4S+Iyo8Yv}< zB(iduyx5o_4nhU?DWk~fvS}=M-OCIQZiFK`T#@H3s#n{nU z#7BMz31tBHByZ3FeA%xSL{b5Gm313d>0<#;#sIHzQ1EV5&oFgtqfcxoq$KLlvCSLS ztcs!|Oq`?IhvvYEFp>V4XV4!FlBhH&`6bP(E345gG&B%{v{(m&(RX2fZ9vFiOVkgx zqTE^CG7k=Mlm~a##0=yqb~rpir|nUYsCBb+L}{r_W~0e2)|!|sYdDip;hg~hHhe?> zV9UoTI}41Ehw8fYK{P|zq&}bJLhl8)Q}ug*j-km6hSDg8EQkf!4H5c`2|2c5BbzHhBp^y(jVI)&M5Pf* zQGrmR8#b*~*r*yDi%Uuh#ZpPZ^D3!6fX`_Eki5T!#bU4KuiN%{^lCIj)NL-`BF0qJ z__mTy^4IO=0S=rx$b2J(4y@)k9YU}4T zhck!Hp@}#@7tf$OZnyJc>~<<|%RbA86I(;sukgvd)-qnzpHV+O4CeADk5*J|-YPRP zy&LAVUwnj8_b{Zsilz4V@5C>aL>GP-W@ygx$^7k^Y`u-aj-AzXCe^v$E|6lGnbdHj z#OWyP34^tsjWRzXi~|Fvhv;&z@q>>r%wNL5ju&yZwXLHTKVtUQ(n(i!KT8&puH@Id zPIBn1tFHSXDh;rstk*<(JeY1Vk#&uyTMW6@t5gjmJ1Uq0%%n1|I*u#|fN?bThMssf zoh4`6t$WMgLTWtmEm80Tr0Glma~(Wi<-SQg#5LM2AywHRYViLr`c|<~LI0Y=sBv)( ztotf6eQUyf!j{p-W7PPJnjJN+Vkb2lVUKI6UA)Hq2Rq#e2deG0wcY5tKsZjtRUKDO zRM#!Ghgt3N1i~5D(q6u{61NjBs8&aeKz9*~a77Qo4c}CD@Xy^ou!;25%Dd%|d-Asj zZ<>?Wl!X6#;{ExDWTIh9sM%kHFRg^p;RhrB3-xF0 zlHiGb9&pwnkD=-CSN2B3GgV`61m_*@*J#7?zs6QKmzvu`DKT|>wkfCEoZ*>0_ouDH z-cZZ)DYDuPwNvs-{YM9W&~B*hmlN?tIHnA1MKUxBkK21X=)aL5S0C|W(UI<`{;-l6 z>7p~%D5JR5e2|MeM_KG_7%9T@r|WQd`1F}I`0@3bHFSJi+0TfD@Q}begS z-$Dc$77YccMrW`Mbc>1jc01i-$hAJct%M((Ud|Mtii&S_99a;+RD9cjXVb;EXK+m> zeO_cEsuZzm(DyCs0yX;fY@t*^nETFvRRe$Q?02`nZ5y^~scLtXIy>lkwqMNkuaxAk zy!TJ*m+Ja}Kh1scO!}~wswR(Kj&yYYwko&j?o5JSs=AQf(JC>udkhjY`rPlsBUvS; zHY`G7rk{^wqhrjH^D)|O=C^ITv(VBsIMLvmg#(S&|LSZP9ofEP5TFil30ld2+vaSQ zVssJ@!Hv3gjgR7teU(5~{OQF2*8!TfdvDyM)~qS{ zhUAPa$@i=|`3$w!$Dm*=tHGjBs~dQV0%gIHGW=aM9PxfQ9J$I5eH9uB^P6JfjUKm1 zqaOeo>=jB$kyy1LUYNW5D4^}sagjLzuqir!-8Yi+rVL}4l79&M0 z5xZJLj6f|^BG?@j2oH2Knu72_?~Id5Z1hgQ=?38%z3H-zHr)nZ8Txz#Hbw{Xg{T+$ zoDYh6(zA(2!w~%osw)h>^f-l`!Voscys$(PZu3$S0{k@0l?v@)@a0&fO&{XXimnZ( zLp>|I=Iyze>Qv&B-4S;y_3>$bF$0WcgMs015)9f9?5|FRKwR{ENr_UE z2{wHH+^`#;T~-C6Tz&gA{Y~~rs^~rB${t7VBU#ZiKFPIO{LOC)`LP=e8s=L`+g@c z5$rl2#Ts{MJh)pgKe9bD5q#3Ix%KJBo=E(yf;rsi-`k^np-4V*hO z2|Sm2isvX!0x|9$S8g{af#ugWCmdOk1msZDD0jS~p^`~EQgy2T;x9ZiE0ua@B zDdlyg5d34??ajTf3P7-W(*1L5gy2=%j{;{O7l6VWvA1M9gy3@P;V$kC0hlM7?Qrmb z5Ui~#?ArT-oe)U6!@srd6oSC$Imtc=Lh#LBg#O3h6M`3>c%n7?X(1^7zx8JoSA<|` z{m4-Hav^v?xVUy-6@n(tp7oOu;k&tCZrh03M@&9?WO0uW{5<}tl~*naLChVO+gtmD zK*8UeSbb3lQk;FmmxU$+d(f`>=DZM8&GUHQ|D_qpfZXt5{5g{Rhp~U(cwamb91fZDTJK!}NZa`3C!@P3 zf-98^0;1~$p!(}=3l~jI0rjJz8#nj~fZ&}?AtT>S0S~@vi2Qq768K`r1ee3vlfe9r zIX<^TlYm#=XCvRdISIV;%u_GPTHO=D-12*Ad9{;4%8}XYLtGL8JTZGToE2dYA4c*5 z2L4p{0{wyXPJ8H<8|d){S{vw0h?fYpFW_I_W*S6r~#*1fGt-EjXic}d8_0D}A;{~;XSdoy?Lox9DQd*?m% zQ_g}-9E({NOFuaNxy21G-!XZTOp%-_OB|b-p;V=3<)me1AiFLXM;wio)I18J_L=E^ zx;9YfsAs6bKOz!nE`(RgP?2VwNmJiBS_L%aSMN$P+VUQ*v;hA}w>ATB;|Z z5%6Rj>NLup;ASm9DVHT+N`7*={32R2!5Qtf@}h2`->rhFIt{<;REqi)Ewpyl4}Z`& z{3#jEm)FY`^1tOO%*^tvzLCX26F|XqJrMJ3kV@bo<#+-?0>1cFhFgJ}X&9Q^V8We!c|AIw6RHaq_UemF+|mfr+@ z6bpS!v*=r#!Yu2O6V!64-!eIcs%*w+Vw@d}73pr9#iFHZj+W+-OPRUTTs+u6Umb__ zD}W1e1KSIhAUh9Fw4u8fwGu!1qWbP0R2hES*8Fg1^p{7PcDS=9qel=;+t8@~HXnfl z<}ux9Gqa)uOxFcJKv!}R;EOCror0%LC{oT-r^o^9j{8(|vCAdnUJUV*kApni*JxHh zL`i;qX!yogJ7_W~YqFtuGJyaNBAp5-AvYsWgaEzo6@?Bq_@WYD zFT}n7Ij`7aZ00zbQjG>a8H;*6Sj1g~abicK{ei-~Z2So4;fsuskg;G5?jXNQ+c$h) z$`1uoqHpe9TGtmO?Q_XqJu(!y%)BA|(IFI=`ILQUwImdzph*EU<3hpIuL9{E)uA9{ zV1in;FccK5M0ZmgLct{A@!PqF@%>xPM<>0*fcxPQ=d+x`Kxp2MSrhgQ3j=SIzW-8PII7pRN$Hv!%g8|WNxF`CD zg9qjDqt@7pKwd-UDF>_3;h;PqcKfw~A^?N7t-UfW9Gv$XHl%*M2-KK$b>Hw&IEb&g zY4hQH5y<}~Y}&-V;lO7Puc&gP2pmlmpLzTDa4_|7%;ptUB5>+y#+g)}2vmGM=n*_F z0waf5PG1`+0^$!Uqe?D_Kvl`{&;Gqh1javlSZoz9&I3`y*T!EsECQw0eSh4aCI%|H zl;?I=1eVGQ$0p1XgON6Uwy$;;gIP0juKlz|47Mh2bNez}3{Lk;=`(1L7z8-ZtXLB- z26uLzp8xo)7|d0aB@~PkgZL#|R_1d0feA^Q%s*6$LFL-7WpP3Mz`o1#?7x{V28BZ= z+r2(EuOIMAO^$y(R}3m1nw5*z^#kSAH_{(qgu>c8R+Ikd2MD8GQp5nx+v__asz2DO z#+*fqfye#X`Ioo!2L*o&-{H|6GjMCYuv_43ZENSX6`=#htL0(E^!R@2qQZGAFu zcnzcJ8X9Fz@UT0?!Z|~ud?#1DWeI6`6^-&0TSH^!HJ|xzr>Gi~*y$6(2n&^N-6z#{+)++&15HUTFVON>5Uh&~1IArS^ZLPo1etZ5y@3=_o=Lc0mXnSprfAz1quosgCK z4c3STtRmc_e2J50M!+(IB)y(r6686OZ+DY$rcVsu01lcS;pTdobSmH#ax;#ySJBo8 z(JKbUqw5hq4)^e&ScBq-kk|^IG}M}bw&b9!!SPy~2p#DHH8j{|8kP&@qC0~x4Y6pwBVD72G_hR0S+L}I1V#SH2xCxF_tRP|kIbRy_xLn}QX4IpF;=8wmk za`P0DJ_Yn5H{*rtjXQ9FE#{-_nSA6F=kzk^ar=Ai)8jNE<%{|l-&R4$81%S|=uyzT z9?`uqAup5if|J{)ysC~--cFYC&bFpJRcatb0t6o=*mXg`o2RKxg%nMhqx^JN9i-<{ zD_b)QW1H|0g<7=9wlp7_W1*={n{I-ZfJLiEYqSne2*I?F3i)gLOnl1(y$}{XQET)v zmj=E}{_@H4_W8@3MCun_Ytgqc2z>?~MV^@)7d_d|2S5vN6D0qPmSO1{3=VF<~R7qJ!D+Ww8;^3L6^nr~CE6 zhkk>rEs3No3U6{rN#q7w*sH$W;?tP(X)bJ8MbkN4R5jSgVFo6p&5mZ_2bhm9sPw)E zP6r#lKdIS@Ybm;vi@vBn*Q*qRVh=Qz7G@CaIpl)4nq6OP<|aLx7wQU~bUX?N5Q5Pp zb5=4CLZ%6%h=CwIgcP6+-=8A99sS{4+gT*Vs9g#nYS{0$bJ6+z+QO0pPjb;;Ukm99 zuEu_^lAE^$4{Q{@xzw&>#nWZexXw9%I6oYEAxY7}jjrhGetW?QBC!J2lACc>c#^B( z|0ty!T?CE$xqvC?to0VhCNnPRyx8^Ibdpx*vI0pS2POOU)YX4{Q^QfuWiHs$4j=h| z6tGDL3(4)=(L#T1QKZt(+_VHSs^~PU4t>L6N;<>vR_xQOi`M;2RNz4j#6(Fqkb#&k z=~Srh$Yizv-p?ZdDcvHy2xWs0CXg0$Afk$;=HbH>Rk8{n_ov9SVDqK;{%@bR?CsJz>-VS$nv0vkk5j3W3`;}SWRaw;^d85 z4xkHZN%nIW?5M1^hpt*!MacpDI|v0VJU8vgGK<~R#~DG;owOpB*UlkiV0YXqf*$x9 zOTixa278jrY#*T&uRo^pKQM$ zXxD;d*QpS{#WhP|yS>8HTIhJM(CNU0SD=-r{5yvq=;DEqAE=65a?){NUl)_WQyA zc}&!M1?qNAzc&`oywba1adsQ>gVuAT#ZmmYckY$9oGqBz7bYcX_i$adP%V2r|Mbo~ zx@_U&qoL&2#vF;|2XZU?4FP%aWFxotM}g;gjTyd1a=shIr1(Abi8edM7Ya?hrfD|A zH~mtFN$~To6p{Sc^6DaTTXl62li=s-=WQv$x4-Z}m*5+Qx02wO;k1lMf-9&dRpbT+ zVj{u+4+Als;OCLOO>@KoK-wE|nf z5U~#Zp^sS0N;8w@Jb2Wht96;%w0EAo7uojJs!m{5Bva9Rkuzf9tznQ^t%f?KZ50hW zi(_=4%jU2A*%>1iBswT|Wxi=c)O8Pfb{FJ%J)pBL2}tp@1@0@lDidpC#El@GHagwn zFrkg2o5BBpZ#ZGElrkp3hYrm7hGFj>MSzkH-@NfsK#S|$(CBkEWD)GH5BUeoZc0!? z*dxbNQpx`&?m-?ZsdaPHo#!yBR;h})S;>mzbR|h*(5BmwD9>*sD!FBEQ_Rhonk7r@ ztJc1_K$~y5d%5WBQ+upnZs3_~l69Ku*@?k_Nd)&SpIYC(DN)NmMZp_nkviSC1rAig}U#q#=s{jB1 diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 685b5cd..d215b8b 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -7,20 +7,15 @@ The benchmark substitutes a 2-type `BenchmarkPrefType` for the production 3-type `PrefType`, which saves ~33% of the compile + -execution volume over all 18 regimes. By default the pref_type axis -is handled via pylcm's fused-vmap dispatch (no `DispatchStrategy` -imported — this module stays compatible with pylcm versions that -pre-date the enum). Callers that want partition-lifted dispatch -(`PARTITION_SCAN` / `PARTITION_VMAP`) construct the grid themselves -and pass it via `pref_type_grid`. +execution volume over all 18 regimes. -Parameters (`fixed_params` + `params`) are a committed stub fixture -packaged alongside the module at -`src/aca_model/_benchmark_data/benchmark_params.pkl` — aggregate-level -values (policy schedules, transition probabilities, fitted -coefficients) with no runtime dependency on aca-data or any data-prep -package. The pref-type-indexed entries in `params` are truncated to -two rows on load to match `BenchmarkPrefType`. +Parameters (`fixed_params` + `params`) are a committed snapshot at +`src/aca_model/_benchmark_data/benchmark_params.pkl`, generated by +`scripts/regen_benchmark_params.py` against the current aca-data + +aca-estimation + aca-model code. Pref-type-indexed Series in `params` +are pre-truncated to two rows so the snapshot loads with no further +reshaping; regenerate after any change that affects `fixed_params` +shape (regime DAGs, aca-data outputs, key renames). Initial conditions are drawn randomly per call — assets/aime/wage_res from their grid ranges, discrete states from their categories, regimes @@ -34,7 +29,6 @@ import cloudpickle import jax.numpy as jnp import numpy as np -import pandas as pd from jax import Array from lcm import DiscreteGrid, Model @@ -88,12 +82,7 @@ def create_benchmark_model( n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. When set, the first matching `simulate(...)` call AOT-compiles all simulate functions for that batch shape. - pref_type_grid: Pref-type grid. Pass `DiscreteGrid(BenchmarkPrefType)` - for plain fused-vmap, or - `DiscreteGrid(BenchmarkPrefType, dispatch=DispatchStrategy.PARTITION_SCAN)` - (or `PARTITION_VMAP`) for the partition-lifted kernel — the - recommended production setting for aca-model at scale, but only - supported on pylcm versions that expose `DispatchStrategy`. + pref_type_grid: Pref-type grid; pass `DiscreteGrid(BenchmarkPrefType)`. """ fixed_params, _ = get_benchmark_params(model=None) return create_model( @@ -111,10 +100,6 @@ def get_benchmark_params( ) -> tuple[dict[str, Any], dict[str, Any]]: """Load the frozen `(fixed_params, params)` snapshot. - Pref-type-indexed `pd.Series` in `params` are truncated to - `_N_BENCHMARK_PREF_TYPES` rows so they line up with - `BenchmarkPrefType`'s categories. - When `model` is provided, consumption_unequiv gridpoints are injected into `params` for each regime that declares `consumption_unequiv` as an `IrregSpacedGrid` with runtime-supplied points. The lower bound is @@ -123,104 +108,13 @@ def get_benchmark_params( """ with _PARAMS_FILE.open("rb") as fh: data = cloudpickle.load(fh) - fixed_params = { - k: v for k, v in data["fixed_params"].items() if k not in _STALE_FIXED_KEYS - } - fixed_params = _add_shifted_imputation_arrays(fixed_params) - params = _truncate_pref_type_indexed(data["params"]) + fixed_params = data["fixed_params"] + params = data["params"] if model is not None: params = inject_consumption_unequiv_points(params=params, model=model) return fixed_params, params -# Keys that the older aca-estimation `_assemble_params.py` wrote into -# `fixed_params` but that the current regime now resolves as a DAG -# function. Drop them on load so pylcm's `_resolve_fixed_params` does -# not reject the snapshot. Regenerating `benchmark_params.pkl` would -# also remove these — the filter is a no-op when the snapshot is fresh. -_STALE_FIXED_KEYS: frozenset[str] = frozenset({"imputed_pension_wealth_next_period"}) - - -# Source → derived key mapping for the 1-period-shifted views of the -# imputation arrays. The current pension correction (`imputed_pension_ -# wealth_next_period`) consumes these. The frozen `benchmark_params.pkl` -# predates aca-data's `_shift_one_period_forward` change, so synthesise -# the shifted views on load. The transformation is deterministic: row -# `period` carries the original at row `period + 1`; the last row holds -# flat. A regenerated snapshot can drop this synthesis (the filter is a -# no-op when the keys already exist). -_SHIFTED_IMPUTATION_KEYS: tuple[str, ...] = ( - "imp_intercept", - "imp_pia_coeff", - "imp_pia_kink_0_coeff", - "imp_pia_kink_1_coeff", - "imp_kink_0", - "imp_kink_1", - "imp_fraction_receiving", - "epdv_constant_pension", -) - - -def _add_shifted_imputation_arrays(fixed_params: dict[str, Any]) -> dict[str, Any]: - """Synthesise `_next_period` views from the source arrays.""" - out = dict(fixed_params) - for key in _SHIFTED_IMPUTATION_KEYS: - next_period_key = f"{key}_next_period" - if next_period_key in out or key not in out: - continue - out[next_period_key] = _shift_one_period_forward(out[key]) - return out - - -def _shift_one_period_forward(sr: pd.Series) -> pd.Series: - """Shift age-axis values forward one position (last row held flat). - - For (age, his)-indexed inputs, also rename the `his` level to - `target_his` so the resulting Series matches the level naming the - consuming `imputed_pension_wealth_next_period` function expects. - """ - if isinstance(sr.index, pd.MultiIndex) and sr.index.names[0] == "age": - n_periods = sr.index.levshape[0] - n_other = int( - np.prod([sr.index.levshape[i] for i in range(1, sr.index.nlevels)]) - ) - values = sr.to_numpy().reshape(n_periods, n_other) - shifted = np.concatenate([values[1:], values[-1:]], axis=0) - new_index = sr.index.rename( - [_rename_his_level(name) for name in sr.index.names] - ) - return pd.Series(shifted.ravel(), index=new_index) - if sr.index.name == "age": - values = sr.to_numpy() - shifted = np.concatenate([values[1:], values[-1:]]) - return pd.Series(shifted, index=sr.index) - msg = f"Unexpected index for _shift_one_period_forward: {sr.index!r}" - raise ValueError(msg) - - -def _rename_his_level(name: str) -> str: - """Rename `his` to `target_his`, leave others alone.""" - return "target_his" if name == "his" else name - - -def _truncate_pref_type_indexed(params: dict[str, Any]) -> dict[str, Any]: - """Return a copy of `params` with pref_type-indexed Series cut to 2 rows. - - A Series is pref_type-indexed when its index labels start with - `"type_"`. The first `_N_BENCHMARK_PREF_TYPES` rows are kept so the - Series aligns with `BenchmarkPrefType.type_0`, `type_1`, ... - """ - out: dict[str, Any] = {} - for key, value in params.items(): - if isinstance(value, pd.Series) and all( - str(label).startswith("type_") for label in value.index - ): - out[key] = value.iloc[:_N_BENCHMARK_PREF_TYPES] - else: - out[key] = value - return out - - def get_benchmark_initial_conditions( *, model: Model, n_subjects: int, seed: int ) -> dict[str, Array]: diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py index 1994b4d..1238171 100644 --- a/tests/test_initial_conditions_extreme_assets.py +++ b/tests/test_initial_conditions_extreme_assets.py @@ -8,7 +8,6 @@ """ import jax.numpy as jnp -import pytest from lcm import DiscreteGrid from lcm.simulation.initial_conditions import validate_initial_conditions @@ -83,11 +82,6 @@ def test_borrowing_constraint_admits_floor_at_million_dollar_negative_cash() -> assert admitted -@pytest.mark.skip( - reason="benchmark_params.pkl predates the consumption_unequiv rename " - "(carries `average_consumption` instead of `average_consumption_unequiv`). " - "Unskip once the benchmark snapshot is regenerated." -) def test_extreme_negative_assets_subject_passes_validation() -> None: """A subject placed at `assets = -1_000_000` clears initial-conditions validation. From 99badcdfd49eaef8ed03bbd1f87b6841f8725895 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 10 May 2026 19:59:11 +0200 Subject: [PATCH 44/54] Decompose consumption floor into equiv-param + unequiv-DAG; fix prose MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Major refactor splitting the consumption floor into a per-equivalent estimation parameter (`consumption_equiv_floor`, ~$1.6k/yr) and a household-$ DAG-derived value (`consumption_unequiv_floor = consumption_equiv_floor * equivalence_scale`). Same for consumption: `consumption_unequiv` is the action variable, `consumption_equiv = consumption_unequiv / equivalence_scale` is a new DAG function fed into `utility`. `transfers` and `borrowing_constraint` now consume `consumption_unequiv_floor` (DAG output) directly — no inline multiplication. `utility` consumes `consumption_equiv` directly — no inline division. `MAX_CONSUMPTION_UNEQUIV` is no longer attached as a dynamic Model attribute; `inject_consumption_unequiv_points` imports it as a module constant. The runtime grid's first two points are pinned exactly to the singles' floor and the married floor (`* 2 ** exponent`), with geomspace from there to `MAX_CONSUMPTION_UNEQUIV` — so both household-floor levels land on a feasible action regardless of sub-ULP drift in `jnp.geomspace`. Pref-type-indexed Series renamed to plural form (`consumption_weights`, `coefficients_rra`); the per-cell scalars (after `[pref_type]` indexing) take the singular form. Drops the `alpha` / `gamma` aliases. `scaled_bequest_weight` continues to consume the per-cell scalars (pylcm broadcasts pref-type-indexed Series to per-cell scalars before consumption). Also tighten `_build_per_target_*` annotations to `dict[RegimeName, ...]`, drop stale `consumption_unequiv` prose in docstrings/comments where bare "consumption" was meant, and update all callers + tests for the new signatures. The benchmark snapshot is regenerated against the new keys/structure. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../_benchmark_data/benchmark_params.pkl | Bin 65428 -> 65428 bytes src/aca_model/aca/health_insurance.py | 2 +- src/aca_model/aca/model.py | 5 +- src/aca_model/agent/assets_and_income.py | 54 ++++++---- src/aca_model/agent/preferences.py | 100 +++++++++++------- src/aca_model/agent/utility.py | 27 +++-- src/aca_model/baseline/model.py | 7 +- src/aca_model/baseline/regimes/_common.py | 38 ++++--- src/aca_model/consumption_unequiv_grid.py | 89 ++++++++++------ src/aca_model/environment/social_security.py | 2 +- tests/test_budget_chain_integration.py | 4 +- tests/test_consumption_unequiv_grid.py | 73 ++++++++++--- .../test_initial_conditions_extreme_assets.py | 40 ++++--- tests/test_model_components.py | 22 ++-- tests/test_model_creation.py | 9 +- tests/test_preferences.py | 64 ++++++----- 16 files changed, 315 insertions(+), 221 deletions(-) diff --git a/src/aca_model/_benchmark_data/benchmark_params.pkl b/src/aca_model/_benchmark_data/benchmark_params.pkl index 0d4005e1262a18000419eeb8e54f6c54b6167992..29e65625274786cade19dce6e926e95d09500617 100644 GIT binary patch delta 190 zcmbR8pLxoE<^`Vx1(Neq)6z1NGgI?QisOrl5+_fr6P;}PO>}a-hSuhk&l{OUgp>30 zic50~N;32F;>%Mr(=$qnL8>;}ezRem%=ANJ@}2K$O5(_>QVUBn%i_~=^7D(PxI4`F z5*Jmv=l5jBAEJ|&e;1MDE6L1FjZe)>$uG|ZS_@RMZtYWVjrv2AcYRlq;!gtVPAx8p fFV4tJD@g?!;3$&$;Yu#2%VfVFDw}(Mu(JaI`BhUx delta 149 zcmbR8pLxoE<^`X5_>=Qf)6z1NGgI?QHuHYwWEK)i&d)0@%`GU&%+HH2Pt8ovn4GAf zJ^8^m_sO%rJ&@!p$;?fSPt8loFV6+)nd0uSZtYWVjrv2AUw_w}toK84^5XAmldHdL ts7N9kRGODsSejWDpO%xKUj)=N<4asr>7L({Gk>T|_WL2dx%US FloatND: - """Compute cash on hand available for consumption_unequiv and saving. + """Compute cash on hand available for consumption and saving. OOP health costs are NOT deducted here — they are deducted from next-period assets instead, matching the timing where HCC shocks are - integrated over (agent does not condition consumption_unequiv on OOP). + integrated over (agent does not condition consumption on OOP). """ return assets + after_tax_income + ssi_benefit - hic_premium +def consumption_unequiv_floor( + consumption_equiv_floor: float, + equivalence_scale: FloatND, +) -> FloatND: + """Per-household $-floor on consumption. + + Lifts the per-equivalent floor parameter to the household-$ level + by scaling with `equivalence_scale`. Singles keep + `consumption_equiv_floor`, married households face + `consumption_equiv_floor * 2 ** exponent` — the same two values + that get pinned exactly on the runtime consumption_unequiv grid + (see `aca_model.consumption_unequiv_grid`). + """ + return consumption_equiv_floor * equivalence_scale + + def transfers( cash_on_hand: FloatND, - consumption_unequiv_floor: float, - equivalence_scale: FloatND, + consumption_unequiv_floor: FloatND, ) -> FloatND: - """Government transfers to enforce consumption_unequiv floor. + """Government transfers to enforce the consumption floor. - tr = max{0, C_min * equivalence_scale - cash_on_hand} + tr = max{0, consumption_unequiv_floor - cash_on_hand} """ - floor = consumption_unequiv_floor * equivalence_scale - return jnp.maximum(0.0, floor - cash_on_hand) + return jnp.maximum(0.0, consumption_unequiv_floor - cash_on_hand) def next_assets( @@ -58,7 +72,7 @@ def next_assets( """Compute beginning-of-next-period assets for non-terminal targets. OOP health costs are deducted here (not from cash_on_hand) so that the - consumption_unequiv choice does not condition on the HCC shock realization. + consumption choice does not condition on the HCC shock realization. """ return ( cash_on_hand + transfers + pension_assets_adjustment - consumption_unequiv - oop_costs @@ -85,18 +99,17 @@ def next_assets_terminal( def borrowing_constraint( consumption_unequiv: ContinuousAction, cash_on_hand: FloatND, - consumption_unequiv_floor: float, - equivalence_scale: FloatND, + consumption_unequiv_floor: FloatND, ) -> BoolND: """Consumption cannot exceed post-transfer resources. - Post-transfer resources are `max(cash_on_hand, consumption_unequiv_floor * - equivalence_scale)`: the transfer system tops `cash_on_hand` to the - floor when below, otherwise resources are unchanged. The algebraic - identity is `cash_on_hand + transfers == max(cash_on_hand, floor)`; - the `max` form is preferred because the additive form rounds to - `floor + ε` (with `|ε| ~ ULP(|cash_on_hand|)`) at extreme cash, which - flips the kink-boundary comparison for HRS-bottom-coded subjects at + Post-transfer resources are `max(cash_on_hand, consumption_unequiv_floor)`: + the transfer system tops `cash_on_hand` to the floor when below, + otherwise resources are unchanged. The algebraic identity is + `cash_on_hand + transfers == max(cash_on_hand, floor)`; the `max` + form is preferred because the additive form rounds to `floor + ε` + (with `|ε| ~ ULP(|cash_on_hand|)`) at extreme cash, which flips + the kink-boundary comparison for HRS-bottom-coded subjects at `assets=-$1{,}000{,}000`. The `max` form returns `floor` exactly. `pension_assets_adjustment` is excluded from the constraint: it can @@ -104,7 +117,6 @@ def borrowing_constraint( wealth at a cross-HIS transition), and including it here can leave no feasible action at low-asset / mid-AIME corners. The correction enters `next_assets` instead — a post-decision shift that does not - gate the current consumption_unequiv choice. + gate the current consumption choice. """ - floor = consumption_unequiv_floor * equivalence_scale - return consumption_unequiv <= jnp.maximum(cash_on_hand, floor) + return consumption_unequiv <= jnp.maximum(cash_on_hand, consumption_unequiv_floor) diff --git a/src/aca_model/agent/preferences.py b/src/aca_model/agent/preferences.py index 336cbd7..212802b 100644 --- a/src/aca_model/agent/preferences.py +++ b/src/aca_model/agent/preferences.py @@ -121,34 +121,47 @@ def leisure_retired( return time_endowment - health_loss -def utility( +def consumption_equiv( consumption_unequiv: ContinuousAction, + equivalence_scale: FloatND, +) -> FloatND: + """Per-equivalent consumption: total $ divided by household equivalence scale.""" + return consumption_unequiv / equivalence_scale + + +def utility( + consumption_equiv: FloatND, leisure: FloatND, pref_type: DiscreteState, - consumption_weight: FloatND, - coefficient_rra: FloatND, - equivalence_scale: FloatND, + consumption_weights: FloatND, + coefficients_rra: FloatND, utility_scale_factor: FloatND, ) -> FloatND: - """Within-period utility: CES aggregator over consumption_unequiv and leisure. - - u = utility_scale_factor * ((c/eq_scale)^α * l^(1-α))^(1-γ) / (1-γ) - with log case for γ=1. `consumption_weight` and `coefficient_rra` are - pref-type-indexed Series sourced directly from params; `utility_scale_factor` - is a regime-function output (already a per-cell scalar — must NOT be - re-indexed by pref_type, see `aca_model.agent.preferences.utility_scale_factor` - for why). + """Within-period utility: CES aggregator over consumption and leisure. + + u = utility_scale_factor * + (consumption_equiv^consumption_weight * leisure^(1 - consumption_weight))^(1 - coefficient_rra) + / (1 - coefficient_rra) + with log case for coefficient_rra=1. `consumption_weights` and + `coefficients_rra` are pref-type-indexed Series sourced directly + from params; `utility_scale_factor` is a regime-function output + (already a per-cell scalar — must NOT be re-indexed by pref_type, + see `aca_model.agent.preferences.utility_scale_factor` for why). """ - alpha = consumption_weight[pref_type] - gamma = coefficient_rra[pref_type] - consumption_equiv = consumption_unequiv / equivalence_scale - composite = consumption_equiv**alpha * leisure ** (1.0 - alpha) + consumption_weight = consumption_weights[pref_type] + coefficient_rra = coefficients_rra[pref_type] + composite = ( + consumption_equiv**consumption_weight + * leisure ** (1.0 - consumption_weight) + ) - one_minus_gamma = jnp.where(jnp.isclose(gamma, 1.0), 1.0, 1.0 - gamma) + one_minus_rra = jnp.where( + jnp.isclose(coefficient_rra, 1.0), 1.0, 1.0 - coefficient_rra + ) u = jnp.where( - jnp.isclose(gamma, 1.0), + jnp.isclose(coefficient_rra, 1.0), jnp.log(composite), - composite**one_minus_gamma / one_minus_gamma, + composite**one_minus_rra / one_minus_rra, ) return u * utility_scale_factor @@ -169,8 +182,8 @@ def discount_factor( def utility_scale_factor( pref_type: DiscreteState, average_consumption_unequiv: float, - consumption_weight: FloatND, - coefficient_rra: FloatND, + consumption_weights: FloatND, + coefficients_rra: FloatND, time_endowment: float, fixed_cost_of_work_intercept: float, fixed_cost_of_work_age_trend: float, @@ -184,25 +197,27 @@ def utility_scale_factor( pattern: take the state as input, return a per-cell scalar. Registering this as a regime function and then doing `utility_scale_factor[pref_type]` in a downstream consumer is invalid — pylcm broadcasts function outputs to - per-cell scalars before consumption_unequiv, and the validator in + per-cell scalars before consumption, and the validator in `lcm.regime_building.validation` raises on that clash. """ - alpha = consumption_weight[pref_type] - gamma = coefficient_rra[pref_type] + consumption_weight = consumption_weights[pref_type] + coefficient_rra = coefficients_rra[pref_type] age_offset = scale_reference_age - reference_age average_leisure = ( time_endowment - scale_reference_hours - (fixed_cost_of_work_intercept + fixed_cost_of_work_age_trend * age_offset) ) - u_cons = average_consumption_unequiv**alpha - u_leisure = average_leisure ** (1.0 - alpha) + u_cons = average_consumption_unequiv**consumption_weight + u_leisure = average_leisure ** (1.0 - consumption_weight) - one_minus_gamma = jnp.where(jnp.isclose(gamma, 1.0), 1.0, 1.0 - gamma) + one_minus_rra = jnp.where( + jnp.isclose(coefficient_rra, 1.0), 1.0, 1.0 - coefficient_rra + ) raw = jnp.where( - jnp.isclose(gamma, 1.0), + jnp.isclose(coefficient_rra, 1.0), jnp.log(u_cons * u_leisure), - (u_cons * u_leisure) ** one_minus_gamma / one_minus_gamma, + (u_cons * u_leisure) ** one_minus_rra / one_minus_rra, ) return jnp.abs(1.0 / raw) @@ -237,25 +252,30 @@ def bequest( pref_type: DiscreteState, bequest_shifter: float, scaled_bequest_weight: float, - consumption_weight: FloatND, - coefficient_rra: FloatND, + consumption_weights: FloatND, + coefficients_rra: FloatND, utility_scale_factor: FloatND, ) -> FloatND: """Bequest function for terminal/dead states. - bequest = scale * bwt * (max(0,a) + shifter)^(α*(1-γ)) / (1-γ) - `consumption_weight` and `coefficient_rra` are pref-type-indexed Series - from params; `utility_scale_factor` is a regime-function output (already a - per-cell scalar — must NOT be re-indexed by pref_type). + bequest = scale * bwt * + (max(0,a) + shifter)^(consumption_weight*(1 - coefficient_rra)) + / (1 - coefficient_rra) + `consumption_weights` and `coefficients_rra` are pref-type-indexed + Series from params; `utility_scale_factor` is a regime-function + output (already a per-cell scalar — must NOT be re-indexed by + pref_type). """ - alpha = consumption_weight[pref_type] - gamma = coefficient_rra[pref_type] + consumption_weight = consumption_weights[pref_type] + coefficient_rra = coefficients_rra[pref_type] assets_shifted = jnp.maximum(0.0, assets) + bequest_shifter - one_minus_gamma = jnp.where(jnp.isclose(gamma, 1.0), 1.0, 1.0 - gamma) + one_minus_rra = jnp.where( + jnp.isclose(coefficient_rra, 1.0), 1.0, 1.0 - coefficient_rra + ) val = jnp.where( - jnp.isclose(gamma, 1.0), + jnp.isclose(coefficient_rra, 1.0), jnp.log(assets_shifted), - assets_shifted ** (one_minus_gamma * alpha) / one_minus_gamma, + assets_shifted ** (one_minus_rra * consumption_weight) / one_minus_rra, ) return val * scaled_bequest_weight * utility_scale_factor diff --git a/src/aca_model/agent/utility.py b/src/aca_model/agent/utility.py index 7ba0e2a..d7817a1 100644 --- a/src/aca_model/agent/utility.py +++ b/src/aca_model/agent/utility.py @@ -8,7 +8,6 @@ """ from lcm.typing import ( - ContinuousAction, ContinuousState, DiscreteState, FloatND, @@ -19,29 +18,27 @@ def retired( - consumption_unequiv: ContinuousAction, + consumption_equiv: FloatND, good_health: IntND, - equivalence_scale: FloatND, pref_type: DiscreteState, - consumption_weight: FloatND, - coefficient_rra: FloatND, + consumption_weights: FloatND, + coefficients_rra: FloatND, utility_scale_factor: FloatND, time_endowment: float, leisure_cost_of_bad_health: float, ) -> FloatND: """Utility for forcedout regimes (no work).""" - lei = preferences.leisure_retired( + leisure = preferences.leisure_retired( good_health=good_health, time_endowment=time_endowment, leisure_cost_of_bad_health=leisure_cost_of_bad_health, ) return preferences.utility( - consumption_unequiv=consumption_unequiv, - leisure=lei, + consumption_equiv=consumption_equiv, + leisure=leisure, pref_type=pref_type, - consumption_weight=consumption_weight, - coefficient_rra=coefficient_rra, - equivalence_scale=equivalence_scale, + consumption_weights=consumption_weights, + coefficients_rra=coefficients_rra, utility_scale_factor=utility_scale_factor, ) @@ -51,8 +48,8 @@ def dead( pref_type: DiscreteState, bequest_shifter: float, scaled_bequest_weight: float, - consumption_weight: FloatND, - coefficient_rra: FloatND, + consumption_weights: FloatND, + coefficients_rra: FloatND, utility_scale_factor: FloatND, ) -> FloatND: """Terminal bequest utility for dead regime.""" @@ -61,7 +58,7 @@ def dead( pref_type=pref_type, bequest_shifter=bequest_shifter, scaled_bequest_weight=scaled_bequest_weight, - consumption_weight=consumption_weight, - coefficient_rra=coefficient_rra, + consumption_weights=consumption_weights, + coefficients_rra=coefficients_rra, utility_scale_factor=utility_scale_factor, ) diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index a8533f2..f1f822f 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -17,7 +17,6 @@ from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.baseline.regimes import RegimeId, build_all_regimes -from aca_model.baseline.regimes._common import MAX_CONSUMPTION_UNEQUIV from aca_model.config import MODEL_CONFIG, GridConfig @@ -82,7 +81,7 @@ def create_model( if derived_categoricals is not None: base_derived.update(derived_categoricals) - model = Model( + return Model( regimes=regimes, ages=ages, regime_id_class=RegimeId, @@ -91,7 +90,3 @@ def create_model( derived_categoricals=base_derived, n_subjects=n_subjects, ) - # See `MAX_CONSUMPTION_UNEQUIV` in `baseline.regimes._common` for why this - # rides on the Model instance instead of `fixed_params`. - model.max_consumption_unequiv = MAX_CONSUMPTION_UNEQUIV - return model diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 1bb73fd..f5174e6 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -22,7 +22,7 @@ ) from lcm.grids.continuous import ContinuousGrid from lcm.grids.piecewise import Piece, PiecewiseLinSpacedGrid -from lcm.typing import BoolND, FloatND +from lcm.typing import BoolND, FloatND, RegimeName from aca_model.agent import ( assets_and_income, @@ -199,11 +199,9 @@ class Grids: """Upper bound of the runtime consumption_unequiv grid in $/year. Lives here next to the other grid bounds (assets `stop=500_000.0`, -AIME `stop=8_000.0`). The `create_model` factories attach this onto -`model.max_consumption_unequiv` so `inject_consumption_unequiv_points` can read it -back at runtime. Routed via a Model attribute rather than -`fixed_params` because pylcm validates `fixed_params` keys against -the regime DAG and rejects entries no function consumes. +AIME `stop=8_000.0`). `inject_consumption_unequiv_points` imports it +directly — pylcm rejects `fixed_params` entries no DAG function +consumes, so this stays a module constant. """ @@ -547,12 +545,14 @@ def build_common_functions(spec: dict[str, str]) -> dict: # Cash on hand and transfers functions["cash_on_hand"] = assets_and_income.cash_on_hand + functions["consumption_unequiv_floor"] = assets_and_income.consumption_unequiv_floor functions["transfers"] = assets_and_income.transfers + functions["consumption_equiv"] = preferences.consumption_equiv return functions -def precompute_targets(spec: dict[str, str]) -> dict[str, int]: +def precompute_targets(spec: Mapping[str, str]) -> dict[str, int]: """Pre-compute target regime IDs for each next-age bracket.""" def _resolve(his_val: str, mc_val: str, ss_val: str, canwork_val: str) -> int: @@ -661,7 +661,9 @@ def build_state_transitions(spec: dict[str, str]) -> dict: return transitions -def _build_per_target_next_assets(spec: dict[str, str]) -> dict: +def _build_per_target_next_assets( + spec: Mapping[str, str], +) -> dict[RegimeName, Callable[..., FloatND]]: """Build per-target assets transitions. The `dead` target uses `next_assets_terminal` (no @@ -673,7 +675,7 @@ def _build_per_target_next_assets(spec: dict[str, str]) -> dict: targets = precompute_targets(spec) id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} - result: dict = {} + result: dict[RegimeName, Callable[..., FloatND]] = {} seen_ids: set[int] = set() for target_id in targets.values(): @@ -689,7 +691,9 @@ def _build_per_target_next_assets(spec: dict[str, str]) -> dict: return result -def _build_per_target_health(spec: dict[str, str]) -> dict: +def _build_per_target_health( + spec: Mapping[str, str], +) -> dict[RegimeName, MarkovTransition]: """Build per-target health transitions. Pre-65 regimes use HealthWithDisability (3-state), post-65 use Health (2-state). @@ -698,7 +702,7 @@ def _build_per_target_health(spec: dict[str, str]) -> dict: targets = precompute_targets(spec) id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} - result: dict[str, MarkovTransition] = {} + result: dict[RegimeName, MarkovTransition] = {} seen_ids: set[int] = set() for target_id in targets.values(): @@ -719,7 +723,9 @@ def _build_per_target_health(spec: dict[str, str]) -> dict: return result -def _build_per_target_claimed_ss(spec: dict[str, str]) -> dict: +def _build_per_target_claimed_ss( + spec: Mapping[str, str], +) -> dict[RegimeName, Callable[..., BoolND]]: """Build per-target claimed_ss transitions. - `choose` regimes (source has `claimed_ss`): absorbing transition. @@ -732,7 +738,7 @@ def _build_per_target_claimed_ss(spec: dict[str, str]) -> dict: targets = precompute_targets(spec) id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} - result: dict = {} + result: dict[RegimeName, Callable[..., BoolND]] = {} seen_ids: set[int] = set() for target_id in targets.values(): @@ -754,7 +760,9 @@ def _build_per_target_claimed_ss(spec: dict[str, str]) -> dict: return result -def _build_per_target_lagged_labor_supply(spec: dict[str, str]) -> dict: +def _build_per_target_lagged_labor_supply( + spec: Mapping[str, str], +) -> dict[RegimeName, Callable[..., BoolND]]: """Build per-target lagged_labor_supply transitions. `lagged_labor_supply` exists in canwork non-tied regimes. Tied regimes @@ -771,7 +779,7 @@ def _build_per_target_lagged_labor_supply(spec: dict[str, str]) -> dict: targets = precompute_targets(spec) id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} - result: dict = {} + result: dict[RegimeName, Callable[..., BoolND]] = {} seen_ids: set[int] = set() for target_id in targets.values(): diff --git a/src/aca_model/consumption_unequiv_grid.py b/src/aca_model/consumption_unequiv_grid.py index e0e9cd3..4fcd225 100644 --- a/src/aca_model/consumption_unequiv_grid.py +++ b/src/aca_model/consumption_unequiv_grid.py @@ -3,12 +3,20 @@ Consumption is declared as `IrregSpacedGrid(n_points=N)` in `baseline.regimes._common.build_grids` so the bounds can track runtime parameters: the lower bound from the per-iteration -`consumption_unequiv_floor` parameter, the upper bound from -`MAX_CONSUMPTION_UNEQUIV` in `baseline.regimes._common`, which the -`create_model` factories attach to `model.max_consumption_unequiv`. -Callers must inject the actual gridpoints into `params` via -`inject_consumption_unequiv_points` before calling `model.solve()` / -`model.simulate()`. +`consumption_equiv_floor` parameter (and its couples-scaled twin), +the upper bound from `MAX_CONSUMPTION_UNEQUIV` in +`baseline.regimes._common`. Callers must inject the actual gridpoints +into `params` via `inject_consumption_unequiv_points` before calling +`model.solve()` / `model.simulate()`. + +The grid pins the two regime-relevant transfer-floor levels exactly +on the action grid so the borrowing constraint's +`max(cash_on_hand, floor)` boundary lands on a feasible action for +both single and married households: + +- `pts[0] = consumption_equiv_floor` (single household: equiv_scale=1) +- `pts[1] = consumption_equiv_floor * 2 ** exponent` (married) +- `pts[2:] = geomspace(pts[1], MAX_CONSUMPTION_UNEQUIV, n_points - 1)` """ from collections.abc import Mapping @@ -18,6 +26,8 @@ from jax import Array from lcm import IrregSpacedGrid, Model +from aca_model.baseline.regimes._common import MAX_CONSUMPTION_UNEQUIV + def inject_consumption_unequiv_points( *, @@ -30,20 +40,24 @@ def inject_consumption_unequiv_points( `IrregSpacedGrid` with runtime-supplied points, and writes `params[regime_name]["consumption_unequiv"] = {"points": }`. - Lower bound: `params["consumption_unequiv_floor"]` (varies per iteration). - Upper bound: `model.max_consumption_unequiv` (set by the `create_model` - factory from `MAX_CONSUMPTION_UNEQUIV` in `baseline.regimes._common`). + The lower two gridpoints are the single and married unequiv + transfer floors (`consumption_equiv_floor` and + `consumption_equiv_floor * 2 ** exponent`); the rest are + geomspaced from the married floor up to `MAX_CONSUMPTION_UNEQUIV`. Args: - params: Existing params mapping. Returned as a new dict; the input is - not mutated. - model: Model whose regime specs determine which regimes need points. + params: Existing params mapping with `consumption_equiv_floor` + (per-equivalent floor, varies per iteration). Returned as a + new dict; the input is not mutated. + model: Model whose regime specs determine which regimes need points + and whose `fixed_params["exponent"]` sets the married + equivalence-scale exponent. Returns: New params dict with consumption_unequiv points injected. """ - consumption_unequiv_floor = float(params["consumption_unequiv_floor"]) - max_consumption_unequiv = float(model.max_consumption_unequiv) + consumption_equiv_floor = jnp.asarray(params["consumption_equiv_floor"]) + exponent = jnp.asarray(model.fixed_params["exponent"]) out: dict[str, Any] = dict(params) for regime_name, regime in model.regimes.items(): grid = regime.actions.get("consumption_unequiv") @@ -53,8 +67,8 @@ def inject_consumption_unequiv_points( # rejects the (points=None, n_points=None) combo); narrow for ty. assert grid.n_points is not None points = _compute_consumption_unequiv_points( - consumption_unequiv_floor=consumption_unequiv_floor, - max_consumption_unequiv=max_consumption_unequiv, + consumption_equiv_floor=consumption_equiv_floor, + exponent=exponent, n_points=grid.n_points, ) regime_entry = dict(out.get(regime_name, {})) @@ -65,21 +79,34 @@ def inject_consumption_unequiv_points( def _compute_consumption_unequiv_points( *, - consumption_unequiv_floor: float, - max_consumption_unequiv: float, + consumption_equiv_floor: Array, + exponent: Array, n_points: int, ) -> Array: - """Return log-spaced consumption_unequiv gridpoints from floor to max. - - `jnp.geomspace` computes intermediate points as `start * r^i` with - `r = (stop/start)^(1/(n-1))`; the first point is `start * r^0`, - which is `start` mathematically but can be off by sub-ULP under - some XLA backends (CUDA + 70 points: `start + 2.27e-13`). The - borrowing constraint compares the first action against - `max(cash_on_hand, consumption_unequiv_floor)`, and any positive drift - above `consumption_unequiv_floor` flips the kink-boundary `<=` for - subjects with very negative cash. Pin the first element back to - `consumption_unequiv_floor` exactly. + """Return log-spaced consumption_unequiv gridpoints with both floors pinned. + + Single and married households face different unequiv (in-$) floors + (`consumption_equiv_floor` and `consumption_equiv_floor * + 2 ** exponent` respectively). Both must land exactly on the action + grid so the borrowing constraint's `max(cash_on_hand, floor)` kink + boundary is a feasible action; otherwise sub-ULP drift can flip + the `<=` comparison for subjects with very negative cash. The + geomspace tail starts at the married floor and runs to + `MAX_CONSUMPTION_UNEQUIV` so the two pinned points stay strictly + increasing. + + All arithmetic stays in jax — multiplying `consumption_equiv_floor` + by `2 ** exponent` in jnp keeps both pinned floors at the canonical + float dtype the model uses everywhere else. """ - pts = jnp.geomspace(consumption_unequiv_floor, max_consumption_unequiv, num=n_points) - return pts.at[0].set(consumption_unequiv_floor) + married_unequiv_floor = consumption_equiv_floor * jnp.asarray(2.0) ** exponent + tail = jnp.geomspace( + married_unequiv_floor, MAX_CONSUMPTION_UNEQUIV, num=n_points - 1 + ) + pts = jnp.concatenate([consumption_equiv_floor[None], tail]) + # `jnp.geomspace` returns `start * r^0` for the first tail element, + # which mathematically equals `married_unequiv_floor` but drifts by + # sub-ULP on some XLA backends. Pin the slot back to the exact + # arithmetic value so the borrowing-constraint kink boundary at the + # married floor is exactly representable. + return pts.at[1].set(married_unequiv_floor) diff --git a/src/aca_model/environment/social_security.py b/src/aca_model/environment/social_security.py index c9ce1f5..e3574cf 100644 --- a/src/aca_model/environment/social_security.py +++ b/src/aca_model/environment/social_security.py @@ -30,7 +30,7 @@ def next_claimed_ss( def enter_claimed_ss() -> DiscreteState: """Initial claimed_ss when entering the SS eligibility window.""" - return ClaimedSS.no + return jnp.int32(ClaimedSS.no) # --- PIA lookup (DAG functions) --- diff --git a/tests/test_budget_chain_integration.py b/tests/test_budget_chain_integration.py index b511f8c..87ab670 100644 --- a/tests/test_budget_chain_integration.py +++ b/tests/test_budget_chain_integration.py @@ -126,14 +126,12 @@ def test_transfers_kick_in_below_floor() -> None: ssi_benefit=jnp.array(0.0), hic_premium=jnp.array(0.0), oop_costs=jnp.array(0.0), - consumption_unequiv_floor=5000.0, - equivalence_scale=jnp.array(1.0), + consumption_unequiv_floor=jnp.array(5000.0), pension_assets_adjustment=jnp.array(0.0), consumption_unequiv=jnp.array(4000.0), ) # cash_on_hand = 500 + 200 = 700 - # floor = 5000 * 1.0 = 5000 # transfers = max(0, 5000 - 700) = 4300 assert jnp.isclose(result["transfers"], 4300.0, atol=ATOL) # next_assets = 700 + 4300 + 0 - 4000 = 1000 diff --git a/tests/test_consumption_unequiv_grid.py b/tests/test_consumption_unequiv_grid.py index f06d571..89e8bb9 100644 --- a/tests/test_consumption_unequiv_grid.py +++ b/tests/test_consumption_unequiv_grid.py @@ -2,21 +2,29 @@ The borrowing constraint in `agent.assets_and_income.borrowing_constraint` compares the lowest consumption_unequiv action against -`max(cash_on_hand, consumption_unequiv_floor * equivalence_scale)`. For subjects -with cash below the floor (HRS bottom-coded `assets=-$1{,}000{,}000$`, -moderate-negative-asset retirees etc.) this RHS collapses to exactly -`consumption_unequiv_floor` for singles. The constraint is feasible iff the -lowest consumption_unequiv gridpoint is `<= consumption_unequiv_floor`. +`max(cash_on_hand, consumption_unequiv_floor)`. For subjects with cash +below the floor (HRS bottom-coded `assets=-$1{,}000{,}000$`, +moderate-negative-asset retirees, etc.) this RHS collapses to exactly +`consumption_unequiv_floor`. The constraint is feasible iff the +relevant household-floor gridpoint is `<= consumption_unequiv_floor`. + +For singles (`equivalence_scale = 1`) that floor is +`consumption_equiv_floor`; for married households +(`equivalence_scale = 2 ** exponent`) it is +`consumption_equiv_floor * 2 ** exponent`. Both must land **exactly** +on the consumption_unequiv grid. `jnp.geomspace(start, stop, num=n)` returns `start * r^i` with `r = (stop/start)^(1/(n-1))`; mathematically `r^0 == 1` so the first point equals `start`, but XLA backends can drift by sub-ULP for some `(start, stop, n)` combinations (observed: CUDA, n=70, drift +2.27e-13). -A positive drift above `consumption_unequiv_floor` flips the kink-boundary `<=` -and rejects every action for those subjects. +A positive drift above the floor flips the kink-boundary `<=` and +rejects every action for the affected subjects. -`_compute_consumption_unequiv_points` therefore pins the first point back to -`consumption_unequiv_floor` after `geomspace`. Test that invariant directly. +`_compute_consumption_unequiv_points` therefore prepends the singles' +floor as `pts[0]`, runs `geomspace` from the married floor up to +`MAX_CONSUMPTION_UNEQUIV` for the rest, and pins the geomspace start +back to the married floor exactly. Test those invariants directly. """ import jax.numpy as jnp @@ -24,25 +32,56 @@ from aca_model.consumption_unequiv_grid import _compute_consumption_unequiv_points +EXPONENT = 0.7 # production value (env_constants["exponent"]) +SINGLE_FLOOR = 1597.0921419521899 # production value +MARRIED_SCALE = 2.0**EXPONENT + + +@pytest.mark.parametrize("n_points", [5, 16, 64, 70, 100]) +def test_compute_consumption_unequiv_points_first_equals_singles_floor( + n_points: int, +) -> None: + """`pts[0]` equals the singles' floor exactly under any `n_points`.""" + pts = _compute_consumption_unequiv_points( + consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR), + exponent=jnp.asarray(EXPONENT), + n_points=n_points, + ) + assert float(pts[0]) == SINGLE_FLOOR + @pytest.mark.parametrize("n_points", [5, 16, 64, 70, 100]) -def test_compute_consumption_unequiv_points_first_equals_floor_exactly(n_points: int) -> None: - """The first gridpoint equals `consumption_unequiv_floor` exactly under any `n_points`.""" - consumption_unequiv_floor = 1597.0921419521899 # production value +def test_compute_consumption_unequiv_points_second_equals_married_floor( + n_points: int, +) -> None: + """`pts[1]` equals `consumption_equiv_floor * 2 ** exponent` exactly.""" pts = _compute_consumption_unequiv_points( - consumption_unequiv_floor=consumption_unequiv_floor, - max_consumption_unequiv=300_000.0, + consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR), + exponent=jnp.asarray(EXPONENT), n_points=n_points, ) - assert float(pts[0]) == consumption_unequiv_floor + expected = float(jnp.asarray(SINGLE_FLOOR) * jnp.asarray(2.0) ** EXPONENT) + assert float(pts[1]) == expected def test_compute_consumption_unequiv_points_strictly_increasing() -> None: """Gridpoints are strictly increasing — no kink-pinning ties.""" pts = _compute_consumption_unequiv_points( - consumption_unequiv_floor=1597.0921419521899, - max_consumption_unequiv=300_000.0, + consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR), + exponent=jnp.asarray(EXPONENT), n_points=70, ) diffs = jnp.diff(pts) assert bool((diffs > 0).all()) + + +def test_compute_consumption_unequiv_points_last_equals_max() -> None: + """The final point is the configured upper bound.""" + from aca_model.baseline.regimes._common import MAX_CONSUMPTION_UNEQUIV + + pts = _compute_consumption_unequiv_points( + consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR), + exponent=jnp.asarray(EXPONENT), + n_points=70, + ) + assert float(pts[-1]) == pytest.approx(MAX_CONSUMPTION_UNEQUIV) diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py index 1238171..d121539 100644 --- a/tests/test_initial_conditions_extreme_assets.py +++ b/tests/test_initial_conditions_extreme_assets.py @@ -1,8 +1,8 @@ """Subjects at extreme negative assets must clear `validate_initial_conditions`. The transfer system (`agent.assets_and_income.transfers`) tops cash-on-hand -to `consumption_unequiv_floor * equivalence_scale` at any starting state, so the -lowest consumption_unequiv-grid point is always a feasible action regardless of +to the household-$ consumption floor at any starting state, so the lowest +consumption_unequiv-grid point is always a feasible action regardless of how negative starting assets are. The model's constraints — and pylcm's `validate_initial_conditions` pass — must reflect this. """ @@ -22,16 +22,30 @@ def test_borrowing_constraint_admits_consumption_unequiv_at_floor() -> None: """`consumption_unequiv == consumption_unequiv_floor` at the kink is feasible by equality.""" - consumption_unequiv_floor = 5_000.0 - equivalence_scale = jnp.asarray(1.0) + consumption_unequiv_floor = jnp.asarray(5_000.0) cash_on_hand = jnp.asarray(-50_000.0) # below floor — RHS = floor admitted = bool( borrowing_constraint( - consumption_unequiv=jnp.asarray(consumption_unequiv_floor), + consumption_unequiv=consumption_unequiv_floor, cash_on_hand=cash_on_hand, consumption_unequiv_floor=consumption_unequiv_floor, - equivalence_scale=equivalence_scale, + ) + ) + assert admitted + + +def test_borrowing_constraint_admits_consumption_unequiv_at_married_floor() -> None: + """At a married household's higher floor, the equivalence-scale-lifted floor is feasible.""" + consumption_equiv_floor = jnp.asarray(5_000.0) + married_floor = consumption_equiv_floor * jnp.asarray(2.0) ** 0.7 + cash_on_hand = jnp.asarray(-50_000.0) + + admitted = bool( + borrowing_constraint( + consumption_unequiv=married_floor, + cash_on_hand=cash_on_hand, + consumption_unequiv_floor=married_floor, ) ) assert admitted @@ -41,17 +55,15 @@ def test_borrowing_constraint_rejects_consumption_unequiv_above_post_transfer_re None ): """`consumption_unequiv > max(cash_on_hand, floor)` is rejected.""" - consumption_unequiv_floor = 5_000.0 - equivalence_scale = jnp.asarray(1.0) + consumption_unequiv_floor = jnp.asarray(5_000.0) cash_on_hand = jnp.asarray(-50_000.0) - consumption_unequiv = jnp.asarray(consumption_unequiv_floor + 1.0) + consumption_unequiv = consumption_unequiv_floor + 1.0 admitted = bool( borrowing_constraint( consumption_unequiv=consumption_unequiv, cash_on_hand=cash_on_hand, consumption_unequiv_floor=consumption_unequiv_floor, - equivalence_scale=equivalence_scale, ) ) assert not admitted @@ -66,17 +78,15 @@ def test_borrowing_constraint_admits_floor_at_million_dollar_negative_cash() -> for the lowest consumption_unequiv gridpoint. The `max(cash_on_hand, floor)` form returns `floor` exactly. """ - consumption_unequiv_floor = 1597.0921419521899 # production value - equivalence_scale = jnp.asarray(1.0) + consumption_unequiv_floor = jnp.asarray(1597.0921419521899) # production value cash_on_hand = jnp.asarray(-1_000_000.0) - consumption_unequiv = jnp.asarray(consumption_unequiv_floor) # lowest grid point + consumption_unequiv = consumption_unequiv_floor # lowest grid point admitted = bool( borrowing_constraint( consumption_unequiv=consumption_unequiv, cash_on_hand=cash_on_hand, consumption_unequiv_floor=consumption_unequiv_floor, - equivalence_scale=equivalence_scale, ) ) assert admitted @@ -86,7 +96,7 @@ def test_extreme_negative_assets_subject_passes_validation() -> None: """A subject placed at `assets = -1_000_000` clears initial-conditions validation. HRS bottom-codes very-large-negative net wealth at exactly $-1{,}000{,}000$. - Such subjects should remain in the simulated population: the consumption_unequiv + Such subjects should remain in the simulated population: the consumption floor / transfer system absorbs them, with `c = c_floor` always feasible. """ n_subjects = 1 diff --git a/tests/test_model_components.py b/tests/test_model_components.py index 3153195..9c14aee 100644 --- a/tests/test_model_components.py +++ b/tests/test_model_components.py @@ -78,12 +78,11 @@ def test_leisure_bad_health() -> None: def test_utility_positive_leisure() -> None: result = preferences.utility( - consumption_unequiv=jnp.array(10000.0), + consumption_equiv=jnp.array(10000.0), leisure=jnp.array(3000.0), pref_type=jnp.array(0), - consumption_weight=jnp.array([0.4, 0.4, 0.4]), - coefficient_rra=jnp.array([2.0, 2.0, 2.0]), - equivalence_scale=jnp.array(1.0), + consumption_weights=jnp.array([0.4, 0.4, 0.4]), + coefficients_rra=jnp.array([2.0, 2.0, 2.0]), utility_scale_factor=jnp.array(1.0), ) assert jnp.isfinite(result) @@ -91,12 +90,11 @@ def test_utility_positive_leisure() -> None: def test_utility_log_case() -> None: result = preferences.utility( - consumption_unequiv=jnp.array(10000.0), + consumption_equiv=jnp.array(10000.0), leisure=jnp.array(3000.0), pref_type=jnp.array(0), - consumption_weight=jnp.array([0.4, 0.4, 0.4]), - coefficient_rra=jnp.array([1.0, 1.0, 1.0]), - equivalence_scale=jnp.array(1.0), + consumption_weights=jnp.array([0.4, 0.4, 0.4]), + coefficients_rra=jnp.array([1.0, 1.0, 1.0]), utility_scale_factor=jnp.array(1.0), ) composite = 10000.0**0.4 * 3000.0**0.6 @@ -110,8 +108,8 @@ def test_bequest_positive_assets() -> None: pref_type=jnp.array(0), bequest_shifter=5000.0, scaled_bequest_weight=0.5, - consumption_weight=jnp.array([0.4, 0.4, 0.4]), - coefficient_rra=jnp.array([2.0, 2.0, 2.0]), + consumption_weights=jnp.array([0.4, 0.4, 0.4]), + coefficients_rra=jnp.array([2.0, 2.0, 2.0]), utility_scale_factor=jnp.array(1.0), ) assert jnp.isfinite(result) @@ -123,8 +121,8 @@ def test_bequest_zero_assets() -> None: pref_type=jnp.array(0), bequest_shifter=5000.0, scaled_bequest_weight=0.5, - consumption_weight=jnp.array([0.4, 0.4, 0.4]), - coefficient_rra=jnp.array([2.0, 2.0, 2.0]), + consumption_weights=jnp.array([0.4, 0.4, 0.4]), + coefficients_rra=jnp.array([2.0, 2.0, 2.0]), utility_scale_factor=jnp.array(1.0), ) assert jnp.isfinite(result) diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index 0cf9655..4e31881 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -10,7 +10,7 @@ from aca_model.aca.regimes import build_all_regimes as _build_aca_regimes from aca_model.baseline.regimes import REGIME_SPECS, RegimeId from aca_model.baseline.regimes import build_regime as _build_regime -from aca_model.baseline.regimes._common import MAX_CONSUMPTION_UNEQUIV, build_grids +from aca_model.baseline.regimes._common import build_grids from aca_model.config import GRID_CONFIG @@ -143,7 +143,7 @@ def test_pre65_regimes_use_health_with_disability() -> None: if spec["mc"] in ("nomc", "dimc"): regime = build_regime(name) grid = regime.states["health"] - assert len(grid.categories) == 3, f"{name} should use HealthWithDisability" # ty: ignore[unresolved-attribute] + assert len(grid.categories) == 3, f"{name} should use HealthWithDisability" def test_post65_regimes_use_health() -> None: @@ -151,7 +151,7 @@ def test_post65_regimes_use_health() -> None: if spec["mc"] == "oamc": regime = build_regime(name) grid = regime.states["health"] - assert len(grid.categories) == 2, f"{name} should use Health" # ty: ignore[unresolved-attribute] + assert len(grid.categories) == 2, f"{name} should use Health" def test_all_regimes_have_aime() -> None: @@ -270,6 +270,3 @@ def test_baseline_model_creates() -> None: assert len(model.regimes) == 19 -def test_max_consumption_unequiv_attached_from_canonical_constant() -> None: - model = make_baseline_model(n_subjects=1) - assert model.max_consumption_unequiv == MAX_CONSUMPTION_UNEQUIV diff --git a/tests/test_preferences.py b/tests/test_preferences.py index 2afe6b8..d67e0cd 100644 --- a/tests/test_preferences.py +++ b/tests/test_preferences.py @@ -35,8 +35,8 @@ def test_utility_scale_factor_crra() -> None: result = preferences.utility_scale_factor( pref_type=jnp.array(0), average_consumption_unequiv=AVERAGE_CONSUMPTION, - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_5_BY_TYPE, + consumption_weights=WEIGHT_BY_TYPE, + coefficients_rra=RRA_5_BY_TYPE, time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, @@ -51,8 +51,8 @@ def test_utility_scale_factor_log() -> None: result = preferences.utility_scale_factor( pref_type=jnp.array(0), average_consumption_unequiv=AVERAGE_CONSUMPTION, - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_1_BY_TYPE, + consumption_weights=WEIGHT_BY_TYPE, + coefficients_rra=RRA_1_BY_TYPE, time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, @@ -109,8 +109,8 @@ def test_utility_log_regression() -> None: scale = preferences.utility_scale_factor( pref_type=jnp.array(0), average_consumption_unequiv=AVERAGE_CONSUMPTION, - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_1_BY_TYPE, + consumption_weights=WEIGHT_BY_TYPE, + coefficients_rra=RRA_1_BY_TYPE, time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, @@ -119,12 +119,11 @@ def test_utility_log_regression() -> None: scale_reference_age=SCALE_REFERENCE_AGE, ) result = preferences.utility( - consumption_unequiv=jnp.array(50000.0), + consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), pref_type=jnp.array(0), - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_1_BY_TYPE, - equivalence_scale=jnp.array(1.0), + consumption_weights=WEIGHT_BY_TYPE, + coefficients_rra=RRA_1_BY_TYPE, utility_scale_factor=scale, ) assert jnp.isclose(result, 1.005_046_313_660_588_5, rtol=1e-5) @@ -134,8 +133,8 @@ def test_utility_crra_regression() -> None: scale = preferences.utility_scale_factor( pref_type=jnp.array(0), average_consumption_unequiv=AVERAGE_CONSUMPTION, - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_5_BY_TYPE, + consumption_weights=WEIGHT_BY_TYPE, + coefficients_rra=RRA_5_BY_TYPE, time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, @@ -144,12 +143,11 @@ def test_utility_crra_regression() -> None: scale_reference_age=SCALE_REFERENCE_AGE, ) result = preferences.utility( - consumption_unequiv=jnp.array(50000.0), + consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), pref_type=jnp.array(0), - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_5_BY_TYPE, - equivalence_scale=jnp.array(1.0), + consumption_weights=WEIGHT_BY_TYPE, + coefficients_rra=RRA_5_BY_TYPE, utility_scale_factor=scale, ) assert jnp.isclose(result, -0.836_511_642_073_019_1, rtol=1e-5) @@ -160,8 +158,8 @@ def test_utility_married_equivalence() -> None: scale = preferences.utility_scale_factor( pref_type=jnp.array(0), average_consumption_unequiv=AVERAGE_CONSUMPTION, - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_5_BY_TYPE, + consumption_weights=WEIGHT_BY_TYPE, + coefficients_rra=RRA_5_BY_TYPE, time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, @@ -170,21 +168,19 @@ def test_utility_married_equivalence() -> None: scale_reference_age=SCALE_REFERENCE_AGE, ) single = preferences.utility( - consumption_unequiv=jnp.array(50000.0), + consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), pref_type=jnp.array(0), - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_5_BY_TYPE, - equivalence_scale=jnp.array(1.0), + consumption_weights=WEIGHT_BY_TYPE, + coefficients_rra=RRA_5_BY_TYPE, utility_scale_factor=scale, ) married = preferences.utility( - consumption_unequiv=jnp.array(50000.0 * 2**0.7), + consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), pref_type=jnp.array(0), - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_5_BY_TYPE, - equivalence_scale=jnp.array(2**0.7), + consumption_weights=WEIGHT_BY_TYPE, + coefficients_rra=RRA_5_BY_TYPE, utility_scale_factor=scale, ) assert jnp.isclose(single, married, rtol=1e-5) @@ -197,8 +193,8 @@ def test_bequest_log_regression() -> None: scale = preferences.utility_scale_factor( pref_type=jnp.array(0), average_consumption_unequiv=AVERAGE_CONSUMPTION, - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_1_BY_TYPE, + consumption_weights=WEIGHT_BY_TYPE, + coefficients_rra=RRA_1_BY_TYPE, time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, @@ -219,8 +215,8 @@ def test_bequest_log_regression() -> None: pref_type=jnp.array(0), bequest_shifter=BEQUEST_SHIFTER, scaled_bequest_weight=bwt.item(), - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_1_BY_TYPE, + consumption_weights=WEIGHT_BY_TYPE, + coefficients_rra=RRA_1_BY_TYPE, utility_scale_factor=scale, ) assert jnp.isclose(result, 86.539_249_963_643_88, rtol=1e-5) @@ -230,8 +226,8 @@ def test_bequest_crra_regression() -> None: scale = preferences.utility_scale_factor( pref_type=jnp.array(0), average_consumption_unequiv=AVERAGE_CONSUMPTION, - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_5_BY_TYPE, + consumption_weights=WEIGHT_BY_TYPE, + coefficients_rra=RRA_5_BY_TYPE, time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, @@ -252,8 +248,8 @@ def test_bequest_crra_regression() -> None: pref_type=jnp.array(0), bequest_shifter=BEQUEST_SHIFTER, scaled_bequest_weight=bwt.item(), - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_5_BY_TYPE, + consumption_weights=WEIGHT_BY_TYPE, + coefficients_rra=RRA_5_BY_TYPE, utility_scale_factor=scale, ) assert jnp.isclose(result, -37.932_748_117_035_63, rtol=1e-5) From 9ac20430f499a8b1cdb056af85bc2a26e850bad2 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 10 May 2026 20:07:35 +0200 Subject: [PATCH 45/54] tests: wrap Python-int kwargs in jnp.int32 to satisfy strict pylcm types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pylcm's typing aliases (Period, IntND) are now strict — Python ints no longer satisfy them. Wrap kwarg literals in test_pensions, test_pension_integration, test_social_security, test_ss_benefit_integration, and test_model_components. Hoist top-level integer constants (PERIOD, his/old_his/new_his) to jnp.int32. Convert _RATIO_NP to pd.Series at the compute_di_dropout_scale call site. aca-model: drop the leftover MAX_CONSUMPTION_UNEQUIV import + the test_max_consumption_unequiv_attached_from_canonical_constant test (the dynamic Model attribute is gone). Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_model_components.py | 4 +- tests/test_pension_integration.py | 12 ++--- tests/test_pensions.py | 60 ++++++++++++------------ tests/test_social_security.py | 70 +++++++++++++++------------- tests/test_ss_benefit_integration.py | 34 +++++++------- 5 files changed, 93 insertions(+), 87 deletions(-) diff --git a/tests/test_model_components.py b/tests/test_model_components.py index 9c14aee..1375a8a 100644 --- a/tests/test_model_components.py +++ b/tests/test_model_components.py @@ -164,8 +164,8 @@ def test_next_aime_accrual() -> None: result = social_security.next_aime( aime=jnp.array(1000.0), labor_income=jnp.array(50000.0), - period=55, - age=55, + period=jnp.int32(55), + age=jnp.int32(55), benefit_withheld_fraction=jnp.array(0.0), earnings_test_credited_back=jnp.zeros(100), earnings_test_repealed_age=70, diff --git a/tests/test_pension_integration.py b/tests/test_pension_integration.py index 9a9176d..0fabb35 100644 --- a/tests/test_pension_integration.py +++ b/tests/test_pension_integration.py @@ -18,7 +18,7 @@ # HIS 0 (retiree): intercept = -50, HIS 1 (nongroup): intercept = -80. N_PERIODS = 30 N_HIS = 2 -PERIOD = 20 +PERIOD = jnp.int32(20) _intercept = jnp.zeros((N_PERIODS, N_HIS)) _intercept = _intercept.at[PERIOD, 0].set(-50.0) @@ -62,7 +62,7 @@ def test_benefit_wealth_dag() -> None: result = combined( pia=jnp.array(500.0), period=PERIOD, - his=0, + his=jnp.int32(0), epdv_constant_pension=EPDV, **IMP_KWARGS, ) @@ -80,7 +80,7 @@ def test_total_to_pia_inverts_benefit_via_dag() -> None: recovered = combined( pia=jnp.array(8000.0), period=PERIOD, - his=0, + his=jnp.int32(0), marginal_tax_rate=jnp.array(0.2), **IMP_KWARGS, ) @@ -103,7 +103,7 @@ def test_next_assets_includes_pension_adjustment() -> None: def test_zero_adjustment_when_his_unchanged() -> None: """Pension adjustment is zero when HIS doesn't change.""" - his = 0 + his = jnp.int32(0) pia = jnp.array(8000.0) labor_income = jnp.array(30_000.0) mtr = jnp.array(0.2) @@ -149,8 +149,8 @@ def test_rebalancing_preserves_total_wealth_across_his_change() -> None: from HIS 0 (retiree) to HIS 1 (nongroup), the pension imputation changes. The assets_adjustment compensates so total wealth is preserved. """ - old_his = 0 - new_his = 1 + old_his = jnp.int32(0) + new_his = jnp.int32(1) pia = jnp.array(8000.0) labor_income = jnp.array(30_000.0) mtr = jnp.array(0.0) diff --git a/tests/test_pensions.py b/tests/test_pensions.py index beff910..514ab8c 100644 --- a/tests/test_pensions.py +++ b/tests/test_pensions.py @@ -42,8 +42,8 @@ def test_pension_benefit_zero_pia() -> None: result = pensions.benefit( pia=jnp.array(0.0), - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, imp_pia_kink_0_coeff=PW_IMP_PIA_KINK_0_COEFF, @@ -59,8 +59,8 @@ def test_pension_benefit_zero_pia() -> None: def test_pension_benefit_below_kink_0() -> None: result = pensions.benefit( pia=jnp.array(500.0), - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, imp_pia_kink_0_coeff=PW_IMP_PIA_KINK_0_COEFF, @@ -75,8 +75,8 @@ def test_pension_benefit_below_kink_0() -> None: def test_pension_benefit_between_kinks() -> None: result = pensions.benefit( pia=jnp.array(12000.0), - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, imp_pia_kink_0_coeff=PW_IMP_PIA_KINK_0_COEFF, @@ -92,8 +92,8 @@ def test_pension_benefit_between_kinks() -> None: def test_pension_benefit_above_kink_1() -> None: result = pensions.benefit( pia=jnp.array(20000.0), - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, imp_pia_kink_0_coeff=PW_IMP_PIA_KINK_0_COEFF, @@ -109,8 +109,8 @@ def test_pension_benefit_above_kink_1() -> None: def test_pension_accrual_no_income() -> None: result = pensions.accrual( labor_income=jnp.array(-1000.0), - period=20, - his=0, + period=jnp.int32(20), + his=jnp.int32(0), accrual_intercept=ACCRUAL_INTERCEPT, accrual_log_earnings=ACCRUAL_LOG_EARNINGS, accrual_prob_intercept=ACCRUAL_PROB_INTERCEPT, @@ -123,8 +123,8 @@ def test_pension_accrual_no_income() -> None: def test_pension_accrual_positive() -> None: result = pensions.accrual( labor_income=jnp.array(10000.0), - period=20, - his=0, + period=jnp.int32(20), + his=jnp.int32(0), accrual_intercept=ACCRUAL_INTERCEPT, accrual_log_earnings=ACCRUAL_LOG_EARNINGS, accrual_prob_intercept=ACCRUAL_PROB_INTERCEPT, @@ -148,7 +148,7 @@ def test_pension_wealth_next_accrual_only() -> None: pension_accrual=jnp.array(accrual), rate_of_return=r, unconditional_survival_prob=SURVIVAL_PROBS, - period=28, + period=jnp.int32(28), ) assert jnp.isclose(result, accrual / 0.99, atol=ATOL) @@ -164,7 +164,7 @@ def test_pension_wealth_next_with_benefit() -> None: pension_accrual=jnp.array(accrual), rate_of_return=r, unconditional_survival_prob=SURVIVAL_PROBS, - period=29, + period=jnp.int32(29), ) expected = ((1 + r) * 3000 + accrual - 2000) / 0.98 assert jnp.isclose(result, expected, atol=ATOL) @@ -177,8 +177,8 @@ def test_convert_total_ben_to_pia_below_kink_0() -> None: pb = pensions.benefit( pia=pia_input, - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, imp_pia_kink_0_coeff=PW_IMP_PIA_KINK_0_COEFF, @@ -190,8 +190,8 @@ def test_convert_total_ben_to_pia_below_kink_0() -> None: recovered = pensions.total_to_pia( pension_benefit=pb, pia=pia_input, - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), marginal_tax_rate=mtr, imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, @@ -210,8 +210,8 @@ def test_convert_total_ben_to_pia_between_kinks() -> None: pb = pensions.benefit( pia=pia_input, - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, imp_pia_kink_0_coeff=PW_IMP_PIA_KINK_0_COEFF, @@ -223,8 +223,8 @@ def test_convert_total_ben_to_pia_between_kinks() -> None: recovered = pensions.total_to_pia( pension_benefit=pb, pia=pia_input, - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), marginal_tax_rate=mtr, imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, @@ -243,8 +243,8 @@ def test_convert_total_ben_to_pia_above_kink_1() -> None: pb = pensions.benefit( pia=pia_input, - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, imp_pia_kink_0_coeff=PW_IMP_PIA_KINK_0_COEFF, @@ -256,8 +256,8 @@ def test_convert_total_ben_to_pia_above_kink_1() -> None: recovered = pensions.total_to_pia( pension_benefit=pb, pia=pia_input, - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), marginal_tax_rate=mtr, imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, @@ -276,8 +276,8 @@ def test_convert_total_ben_to_pia_zero_mtr() -> None: pb = pensions.benefit( pia=pia_input, - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, imp_pia_kink_0_coeff=PW_IMP_PIA_KINK_0_COEFF, @@ -289,8 +289,8 @@ def test_convert_total_ben_to_pia_zero_mtr() -> None: recovered = pensions.total_to_pia( pension_benefit=pb, pia=pia_input, - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), marginal_tax_rate=mtr, imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, diff --git a/tests/test_social_security.py b/tests/test_social_security.py index 90c5128..835b554 100644 --- a/tests/test_social_security.py +++ b/tests/test_social_security.py @@ -5,6 +5,7 @@ import jax.numpy as jnp import numpy as np +import pandas as pd from helpers.social_security import compute_di_dropout_scale, compute_pia_table from aca_model.agent.labor_market import LaborSupply @@ -56,7 +57,12 @@ RATIO = jnp.array(_RATIO_NP) DI_SCALE = jnp.array( - compute_di_dropout_scale(_RATIO_NP, AIME_ACCRUAL_FACTOR, start_age=0, n_periods=100) + compute_di_dropout_scale( + pd.Series(_RATIO_NP), + AIME_ACCRUAL_FACTOR, + start_age=jnp.int32(0), + n_periods=100, + ) ) # Pre-computed PIA lookup table (4-point exact grid) @@ -119,13 +125,13 @@ def test_next_aime_indexing_high_income() -> None: result = social_security.next_aime( benefit_withheld_fraction=jnp.array(0.0), earnings_test_credited_back=jnp.zeros(100), - earnings_test_repealed_age=70, + earnings_test_repealed_age=jnp.int32(70), pia_table=PIA_TABLE, pia_aime_grid=PIA_AIME_GRID, aime=jnp.array(1000.0), labor_income=jnp.array(20000.0), - period=58, - age=58, + period=jnp.int32(58), + age=jnp.int32(58), aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, @@ -140,13 +146,13 @@ def test_next_aime_indexing_low_income() -> None: result = social_security.next_aime( benefit_withheld_fraction=jnp.array(0.0), earnings_test_credited_back=jnp.zeros(100), - earnings_test_repealed_age=70, + earnings_test_repealed_age=jnp.int32(70), pia_table=PIA_TABLE, pia_aime_grid=PIA_AIME_GRID, aime=jnp.array(10000.0), labor_income=jnp.array(510.0), - period=58, - age=58, + period=jnp.int32(58), + age=jnp.int32(58), aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, @@ -160,13 +166,13 @@ def test_next_aime_no_indexing_high_income() -> None: result = social_security.next_aime( benefit_withheld_fraction=jnp.array(0.0), earnings_test_credited_back=jnp.zeros(100), - earnings_test_repealed_age=70, + earnings_test_repealed_age=jnp.int32(70), pia_table=PIA_TABLE, pia_aime_grid=PIA_AIME_GRID, aime=jnp.array(1000.0), labor_income=jnp.array(20000.0), - period=62, - age=62, + period=jnp.int32(62), + age=jnp.int32(62), aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, @@ -181,13 +187,13 @@ def test_next_aime_no_indexing_low_income() -> None: result = social_security.next_aime( benefit_withheld_fraction=jnp.array(0.0), earnings_test_credited_back=jnp.zeros(100), - earnings_test_repealed_age=70, + earnings_test_repealed_age=jnp.int32(70), pia_table=PIA_TABLE, pia_aime_grid=PIA_AIME_GRID, aime=jnp.array(1000.0), labor_income=jnp.array(99.0), - period=62, - age=62, + period=jnp.int32(62), + age=jnp.int32(62), aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, @@ -201,13 +207,13 @@ def test_next_aime_cap_high_aime_high_income() -> None: result = social_security.next_aime( benefit_withheld_fraction=jnp.array(0.0), earnings_test_credited_back=jnp.zeros(100), - earnings_test_repealed_age=70, + earnings_test_repealed_age=jnp.int32(70), pia_table=PIA_TABLE, pia_aime_grid=PIA_AIME_GRID, aime=jnp.array(40000.0), labor_income=jnp.array(20000.0), - period=62, - age=62, + period=jnp.int32(62), + age=jnp.int32(62), aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, @@ -221,13 +227,13 @@ def test_next_aime_cap_high_aime_low_income() -> None: result = social_security.next_aime( benefit_withheld_fraction=jnp.array(0.0), earnings_test_credited_back=jnp.zeros(100), - earnings_test_repealed_age=70, + earnings_test_repealed_age=jnp.int32(70), pia_table=PIA_TABLE, pia_aime_grid=PIA_AIME_GRID, aime=jnp.array(40000.0), labor_income=jnp.array(3500.0), - period=62, - age=62, + period=jnp.int32(62), + age=jnp.int32(62), aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, @@ -256,7 +262,7 @@ def test_pia_lookup_matches_formula() -> None: def test_ssdi_pia_matches_dropout_adjusted() -> None: """ssdi_pia lookup matches aime_to_pia(aime * di_dropout_scale[period]).""" aime = jnp.array(5000.0) - period = 55 + period = jnp.int32(55) adjusted_aime = aime * DI_SCALE[period] lookup = social_security.ssdi_pia( @@ -292,17 +298,17 @@ def test_benefit_choose_post65_below_et_threshold() -> None: ) result = social_security.benefit_choose_post65( pia=pia_val, - age=67, - period=0, + age=jnp.int32(67), + period=jnp.int32(0), claim_ss=jnp.array(ClaimedSS.yes), claimed_ss=jnp.array(ClaimedSS.no), labor_supply=jnp.array(LaborSupply.h2000), labor_income=jnp.array(4000.0), early_ret_adjustment=jnp.array([1.0]), - normal_retirement_age=66, + normal_retirement_age=jnp.int32(66), earnings_test_threshold=jnp.array([10000.0]), earnings_test_fraction=jnp.array([0.0]), - earnings_test_repealed_age=70, + earnings_test_repealed_age=jnp.int32(70), ) assert jnp.isclose(result, pia_val, atol=ATOL) @@ -316,17 +322,17 @@ def test_benefit_choose_post65_partially_reduced() -> None: ) result = social_security.benefit_choose_post65( pia=pia_val, - age=60, - period=0, + age=jnp.int32(60), + period=jnp.int32(0), claim_ss=jnp.array(ClaimedSS.yes), claimed_ss=jnp.array(ClaimedSS.no), labor_supply=jnp.array(LaborSupply.h2000), labor_income=jnp.array(6000.0), early_ret_adjustment=jnp.array([1.0]), - normal_retirement_age=66, + normal_retirement_age=jnp.int32(66), earnings_test_threshold=jnp.array([2000.0]), earnings_test_fraction=jnp.array([0.2]), - earnings_test_repealed_age=70, + earnings_test_repealed_age=jnp.int32(70), ) expected = pia_val - (6000 - 2000) * 0.2 assert jnp.isclose(result, expected, atol=ATOL) @@ -336,7 +342,7 @@ def test_benefit_inelig_pre65_disabled_below_sga() -> None: """Disabled agent below SGA: benefit = ssdi_pia.""" ssdi_val = social_security.ssdi_pia( aime=jnp.array(5000.0), - period=55, + period=jnp.int32(55), di_dropout_scale=DI_SCALE, pia_table=PIA_TABLE, pia_aime_grid=PIA_AIME_GRID, @@ -354,7 +360,7 @@ def test_benefit_inelig_pre65_disabled_above_sga() -> None: """Disabled agent above SGA: benefit = 0.""" ssdi_val = social_security.ssdi_pia( aime=jnp.array(5000.0), - period=55, + period=jnp.int32(55), di_dropout_scale=DI_SCALE, pia_table=PIA_TABLE, pia_aime_grid=PIA_AIME_GRID, @@ -385,12 +391,12 @@ def test_benefit_inelig_pre65_not_disabled() -> None: def test_di_dropout_round_trip_zero_years() -> None: aime = jnp.array(10000.0) scaled = aime * DI_SCALE[52] - round_tripped = social_security.adjust_aime_di_dropout_inv(52, scaled, DI_SCALE) + round_tripped = social_security.adjust_aime_di_dropout_inv(jnp.int32(52), scaled, DI_SCALE) assert jnp.isclose(aime, round_tripped, atol=ATOL) def test_di_dropout_round_trip_positive_years() -> None: aime = jnp.array(10000.0) scaled = aime * DI_SCALE[62] - round_tripped = social_security.adjust_aime_di_dropout_inv(62, scaled, DI_SCALE) + round_tripped = social_security.adjust_aime_di_dropout_inv(jnp.int32(62), scaled, DI_SCALE) assert jnp.isclose(aime, round_tripped, rtol=0.0002) diff --git a/tests/test_ss_benefit_integration.py b/tests/test_ss_benefit_integration.py index ef09775..5e74e9a 100644 --- a/tests/test_ss_benefit_integration.py +++ b/tests/test_ss_benefit_integration.py @@ -43,7 +43,7 @@ def test_earnings_test_reduces_benefit_before_fra() -> None: n_periods = 45 ssdi_pia_val = social_security.ssdi_pia( aime=jnp.array(3000.0), - period=12, + period=jnp.int32(12), di_dropout_scale=jnp.ones(n_periods + 1), pia_table=PIA_TABLE, pia_aime_grid=PIA_AIME_GRID, @@ -52,36 +52,36 @@ def test_earnings_test_reduces_benefit_before_fra() -> None: benefit_working = social_security.benefit_choose_pre65( pia=pia_val, ssdi_pia=ssdi_pia_val, - age=63, - period=0, + age=jnp.int32(63), + period=jnp.int32(0), claim_ss=jnp.array(ClaimedSS.yes), claimed_ss=jnp.array(ClaimedSS.no), health=jnp.array(2), labor_supply=jnp.array(LaborSupply.h2000), labor_income=jnp.array(30000.0), early_ret_adjustment=jnp.array([0.75]), - normal_retirement_age=66, + normal_retirement_age=jnp.int32(66), earnings_test_threshold=jnp.array([17640.0]), earnings_test_fraction=jnp.array([0.5]), - earnings_test_repealed_age=66, + earnings_test_repealed_age=jnp.int32(66), ssdi_substantial_gainful_activity=13560.0, ) benefit_not_working = social_security.benefit_choose_pre65( pia=pia_val, ssdi_pia=ssdi_pia_val, - age=63, - period=0, + age=jnp.int32(63), + period=jnp.int32(0), claim_ss=jnp.array(ClaimedSS.yes), claimed_ss=jnp.array(ClaimedSS.no), health=jnp.array(2), labor_supply=jnp.array(LaborSupply.do_not_work), labor_income=jnp.array(0.0), early_ret_adjustment=jnp.array([0.75]), - normal_retirement_age=66, + normal_retirement_age=jnp.int32(66), earnings_test_threshold=jnp.array([17640.0]), earnings_test_fraction=jnp.array([0.5]), - earnings_test_repealed_age=66, + earnings_test_repealed_age=jnp.int32(66), ssdi_substantial_gainful_activity=13560.0, ) @@ -99,32 +99,32 @@ def test_earnings_test_not_applied_after_fra() -> None: benefit_post65 = social_security.benefit_choose_post65( pia=pia_val, - age=67, - period=0, + age=jnp.int32(67), + period=jnp.int32(0), claim_ss=jnp.array(ClaimedSS.yes), claimed_ss=jnp.array(ClaimedSS.no), labor_supply=jnp.array(LaborSupply.h2000), labor_income=jnp.array(50000.0), early_ret_adjustment=jnp.array([1.0]), - normal_retirement_age=66, + normal_retirement_age=jnp.int32(66), earnings_test_threshold=jnp.array([17640.0]), earnings_test_fraction=jnp.array([0.5]), - earnings_test_repealed_age=66, + earnings_test_repealed_age=jnp.int32(66), ) benefit_not_working = social_security.benefit_choose_post65( pia=pia_val, - age=67, - period=0, + age=jnp.int32(67), + period=jnp.int32(0), claim_ss=jnp.array(ClaimedSS.yes), claimed_ss=jnp.array(ClaimedSS.no), labor_supply=jnp.array(LaborSupply.do_not_work), labor_income=jnp.array(0.0), early_ret_adjustment=jnp.array([1.0]), - normal_retirement_age=66, + normal_retirement_age=jnp.int32(66), earnings_test_threshold=jnp.array([17640.0]), earnings_test_fraction=jnp.array([0.5]), - earnings_test_repealed_age=66, + earnings_test_repealed_age=jnp.int32(66), ) assert jnp.isclose(benefit_post65, benefit_not_working, atol=ATOL) From 987e86ec8821a9fe52015b1e8cf8a9ae7c26f66d Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 11 May 2026 09:01:15 +0200 Subject: [PATCH 46/54] Phase 1: docstring + naming cleanups MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - `next_assets_terminal` → `next_assets_when_dead` (definition + register site in `_build_per_target_next_assets`). - `consumption_unequiv_floor`, `transfers`, `consumption_equiv`, `utility`: drop stale formula / cross-reference paragraphs. - `borrowing_constraint` + tests: HRS phrasing → "large negative values of `assets`"; drop `pension_assets_adjustment` paragraph. - `inject_consumption_unequiv_points`: walk the canonical `consumption_unequiv` action on each non-terminal regime, raise if it is missing or not a runtime-points IrregSpacedGrid. - `_compute_consumption_unequiv_points` docstring: drop the jax-arithmetic paragraph. - `baseline/model.py`: replace the verbose `target_his` block with one line. - Pre-existing PLC0415 in-function imports moved to module-level. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/agent/assets_and_income.py | 34 ++++--------- src/aca_model/agent/preferences.py | 19 ++----- src/aca_model/baseline/model.py | 6 +-- src/aca_model/baseline/regimes/_common.py | 4 +- src/aca_model/consumption_unequiv_grid.py | 50 ++++++++++++------- tests/test_benchmark.py | 7 +-- tests/test_consumption_unequiv_grid.py | 7 ++- .../test_initial_conditions_extreme_assets.py | 16 +++--- tests/test_model_creation.py | 2 - tests/test_social_security.py | 8 ++- 10 files changed, 67 insertions(+), 86 deletions(-) diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index 8adfbd5..844369c 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -39,15 +39,7 @@ def consumption_unequiv_floor( consumption_equiv_floor: float, equivalence_scale: FloatND, ) -> FloatND: - """Per-household $-floor on consumption. - - Lifts the per-equivalent floor parameter to the household-$ level - by scaling with `equivalence_scale`. Singles keep - `consumption_equiv_floor`, married households face - `consumption_equiv_floor * 2 ** exponent` — the same two values - that get pinned exactly on the runtime consumption_unequiv grid - (see `aca_model.consumption_unequiv_grid`). - """ + """Per-household $-floor on consumption.""" return consumption_equiv_floor * equivalence_scale @@ -55,10 +47,7 @@ def transfers( cash_on_hand: FloatND, consumption_unequiv_floor: FloatND, ) -> FloatND: - """Government transfers to enforce the consumption floor. - - tr = max{0, consumption_unequiv_floor - cash_on_hand} - """ + """Government transfers to enforce the consumption floor.""" return jnp.maximum(0.0, consumption_unequiv_floor - cash_on_hand) @@ -75,11 +64,15 @@ def next_assets( consumption choice does not condition on the HCC shock realization. """ return ( - cash_on_hand + transfers + pension_assets_adjustment - consumption_unequiv - oop_costs + cash_on_hand + + transfers + + pension_assets_adjustment + - consumption_unequiv + - oop_costs ) -def next_assets_terminal( +def next_assets_when_dead( cash_on_hand: FloatND, transfers: FloatND, consumption_unequiv: ContinuousAction, @@ -109,14 +102,7 @@ def borrowing_constraint( `cash_on_hand + transfers == max(cash_on_hand, floor)`; the `max` form is preferred because the additive form rounds to `floor + ε` (with `|ε| ~ ULP(|cash_on_hand|)`) at extreme cash, which flips - the kink-boundary comparison for HRS-bottom-coded subjects at - `assets=-$1{,}000{,}000`. The `max` form returns `floor` exactly. - - `pension_assets_adjustment` is excluded from the constraint: it can - be negative (e.g. when the imputation overstates next-period pension - wealth at a cross-HIS transition), and including it here can leave - no feasible action at low-asset / mid-AIME corners. The correction - enters `next_assets` instead — a post-decision shift that does not - gate the current consumption choice. + the kink-boundary comparison at large negative values of `assets`. + The `max` form returns `floor` exactly. """ return consumption_unequiv <= jnp.maximum(cash_on_hand, consumption_unequiv_floor) diff --git a/src/aca_model/agent/preferences.py b/src/aca_model/agent/preferences.py index 212802b..ce77014 100644 --- a/src/aca_model/agent/preferences.py +++ b/src/aca_model/agent/preferences.py @@ -125,7 +125,7 @@ def consumption_equiv( consumption_unequiv: ContinuousAction, equivalence_scale: FloatND, ) -> FloatND: - """Per-equivalent consumption: total $ divided by household equivalence scale.""" + """Utility-equivalized consumption.""" return consumption_unequiv / equivalence_scale @@ -137,22 +137,11 @@ def utility( coefficients_rra: FloatND, utility_scale_factor: FloatND, ) -> FloatND: - """Within-period utility: CES aggregator over consumption and leisure. - - u = utility_scale_factor * - (consumption_equiv^consumption_weight * leisure^(1 - consumption_weight))^(1 - coefficient_rra) - / (1 - coefficient_rra) - with log case for coefficient_rra=1. `consumption_weights` and - `coefficients_rra` are pref-type-indexed Series sourced directly - from params; `utility_scale_factor` is a regime-function output - (already a per-cell scalar — must NOT be re-indexed by pref_type, - see `aca_model.agent.preferences.utility_scale_factor` for why). - """ + """Within-period utility: CES aggregator over consumption and leisure.""" consumption_weight = consumption_weights[pref_type] coefficient_rra = coefficients_rra[pref_type] - composite = ( - consumption_equiv**consumption_weight - * leisure ** (1.0 - consumption_weight) + composite = consumption_equiv**consumption_weight * leisure ** ( + 1.0 - consumption_weight ) one_minus_rra = jnp.where( diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index f1f822f..a278bce 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -70,11 +70,7 @@ def create_model( pref_type_grid=pref_type_grid, ) - # `target_his` is a DAG output of `health_insurance.target_his` (set on - # nongroup/tied/retiree regimes). The pension imputation correction - # (`imputed_pension_wealth_next_period`) indexes shifted arrays by - # `arr[period, target_his]`; pylcm needs the categorical declared so - # `pd.Series` fixed_params with a `target_his` index level resolve. + # `target_his` is a state subsumed into regimes. base_derived: dict[str, DiscreteGrid] = { "target_his": DiscreteGrid(HealthInsuranceState), } diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index f5174e6..16b04bd 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -666,7 +666,7 @@ def _build_per_target_next_assets( ) -> dict[RegimeName, Callable[..., FloatND]]: """Build per-target assets transitions. - The `dead` target uses `next_assets_terminal` (no + The `dead` target uses `next_assets_when_dead` (no `pension_assets_adjustment`), so the dead per-target DAG does not pull in the `next_aime`-dependent imputation chain — `dead` has no `aime` state and pylcm cannot resolve `next_aime` there. Non-dead @@ -687,7 +687,7 @@ def _build_per_target_next_assets( continue result[target_name] = assets_and_income.next_assets - result["dead"] = assets_and_income.next_assets_terminal + result["dead"] = assets_and_income.next_assets_when_dead return result diff --git a/src/aca_model/consumption_unequiv_grid.py b/src/aca_model/consumption_unequiv_grid.py index 4fcd225..ba1b74c 100644 --- a/src/aca_model/consumption_unequiv_grid.py +++ b/src/aca_model/consumption_unequiv_grid.py @@ -36,33 +36,49 @@ def inject_consumption_unequiv_points( ) -> dict[str, Any]: """Inject consumption_unequiv gridpoints into per-regime params. - Walks every regime, finds the action whose grid is an - `IrregSpacedGrid` with runtime-supplied points, and writes - `params[regime_name]["consumption_unequiv"] = {"points": }`. + Walks every regime, reads its `consumption_unequiv` action grid, + and writes `params[regime_name]["consumption_unequiv"] = {"points": }`. The lower two gridpoints are the single and married unequiv - transfer floors (`consumption_equiv_floor` and - `consumption_equiv_floor * 2 ** exponent`); the rest are - geomspaced from the married floor up to `MAX_CONSUMPTION_UNEQUIV`. + transfer floors; the rest are geomspaced from the married floor up + to `MAX_CONSUMPTION_UNEQUIV`. Args: params: Existing params mapping with `consumption_equiv_floor` (per-equivalent floor, varies per iteration). Returned as a new dict; the input is not mutated. - model: Model whose regime specs determine which regimes need points - and whose `fixed_params["exponent"]` sets the married + model: Model whose regimes carry the runtime-points grid and + whose `fixed_params["exponent"]` sets the married equivalence-scale exponent. Returns: New params dict with consumption_unequiv points injected. + + Raises: + ValueError: If a regime is missing the `consumption_unequiv` + action, or its grid is not an `IrregSpacedGrid` with + `pass_points_at_runtime=True`. """ consumption_equiv_floor = jnp.asarray(params["consumption_equiv_floor"]) exponent = jnp.asarray(model.fixed_params["exponent"]) out: dict[str, Any] = dict(params) for regime_name, regime in model.regimes.items(): + if regime.terminal: + continue grid = regime.actions.get("consumption_unequiv") + if grid is None: + msg = ( + f"Regime {regime_name!r} is missing the `consumption_unequiv` " + f"action — the runtime-points grid must be on every regime." + ) + raise ValueError(msg) if not (isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime): - continue + msg = ( + f"Regime {regime_name!r} has a `consumption_unequiv` action " + f"whose grid is not an `IrregSpacedGrid(pass_points_at_runtime=True)`; " + f"got {type(grid).__name__}." + ) + raise ValueError(msg) # Runtime-points grids always have `n_points` set (the constructor # rejects the (points=None, n_points=None) combo); narrow for ty. assert grid.n_points is not None @@ -86,18 +102,14 @@ def _compute_consumption_unequiv_points( """Return log-spaced consumption_unequiv gridpoints with both floors pinned. Single and married households face different unequiv (in-$) floors - (`consumption_equiv_floor` and `consumption_equiv_floor * - 2 ** exponent` respectively). Both must land exactly on the action - grid so the borrowing constraint's `max(cash_on_hand, floor)` kink - boundary is a feasible action; otherwise sub-ULP drift can flip - the `<=` comparison for subjects with very negative cash. The - geomspace tail starts at the married floor and runs to + (`consumption_equiv_floor` and the married-scaled twin + respectively). Both must land exactly on the action grid so the + borrowing constraint's `max(cash_on_hand, floor)` kink boundary is + a feasible action; otherwise sub-ULP drift can flip the `<=` + comparison for subjects with very negative cash. The geomspace + tail starts at the married floor and runs to `MAX_CONSUMPTION_UNEQUIV` so the two pinned points stay strictly increasing. - - All arithmetic stays in jax — multiplying `consumption_equiv_floor` - by `2 ** exponent` in jnp keeps both pinned floors at the canonical - float dtype the model uses everywhere else. """ married_unequiv_floor = consumption_equiv_floor * jnp.asarray(2.0) ** exponent tail = jnp.geomspace( diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index b626b50..649442e 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -1,5 +1,6 @@ """Integration test: the benchmark-sized baseline solves + simulates end-to-end.""" +import numpy as np import pytest from lcm import DiscreteGrid @@ -53,8 +54,6 @@ def test_benchmark_simulate_obeys_borrowing_constraint() -> None: sub-ULP at extreme `|cash_on_hand|`, so the post-hoc check would also flip on the same kink. """ - import numpy as np - n_subjects = 4 model = create_benchmark_model( n_subjects=n_subjects, @@ -73,9 +72,7 @@ def test_benchmark_simulate_obeys_borrowing_constraint() -> None: check_initial_conditions=False, ) - df = result.to_dataframe( - additional_targets=["cash_on_hand", "equivalence_scale"] - ) + df = result.to_dataframe(additional_targets=["cash_on_hand", "equivalence_scale"]) alive = df.loc[df["regime"] != "dead"].copy() consumption_unequiv_floor = float(params["consumption_unequiv_floor"]) floor = consumption_unequiv_floor * alive["equivalence_scale"].to_numpy() diff --git a/tests/test_consumption_unequiv_grid.py b/tests/test_consumption_unequiv_grid.py index 89e8bb9..92593b4 100644 --- a/tests/test_consumption_unequiv_grid.py +++ b/tests/test_consumption_unequiv_grid.py @@ -3,8 +3,8 @@ The borrowing constraint in `agent.assets_and_income.borrowing_constraint` compares the lowest consumption_unequiv action against `max(cash_on_hand, consumption_unequiv_floor)`. For subjects with cash -below the floor (HRS bottom-coded `assets=-$1{,}000{,}000$`, -moderate-negative-asset retirees, etc.) this RHS collapses to exactly +below the floor (large-negative-asset subjects, moderate-negative-asset +retirees, etc.) this RHS collapses to exactly `consumption_unequiv_floor`. The constraint is feasible iff the relevant household-floor gridpoint is `<= consumption_unequiv_floor`. @@ -30,6 +30,7 @@ import jax.numpy as jnp import pytest +from aca_model.baseline.regimes._common import MAX_CONSUMPTION_UNEQUIV from aca_model.consumption_unequiv_grid import _compute_consumption_unequiv_points EXPONENT = 0.7 # production value (env_constants["exponent"]) @@ -77,8 +78,6 @@ def test_compute_consumption_unequiv_points_strictly_increasing() -> None: def test_compute_consumption_unequiv_points_last_equals_max() -> None: """The final point is the configured upper bound.""" - from aca_model.baseline.regimes._common import MAX_CONSUMPTION_UNEQUIV - pts = _compute_consumption_unequiv_points( consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR), exponent=jnp.asarray(EXPONENT), diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py index d121539..7e583e3 100644 --- a/tests/test_initial_conditions_extreme_assets.py +++ b/tests/test_initial_conditions_extreme_assets.py @@ -72,11 +72,11 @@ def test_borrowing_constraint_rejects_consumption_unequiv_above_post_transfer_re def test_borrowing_constraint_admits_floor_at_million_dollar_negative_cash() -> None: """The kink-boundary check survives sub-ULP rounding at `|cash_on_hand| ~ 1e6`. - Reproduces the production failure mode at `assets=-$1{,}000{,}000$` (HRS - bottom-code): the algebraically equivalent `cash_on_hand + transfers` - form rounds to `floor - 5.7e-11` at fp64, flipping `consumption_unequiv <= ...` - for the lowest consumption_unequiv gridpoint. The `max(cash_on_hand, floor)` - form returns `floor` exactly. + At large negative `assets`, the algebraically equivalent + `cash_on_hand + transfers` form rounds to `floor - 5.7e-11` at fp64, + flipping `consumption_unequiv <= ...` for the lowest + consumption_unequiv gridpoint. The `max(cash_on_hand, floor)` form + returns `floor` exactly. """ consumption_unequiv_floor = jnp.asarray(1597.0921419521899) # production value cash_on_hand = jnp.asarray(-1_000_000.0) @@ -95,9 +95,9 @@ def test_borrowing_constraint_admits_floor_at_million_dollar_negative_cash() -> def test_extreme_negative_assets_subject_passes_validation() -> None: """A subject placed at `assets = -1_000_000` clears initial-conditions validation. - HRS bottom-codes very-large-negative net wealth at exactly $-1{,}000{,}000$. - Such subjects should remain in the simulated population: the consumption - floor / transfer system absorbs them, with `c = c_floor` always feasible. + A large-but-reasonable negative value (very bad draws for both HCC shocks) + should remain in the simulated population: the consumption floor / + transfer system absorbs them, with `c = c_floor` always feasible. """ n_subjects = 1 model = create_benchmark_model( diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index 4e31881..042e8a3 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -268,5 +268,3 @@ def test_baseline_model_creates() -> None: """Baseline model creates successfully without PolicyVariant.""" model = make_baseline_model(n_subjects=1) assert len(model.regimes) == 19 - - diff --git a/tests/test_social_security.py b/tests/test_social_security.py index 835b554..b8ac44a 100644 --- a/tests/test_social_security.py +++ b/tests/test_social_security.py @@ -391,12 +391,16 @@ def test_benefit_inelig_pre65_not_disabled() -> None: def test_di_dropout_round_trip_zero_years() -> None: aime = jnp.array(10000.0) scaled = aime * DI_SCALE[52] - round_tripped = social_security.adjust_aime_di_dropout_inv(jnp.int32(52), scaled, DI_SCALE) + round_tripped = social_security.adjust_aime_di_dropout_inv( + jnp.int32(52), scaled, DI_SCALE + ) assert jnp.isclose(aime, round_tripped, atol=ATOL) def test_di_dropout_round_trip_positive_years() -> None: aime = jnp.array(10000.0) scaled = aime * DI_SCALE[62] - round_tripped = social_security.adjust_aime_di_dropout_inv(jnp.int32(62), scaled, DI_SCALE) + round_tripped = social_security.adjust_aime_di_dropout_inv( + jnp.int32(62), scaled, DI_SCALE + ) assert jnp.isclose(aime, round_tripped, rtol=0.0002) From b36375b422ee09c7b24f554cc9959f7965a7be3a Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 11 May 2026 09:09:08 +0200 Subject: [PATCH 47/54] =?UTF-8?q?Phase=202:=20collapse=20scale=5Freference?= =?UTF-8?q?=5Fage=20into=20reference=5Fage;=20rename=20scale=5Freference?= =?UTF-8?q?=5Fhours=20=E2=86=92=20reference=5Fhours?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `utility_scale_factor` is a multiplicative constant — shifting the calibration target age rescales every utility uniformly, so the distinction between `scale_reference_age` and `reference_age` is unobservable. Drop the redundant parameter and the `fixed_cost_of_work_age_trend * age_offset` term it gated. Rename `scale_reference_hours` → `reference_hours` for consistency with `reference_age`. Regenerated benchmark snapshot and updated test_preferences.py regression values to reflect the new `average_leisure = 4500` (was 4000 under the old `age_offset=10` shift). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../_benchmark_data/benchmark_params.pkl | Bin 65428 -> 65398 bytes src/aca_model/agent/preferences.py | 12 +--- tests/test_preferences.py | 52 +++++------------- 3 files changed, 16 insertions(+), 48 deletions(-) diff --git a/src/aca_model/_benchmark_data/benchmark_params.pkl b/src/aca_model/_benchmark_data/benchmark_params.pkl index 29e65625274786cade19dce6e926e95d09500617..5297faa4ccb61966e2b4f91eeef3c4d1b206db92 100644 GIT binary patch delta 11069 zcmd5?33yXg7Vb+*7obgO(=F*jQz!{g3Mw*+(v*s51J%LL8OM)OS{nL#YZuZ4T!6-r zrDe)ePJmAtSz5#vqq=e^d^<=fGaBX^jE1r~hCH_1Y^=0S)m7;N=nt+0 zU1S>ISVGjqrkbqSBNMCUZrXBfzm0~4y`Et;Sr=K^=9)^Ab%wPztfDwSFTc3dkXK|Z zF0`@@UHpAAgE5dD z;hDkH5n7aAAK@lnCmKG*tNfLU@i_dS3b)~Z77d^2b$E01T?5bL=UUmRn)NILE#b4c z?3f54knyrI5wA+GypF~;pmCHcV{ma-#I-6NiyK~G%%4+YFjeN6D=G{|y25;?rOH;O z86Hu^I)_N7@=B%YG|Lg=HSXHNi9~;638m5o6RCo=+=bfaJx{l zK}h&UK-eNAj7m&K?$Z;g258lR)T$8##UR~tEjqgy!=^;FPvM^>R;G%@YS@u%nm%Bu zl}%r6)#w?$R4-X(Wt+g1bPvjFPD+St@Y?5V+J{RVTE$2POJGY=Y<8q1AfP4}XB&+# zB!(8TSQt#!0t6$srgFKpfYlKdwwagVmdnRhq-fYe5`=9oGfo9YeLph^GSulz0+g#q zOVkqhW8;WE5Hv+3C>sSSmE=C!Xe9U3hhW!{N!-koH4;gw1j!j!^n&1q&UygykDx+8 XP;-$1Z#!&9vCfnY5dwOe<%}*u^ldaA1sP zP$?z6L5mjd>X-wAt4#PleI0|#@(A@z)a;|lYHXyIREJN*n2ly*DRfG*5gQOy%EgAN zi?W5MP}Xjgw|Pp7gJK{b;45W2RWq5p`@W4-H%mqgw`kOQe6?visujBk)H0A0*pNd` z@*`~u*ry?Zt8oad#A;whA|+ft3-DdQXz;ZkJlfnPTM%FE3yEmC^u3z z&M7uZA+yN7oeu$%Gs?FX36PgJgjopl@>0{5A|u;2uf`WrLyHwyP`9_t=I}gg+6`2O zKYefDm~F_@1)?`g6Crg1hx7qbM=2X8QkN-G%H)Hz(MUd|58+Negbg|7f?5U4TdaWS zoaoyL$Tyzo8ITEqeF3B_in4J6c@OA{RB>K$gln_zxSg;Yo8L3+3i}4TnIh~gey}qY zr8Cr9C?HoJ47y<<+@T`EFx#NbsTRK)JIUv@W3C-^p&MCLgtEsL{xaIAz%YihMkbf^iW559Ffqd+o3G@Ad5ovdgo;B;rA;JUGOV8|C zd^dp63=n#hOn<->FP`MbN+Rv9&a-A+2WFO~XBOjAd^dn86allyA29XH`~qY5W4!=| ztE!&SqjTJ+=x%_rMFh&b{y_O`6}l)ce067UCp#m&=Qql_ljd0=!!#Hig5 zqwz2a8JVCnBl$p#@LgxCEXbcLFI$)Hdn?@95bu=@J=-r@I_=&E!^tA_rUdzMu;0FX zTZJYmTY3gwMc;tOiomP&2VT*3J%4jhJ8uroO9P>I2*J$(e%n3{qn1Y-HO!l^xhXE& z#}HwT#d+lr4+$0jO@5StY z7l{uHhR+*)gb>`_J6i2QnM55-(?7DI!9oY4h3A`n!MEtO`Ok- zm$#ZzT^DQa>*kFTNntRwOdcLlihOLB&I)`X^RR%Ieoe&auKnOCZuc7u!>oYUx%1i; zNYNkhcJ-b`2b*Z|0=MVQT|u)uU%l;Frd;fVeufTkOHj<=DU^#FPaT|{wkn(9YXLmmN4vR!ie4LY-YJXoj=L7xYYhnp(<=q|BXG zI-S;v*{AdYdb6Hcju~BZz8U}KRKQjc@A=UGeyUg9#MPdvgU1FcUJh&~9kk~78x~kj zPKws8{bp*`w)EdfAc}C2>dvW)bPF17T}cQfWgHG#`Ocn0Lr5sC>6+FY!QaPZpl}!& zWb0W!;60|HES(Idwag7Cmryn&3~Mzc97{bzWcVO*T2#Ww6Jzaa*6)p&@ZFvZozzT# zohKusbuYFYe{yfJzO$-TI_o=g>Cdj}N9to2W@pyeRoypsTF+IL2vtRWoe-I8RYrrU z*k~~16~fGuF}E4g)jZMxLiM5Kls}}^AsKs5kJHp0TL;OP6t@_1yUtyKybN}p?hWJx z{b$?`$l#h{otd|M)Y#$74-kkgOg`dZ!;X-HW&Q-xj!7cys!ELP`7;r z1J5}CtzCR^^wPR{og6Nn3v~}Davs0Ft+TVu)yb`2kS zEOaQ~EAklf^$mOE;q2}w@>q+93Y`*$4sNmU2A<*}@_2=(2z7Qv9t&Z`#l<~S-xhiN zjWTye9tUZSk31H3J@VM`U2iO`s{bs^?)s{6R_5CHvhuPbW4XB^uROnWPTm~0yu8wo zHzyy?e?QRfan`?7K6}ZH|9RRV>s8M_-tk0M=M92eIG5i9eBkt^>9JK;rW1aHIK!dt z1y15ETn(RJ>J7jJ{Rgo9^s9zzKWLds82aOAJCZ$dHMhnPNE=#%Gp-|*t{Tkz$pLdu z3<1N^{x$^E)rb0Pe|DTpUeOP6IC|LnAH}V(i#2WWGB?0sZ`&RI_FIU~zTc>xw+PsF zmnDc{#y=(Dz2(B?4r!AfKU7htFu~CFOWc%j{FcI&Q5k|SBmxWf50|2`&F1C>Ni_!B zTn;w@bl2?(iSe-gmu#MON*7o|+Es0zD0H7lp&V356?Z3XGm^XLLzp}m53*|yc1y!dpn9yRDDO;@UDmn drOB>o4U~`bFh50W{9*p2`(f?}IJ0@o{{d?p%hdn? delta 11242 zcmdT~3tUuH8s`pg2?v>xff*PWKpYfBEqiS#h%0N8nmw%b7vi8-W-c(`FeB9uR9Zzz zH0T?ut*lg1w9>Zg{&KbOm;F#MGj~@P(Jf)CUANj@ANg#T_B;39d0gfxVapCb&b{}X z^S%D(JKy>4H*>i!*!Dwk$jZ>>`rm}S7ZyIUa*@82E`;~PlnNPtZ>E>fX4C zs4Y>z{#mIK!v$!Mnn0aqFGoE|Q9V!*t+3C(=A7S0=F1J|3}(Z5gSEzHsKWmz*H$9+ z>F8{vjv6kaPL#$_K|v0BW+ddwRWeWHm!)@5-R$_7y;N|I?pDymjYRQOY@&3uh(R`F z#d>BiKQjyq`D7%2L*9@4F%fxzSMtWV+lHOi7uXE>+UE@vY#P3R z%~Hg=0+}c+7xA)q<@HYd7ODm2C>7KWC{5^8+E;scky*c}gtk-`S}Q7OGcLI>saV-b z?ShG_cn4XLiu7b9COUfbiKJ(2(&&k+QH{)6+iYApC9@tM&=r_xiEi#L$GhpTk~Hap z63mOfCU&??+~jLwv|5AOWUJ%7ZjA-pjFswCl@EPkO0oius~;aO){b-|cJq?&LLrUG zNeBYpbCC4Bo3etUqG6>bK`WO81=Saz`3k&Y;z^T$4Q;U%Avy*1m1}H8h8!jlKGeu! z!g0ju$)hE(zcm53P8u~8G^yE?8s?^ENzx<`vu*T1*v}BL*Ae#0gnLP+nVCW^F58Wn z%1%jZlt@exm|LQbDHV}EDS;Iwyk=|IpJjw`({}ZyKt;;K@Sb)twHbcjn+g@JA@Vno zkh6a$-Y~mJ^A0;N^R-~=Fnc-cZHhV!E#tLOCNjX1<^bC@{=h9Otav}GnZkH9hr}-8 z_i;=kb`n>54zGx*ACaflXW~Po@T;X2>O}V#6It&I)j|{%Zjg`HK6Lk()IGbTtuf~Zu^(M@n_5k;s zA$6dy4n}BGplh0Dq|?g0upsVZV-7=fo^m>|P8Vi}9AUr-6#2smq9KoVrx6j8pd{}; zZtMm0mul9$s1kp*aNOeJCCuy%Ve z%&p_Mda&g_p9Wa01B;ShtbXaj7WP9xH z0P#|X3<_RxgpC8QyF>iob#lQ_!YiJY>F){kqGa#n}|u_j1@7DP^}8 zGZgNIvPPx7QyL|05pjGAmY1k}IBb(KgWAUKF;=@XY7;T~h%nN4^?ck4Bml-iGQ`2L z5;?UE=D`?NTKb?n;}asrr~GJmrtA*iTS`q;!yHor7h!K8?p*}^+IaRvd7v9a2(%Yb zdG34?1={=o0xh9*%M!j7VG#xo)cIPigk}D)u*W^mY6fy1@VJCO5hyM~p!F3O4lhK*P+^!NJQY!fU-hyk>6I2?>9yBjI=C`V2|^T@rryZ0kf! zQ>5Fp)Ni47d%}G_NLQ+ePAaKf?N;aMu738qU>pyyj$V{Eki``;w%EbCn&9h{2d}KM|zxt$V3a?8dk3 zvB=Km9ub>de{5dZnc?zAjKvMMd*@n-gbv3yMopY#wT&?$L+-6{Dc&$6bX#gDAAkB|iLRC_mSw)xWF-a>0j)0LQ!27WlA3$<+fp;b>bL^%8rb z%@VfaHJs!`z)!VhkjxY+4_G2P0$;Ts9XyMVi7Y;W7ON5tkv=oiK`yQ@*PZO8j$cT` zegwb#*vuEnqU}C`tBO@jQIRZFTo_Bn=)|OmLR1>;>`ir6)yDZHaIHl_+mWZSu$lWw zJ=e9waL((Y{+Ys$c&MR`=SN)i zcCM-wI*(Kh+yqiU%99fT4(FXso{WZ4U69VIqt;-N)~dJS_fkcM3dU;@4LXnOqWf^& z^K;;#Ve*Y3JE5X$N*rI`!!-**rujL~{WU+YCHs4*J(Ie^F%-cKRkd;*OeA@l+7Y&X z&)!2Lm|@)5_QoUl`GkogPY$|UxS5+n->0^ivYBY|l>0*0Dl*Ot$ESKmimi@fV(`L@ zAdLkZfdM>vayX>_DPB-|^S;O=H@o;mfqOG2RD^2b}r=LU}$C*IHX#5>a)R?{<*;&Yp1 zdm16mZ8DO?X^OX$CPV9Mo@+(``m-d5`Q~gRw$*=c-|HgQ-WWMB91yH}>uP{7dBvS6%ZwQi9?SEXxD@a{&!=3wEe~zD zIE&%;gCmi;v6M@Ekn##U0%kcJ{r1K-@K1)YaF7y(8xksLQ>C%oRH(0nzx_i=SwQm7 zEFsPhUB_T9pzcCyi(gAH7MOK19`66w(Dx8`&6#r(-SKw%0{y z+{=Wkquf%eFDf?|8cmDJ%V?YPy~JfDxP{nwev)90%Qo$Aj<8KVnq#{O`*b{Kj*#&Z zzh|zAfb4I?p$mQZE2>v7kCH@0LajzVlT?1B`1=j6@)L_Uj0~Ipzm`&9A6JTTDFaHd z&l#)L@WCa%RB}74{q8=2{pbMMgY;77VCi+QNUu4lmnz{l(r0ESlZ$IX<90ZIX~vD~ zs(^KuV+Tv$5|O?&MBiku`r=&tju%3%X2RwxGhy6SEqrrD2?bX)N@)?j)Jj)c3M&o9 zB^LZGVScXMcga-30v%O diff --git a/src/aca_model/agent/preferences.py b/src/aca_model/agent/preferences.py index ce77014..2bdadec 100644 --- a/src/aca_model/agent/preferences.py +++ b/src/aca_model/agent/preferences.py @@ -175,10 +175,7 @@ def utility_scale_factor( coefficients_rra: FloatND, time_endowment: float, fixed_cost_of_work_intercept: float, - fixed_cost_of_work_age_trend: float, - scale_reference_hours: float, - reference_age: int, - scale_reference_age: int, + reference_hours: float, ) -> FloatND: """Compute the scale factor so utility is approximately 1 at typical values. @@ -191,12 +188,7 @@ def utility_scale_factor( """ consumption_weight = consumption_weights[pref_type] coefficient_rra = coefficients_rra[pref_type] - age_offset = scale_reference_age - reference_age - average_leisure = ( - time_endowment - - scale_reference_hours - - (fixed_cost_of_work_intercept + fixed_cost_of_work_age_trend * age_offset) - ) + average_leisure = time_endowment - reference_hours - fixed_cost_of_work_intercept u_cons = average_consumption_unequiv**consumption_weight u_leisure = average_leisure ** (1.0 - consumption_weight) diff --git a/tests/test_preferences.py b/tests/test_preferences.py index d67e0cd..ad4a9c5 100644 --- a/tests/test_preferences.py +++ b/tests/test_preferences.py @@ -12,14 +12,11 @@ TIME_DISCOUNT_FACTOR = 0.85 TIME_ENDOWMENT = 5000.0 FIXED_COST_INTERCEPT = 0.0 -FIXED_COST_AGE_TREND = 50.0 AVERAGE_CONSUMPTION = 10000.0 RATE_OF_RETURN = 0.01 BEQUEST_WEIGHT = 0.02 BEQUEST_SHIFTER = 500_000.0 -SCALE_REFERENCE_HOURS = 500.0 -REFERENCE_AGE = 50 -SCALE_REFERENCE_AGE = 60 +REFERENCE_HOURS = 500.0 # Pref-type-indexed params: three identical entries so pref_type=0 selects # the struct-ret scalar value used by the regression tests. @@ -39,12 +36,9 @@ def test_utility_scale_factor_crra() -> None: coefficients_rra=RRA_5_BY_TYPE, time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, - fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, - scale_reference_hours=SCALE_REFERENCE_HOURS, - reference_age=REFERENCE_AGE, - scale_reference_age=SCALE_REFERENCE_AGE, + reference_hours=REFERENCE_HOURS, ) - assert jnp.isclose(result, 9_233_279_397_806_166.0, rtol=1e-6) + assert jnp.isclose(result, 1.114_807_837_680_009_4e16, rtol=1e-6) def test_utility_scale_factor_log() -> None: @@ -55,12 +49,9 @@ def test_utility_scale_factor_log() -> None: coefficients_rra=RRA_1_BY_TYPE, time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, - fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, - scale_reference_hours=SCALE_REFERENCE_HOURS, - reference_age=REFERENCE_AGE, - scale_reference_age=SCALE_REFERENCE_AGE, + reference_hours=REFERENCE_HOURS, ) - assert jnp.isclose(result, 0.113_073_257_794_546_72, rtol=1e-6) + assert jnp.isclose(result, 0.112_474_080_852_230_33, rtol=1e-6) # --- scaled_bequest_weight --- @@ -113,10 +104,7 @@ def test_utility_log_regression() -> None: coefficients_rra=RRA_1_BY_TYPE, time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, - fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, - scale_reference_hours=SCALE_REFERENCE_HOURS, - reference_age=REFERENCE_AGE, - scale_reference_age=SCALE_REFERENCE_AGE, + reference_hours=REFERENCE_HOURS, ) result = preferences.utility( consumption_equiv=jnp.array(50000.0), @@ -126,7 +114,7 @@ def test_utility_log_regression() -> None: coefficients_rra=RRA_1_BY_TYPE, utility_scale_factor=scale, ) - assert jnp.isclose(result, 1.005_046_313_660_588_5, rtol=1e-5) + assert jnp.isclose(result, 0.999_720_557_696_258_7, rtol=1e-5) def test_utility_crra_regression() -> None: @@ -137,10 +125,7 @@ def test_utility_crra_regression() -> None: coefficients_rra=RRA_5_BY_TYPE, time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, - fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, - scale_reference_hours=SCALE_REFERENCE_HOURS, - reference_age=REFERENCE_AGE, - scale_reference_age=SCALE_REFERENCE_AGE, + reference_hours=REFERENCE_HOURS, ) result = preferences.utility( consumption_equiv=jnp.array(50000.0), @@ -150,7 +135,7 @@ def test_utility_crra_regression() -> None: coefficients_rra=RRA_5_BY_TYPE, utility_scale_factor=scale, ) - assert jnp.isclose(result, -0.836_511_642_073_019_1, rtol=1e-5) + assert jnp.isclose(result, -1.009_987_562_073_720_9, rtol=1e-5) def test_utility_married_equivalence() -> None: @@ -162,10 +147,7 @@ def test_utility_married_equivalence() -> None: coefficients_rra=RRA_5_BY_TYPE, time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, - fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, - scale_reference_hours=SCALE_REFERENCE_HOURS, - reference_age=REFERENCE_AGE, - scale_reference_age=SCALE_REFERENCE_AGE, + reference_hours=REFERENCE_HOURS, ) single = preferences.utility( consumption_equiv=jnp.array(50000.0), @@ -197,10 +179,7 @@ def test_bequest_log_regression() -> None: coefficients_rra=RRA_1_BY_TYPE, time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, - fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, - scale_reference_hours=SCALE_REFERENCE_HOURS, - reference_age=REFERENCE_AGE, - scale_reference_age=SCALE_REFERENCE_AGE, + reference_hours=REFERENCE_HOURS, ) bwt = preferences.scaled_bequest_weight( bequest_weight=BEQUEST_WEIGHT, @@ -219,7 +198,7 @@ def test_bequest_log_regression() -> None: coefficients_rra=RRA_1_BY_TYPE, utility_scale_factor=scale, ) - assert jnp.isclose(result, 86.539_249_963_643_88, rtol=1e-5) + assert jnp.isclose(result, 86.080_677_139_309_2, rtol=1e-5) def test_bequest_crra_regression() -> None: @@ -230,10 +209,7 @@ def test_bequest_crra_regression() -> None: coefficients_rra=RRA_5_BY_TYPE, time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, - fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, - scale_reference_hours=SCALE_REFERENCE_HOURS, - reference_age=REFERENCE_AGE, - scale_reference_age=SCALE_REFERENCE_AGE, + reference_hours=REFERENCE_HOURS, ) bwt = preferences.scaled_bequest_weight( bequest_weight=BEQUEST_WEIGHT, @@ -252,4 +228,4 @@ def test_bequest_crra_regression() -> None: coefficients_rra=RRA_5_BY_TYPE, utility_scale_factor=scale, ) - assert jnp.isclose(result, -37.932_748_117_035_63, rtol=1e-5) + assert jnp.isclose(result, -45.799_247_573_576_66, rtol=1e-5) From e879e05c760672c659acc88a0a228698047b9341 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 11 May 2026 09:13:05 +0200 Subject: [PATCH 48/54] Phase 3: merge agent/utility.py into preferences.py; adopt u_X naming MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - `preferences.utility` → `preferences.u_working_life` (CRRA + leisure aggregator; `leisure` stays a DAG input). - `utility.retired` → `preferences.u_retired` (forcedout — computes `leisure_retired` inline). - `utility.dead` → `preferences.u_dead` (terminal bequest wrapper around `preferences.bequest`). - Delete `agent/utility.py`; update `select_utility` dispatch. The `u_*` prefix disambiguates regime-utility entry points from auxiliary scalars (`utility_scale_factor`, `discount_factor`, `scaled_bequest_weight`). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/agent/preferences.py | 51 +++++++++++++++++- src/aca_model/agent/utility.py | 64 ----------------------- src/aca_model/baseline/regimes/_common.py | 7 ++- tests/test_model_components.py | 4 +- tests/test_preferences.py | 8 +-- 5 files changed, 58 insertions(+), 76 deletions(-) delete mode 100644 src/aca_model/agent/utility.py diff --git a/src/aca_model/agent/preferences.py b/src/aca_model/agent/preferences.py index 2bdadec..d630588 100644 --- a/src/aca_model/agent/preferences.py +++ b/src/aca_model/agent/preferences.py @@ -129,7 +129,7 @@ def consumption_equiv( return consumption_unequiv / equivalence_scale -def utility( +def u_working_life( consumption_equiv: FloatND, leisure: FloatND, pref_type: DiscreteState, @@ -137,7 +137,7 @@ def utility( coefficients_rra: FloatND, utility_scale_factor: FloatND, ) -> FloatND: - """Within-period utility: CES aggregator over consumption and leisure.""" + """Within-period utility for canwork regimes: CES over consumption and leisure.""" consumption_weight = consumption_weights[pref_type] coefficient_rra = coefficients_rra[pref_type] composite = consumption_equiv**consumption_weight * leisure ** ( @@ -155,6 +155,53 @@ def utility( return u * utility_scale_factor +def u_retired( + consumption_equiv: FloatND, + good_health: IntND, + pref_type: DiscreteState, + consumption_weights: FloatND, + coefficients_rra: FloatND, + utility_scale_factor: FloatND, + time_endowment: float, + leisure_cost_of_bad_health: float, +) -> FloatND: + """Within-period utility for forcedout regimes (no work, retired leisure).""" + leisure = leisure_retired( + good_health=good_health, + time_endowment=time_endowment, + leisure_cost_of_bad_health=leisure_cost_of_bad_health, + ) + return u_working_life( + consumption_equiv=consumption_equiv, + leisure=leisure, + pref_type=pref_type, + consumption_weights=consumption_weights, + coefficients_rra=coefficients_rra, + utility_scale_factor=utility_scale_factor, + ) + + +def u_dead( + assets: ContinuousState, + pref_type: DiscreteState, + bequest_shifter: float, + scaled_bequest_weight: float, + consumption_weights: FloatND, + coefficients_rra: FloatND, + utility_scale_factor: FloatND, +) -> FloatND: + """Terminal bequest utility for the dead regime.""" + return bequest( + assets=assets, + pref_type=pref_type, + bequest_shifter=bequest_shifter, + scaled_bequest_weight=scaled_bequest_weight, + consumption_weights=consumption_weights, + coefficients_rra=coefficients_rra, + utility_scale_factor=utility_scale_factor, + ) + + def discount_factor( pref_type: DiscreteState, discount_factor_by_type: FloatND, diff --git a/src/aca_model/agent/utility.py b/src/aca_model/agent/utility.py deleted file mode 100644 index d7817a1..0000000 --- a/src/aca_model/agent/utility.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Utility function variants for different regime types. - -- retired: forcedout regimes (no work, computes leisure_retired internally) -- dead: terminal bequest - -Canwork regimes use `preferences.utility` directly, with `leisure` computed -as a separate DAG function (`preferences.leisure` / `preferences.leisure_tied`). -""" - -from lcm.typing import ( - ContinuousState, - DiscreteState, - FloatND, - IntND, -) - -from aca_model.agent import preferences - - -def retired( - consumption_equiv: FloatND, - good_health: IntND, - pref_type: DiscreteState, - consumption_weights: FloatND, - coefficients_rra: FloatND, - utility_scale_factor: FloatND, - time_endowment: float, - leisure_cost_of_bad_health: float, -) -> FloatND: - """Utility for forcedout regimes (no work).""" - leisure = preferences.leisure_retired( - good_health=good_health, - time_endowment=time_endowment, - leisure_cost_of_bad_health=leisure_cost_of_bad_health, - ) - return preferences.utility( - consumption_equiv=consumption_equiv, - leisure=leisure, - pref_type=pref_type, - consumption_weights=consumption_weights, - coefficients_rra=coefficients_rra, - utility_scale_factor=utility_scale_factor, - ) - - -def dead( - assets: ContinuousState, - pref_type: DiscreteState, - bequest_shifter: float, - scaled_bequest_weight: float, - consumption_weights: FloatND, - coefficients_rra: FloatND, - utility_scale_factor: FloatND, -) -> FloatND: - """Terminal bequest utility for dead regime.""" - return preferences.bequest( - assets=assets, - pref_type=pref_type, - bequest_shifter=bequest_shifter, - scaled_bequest_weight=scaled_bequest_weight, - consumption_weights=consumption_weights, - coefficients_rra=coefficients_rra, - utility_scale_factor=utility_scale_factor, - ) diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 16b04bd..718c37f 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -29,7 +29,6 @@ health, labor_market, preferences, - utility, ) from aca_model.agent.health import Health, HealthWithDisability from aca_model.agent.labor_market import LaborSupply, LaggedLaborSupply, SpousalIncome @@ -444,7 +443,7 @@ def build_dead_regime(grids: Grids) -> Regime: return Regime( transition=None, functions={ - "utility": utility.dead, + "utility": preferences.u_dead, "utility_scale_factor": preferences.utility_scale_factor, }, states={ @@ -471,8 +470,8 @@ def select_ss_benefit(spec: dict[str, str]) -> Callable[..., Any]: def select_utility(spec: dict[str, str]) -> Callable[..., Any]: """Select the utility function for a regime.""" if spec["canwork"] != "canwork": - return utility.retired - return preferences.utility + return preferences.u_retired + return preferences.u_working_life def _select_leisure(spec: dict[str, str]) -> Callable[..., Any]: diff --git a/tests/test_model_components.py b/tests/test_model_components.py index 1375a8a..79ec540 100644 --- a/tests/test_model_components.py +++ b/tests/test_model_components.py @@ -77,7 +77,7 @@ def test_leisure_bad_health() -> None: def test_utility_positive_leisure() -> None: - result = preferences.utility( + result = preferences.u_working_life( consumption_equiv=jnp.array(10000.0), leisure=jnp.array(3000.0), pref_type=jnp.array(0), @@ -89,7 +89,7 @@ def test_utility_positive_leisure() -> None: def test_utility_log_case() -> None: - result = preferences.utility( + result = preferences.u_working_life( consumption_equiv=jnp.array(10000.0), leisure=jnp.array(3000.0), pref_type=jnp.array(0), diff --git a/tests/test_preferences.py b/tests/test_preferences.py index ad4a9c5..e412a16 100644 --- a/tests/test_preferences.py +++ b/tests/test_preferences.py @@ -106,7 +106,7 @@ def test_utility_log_regression() -> None: fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, ) - result = preferences.utility( + result = preferences.u_working_life( consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), pref_type=jnp.array(0), @@ -127,7 +127,7 @@ def test_utility_crra_regression() -> None: fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, ) - result = preferences.utility( + result = preferences.u_working_life( consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), pref_type=jnp.array(0), @@ -149,7 +149,7 @@ def test_utility_married_equivalence() -> None: fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, ) - single = preferences.utility( + single = preferences.u_working_life( consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), pref_type=jnp.array(0), @@ -157,7 +157,7 @@ def test_utility_married_equivalence() -> None: coefficients_rra=RRA_5_BY_TYPE, utility_scale_factor=scale, ) - married = preferences.utility( + married = preferences.u_working_life( consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), pref_type=jnp.array(0), From f6ed413055312159c5b5045c55d280b6975806be Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 11 May 2026 09:19:46 +0200 Subject: [PATCH 49/54] Phase 4: DAG functions for pref-type indexing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add `consumption_weight(consumption_weights, pref_type)` and `coefficient_rra(coefficients_rra, pref_type)` mirroring `discount_factor`. Register both on every regime in `build_common_functions` and on the dead regime. Drop `pref_type`, `consumption_weights`, `coefficients_rra` from `u_working_life`, `u_retired`, `u_dead`, `bequest`, `utility_scale_factor`: each now takes the scalar `consumption_weight` / `coefficient_rra` as a DAG-resolved input. `scaled_bequest_weight` already consumed scalars — no change. aca-data / aca-estimation param shapes are unchanged: pref-type- indexed Series for `consumption_weights` / `coefficients_rra` still flow through fixed_params; the new DAG functions resolve them per-cell at runtime. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/agent/preferences.py | 79 +++++++++++------------ src/aca_model/baseline/regimes/_common.py | 17 +++-- tests/test_model_components.py | 20 +++--- tests/test_preferences.py | 71 ++++++++------------ 4 files changed, 84 insertions(+), 103 deletions(-) diff --git a/src/aca_model/agent/preferences.py b/src/aca_model/agent/preferences.py index d630588..e7a356d 100644 --- a/src/aca_model/agent/preferences.py +++ b/src/aca_model/agent/preferences.py @@ -132,14 +132,11 @@ def consumption_equiv( def u_working_life( consumption_equiv: FloatND, leisure: FloatND, - pref_type: DiscreteState, - consumption_weights: FloatND, - coefficients_rra: FloatND, + consumption_weight: FloatND, + coefficient_rra: FloatND, utility_scale_factor: FloatND, ) -> FloatND: """Within-period utility for canwork regimes: CES over consumption and leisure.""" - consumption_weight = consumption_weights[pref_type] - coefficient_rra = coefficients_rra[pref_type] composite = consumption_equiv**consumption_weight * leisure ** ( 1.0 - consumption_weight ) @@ -158,9 +155,8 @@ def u_working_life( def u_retired( consumption_equiv: FloatND, good_health: IntND, - pref_type: DiscreteState, - consumption_weights: FloatND, - coefficients_rra: FloatND, + consumption_weight: FloatND, + coefficient_rra: FloatND, utility_scale_factor: FloatND, time_endowment: float, leisure_cost_of_bad_health: float, @@ -174,34 +170,55 @@ def u_retired( return u_working_life( consumption_equiv=consumption_equiv, leisure=leisure, - pref_type=pref_type, - consumption_weights=consumption_weights, - coefficients_rra=coefficients_rra, + consumption_weight=consumption_weight, + coefficient_rra=coefficient_rra, utility_scale_factor=utility_scale_factor, ) def u_dead( assets: ContinuousState, - pref_type: DiscreteState, bequest_shifter: float, scaled_bequest_weight: float, - consumption_weights: FloatND, - coefficients_rra: FloatND, + consumption_weight: FloatND, + coefficient_rra: FloatND, utility_scale_factor: FloatND, ) -> FloatND: """Terminal bequest utility for the dead regime.""" return bequest( assets=assets, - pref_type=pref_type, bequest_shifter=bequest_shifter, scaled_bequest_weight=scaled_bequest_weight, - consumption_weights=consumption_weights, - coefficients_rra=coefficients_rra, + consumption_weight=consumption_weight, + coefficient_rra=coefficient_rra, utility_scale_factor=utility_scale_factor, ) +def consumption_weight( + consumption_weights: FloatND, + pref_type: DiscreteState, +) -> FloatND: + """Per-type consumption weight indexed by preference type. + + Wired as a DAG function so pylcm broadcasts the scalar to every cell; + mirrors `discount_factor`. + """ + return consumption_weights[pref_type] + + +def coefficient_rra( + coefficients_rra: FloatND, + pref_type: DiscreteState, +) -> FloatND: + """Per-type CRRA coefficient indexed by preference type. + + Wired as a DAG function so pylcm broadcasts the scalar to every cell; + mirrors `discount_factor`. + """ + return coefficients_rra[pref_type] + + def discount_factor( pref_type: DiscreteState, discount_factor_by_type: FloatND, @@ -216,25 +233,14 @@ def discount_factor( def utility_scale_factor( - pref_type: DiscreteState, average_consumption_unequiv: float, - consumption_weights: FloatND, - coefficients_rra: FloatND, + consumption_weight: FloatND, + coefficient_rra: FloatND, time_endowment: float, fixed_cost_of_work_intercept: float, reference_hours: float, ) -> FloatND: - """Compute the scale factor so utility is approximately 1 at typical values. - - Returns the scalar for the cell's `pref_type`. Mirrors the `discount_factor` - pattern: take the state as input, return a per-cell scalar. Registering this - as a regime function and then doing `utility_scale_factor[pref_type]` in a - downstream consumer is invalid — pylcm broadcasts function outputs to - per-cell scalars before consumption, and the validator in - `lcm.regime_building.validation` raises on that clash. - """ - consumption_weight = consumption_weights[pref_type] - coefficient_rra = coefficients_rra[pref_type] + """Compute the scale factor so utility is approximately 1 at typical values.""" average_leisure = time_endowment - reference_hours - fixed_cost_of_work_intercept u_cons = average_consumption_unequiv**consumption_weight u_leisure = average_leisure ** (1.0 - consumption_weight) @@ -277,11 +283,10 @@ def scaled_bequest_weight( def bequest( assets: ContinuousState, - pref_type: DiscreteState, bequest_shifter: float, scaled_bequest_weight: float, - consumption_weights: FloatND, - coefficients_rra: FloatND, + consumption_weight: FloatND, + coefficient_rra: FloatND, utility_scale_factor: FloatND, ) -> FloatND: """Bequest function for terminal/dead states. @@ -289,13 +294,7 @@ def bequest( bequest = scale * bwt * (max(0,a) + shifter)^(consumption_weight*(1 - coefficient_rra)) / (1 - coefficient_rra) - `consumption_weights` and `coefficients_rra` are pref-type-indexed - Series from params; `utility_scale_factor` is a regime-function - output (already a per-cell scalar — must NOT be re-indexed by - pref_type). """ - consumption_weight = consumption_weights[pref_type] - coefficient_rra = coefficients_rra[pref_type] assets_shifted = jnp.maximum(0.0, assets) + bequest_shifter one_minus_rra = jnp.where( diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 718c37f..5d52b7b 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -436,14 +436,17 @@ def build_regime_probs(target: FloatND, survival: FloatND) -> FloatND: def build_dead_regime(grids: Grids) -> Regime: """Build the terminal dead regime. - `pref_type` is retained as a state so type-indexed preference params - (`consumption_weight`, `coefficient_rra`, `utility_scale_factor`) can - be indexed by it in the bequest utility. + `pref_type` is retained as a state so the pref-type-indexed DAG + functions (`consumption_weight`, `coefficient_rra`, + `utility_scale_factor`) can resolve their per-cell scalar in the + bequest utility. """ return Regime( transition=None, functions={ "utility": preferences.u_dead, + "consumption_weight": preferences.consumption_weight, + "coefficient_rra": preferences.coefficient_rra, "utility_scale_factor": preferences.utility_scale_factor, }, states={ @@ -510,9 +513,11 @@ def build_common_functions(spec: dict[str, str]) -> dict: functions["is_married"] = labor_market.is_married functions["equivalence_scale"] = preferences.equivalence_scale functions["utility_scale_factor"] = preferences.utility_scale_factor - # `discount_factor` is a DAG function that indexes the per-type - # Series by the pref_type state and returns a scalar. pylcm's - # default H picks the scalar up as a DAG-output H input. + # Pref-type-indexed scalars: DAG functions resolve the per-cell + # value from the params Series so downstream consumers get a + # scalar broadcast to every cell. + functions["consumption_weight"] = preferences.consumption_weight + functions["coefficient_rra"] = preferences.coefficient_rra functions["discount_factor"] = preferences.discount_factor # PIA from pre-computed lookup table diff --git a/tests/test_model_components.py b/tests/test_model_components.py index 79ec540..9de3b60 100644 --- a/tests/test_model_components.py +++ b/tests/test_model_components.py @@ -80,9 +80,8 @@ def test_utility_positive_leisure() -> None: result = preferences.u_working_life( consumption_equiv=jnp.array(10000.0), leisure=jnp.array(3000.0), - pref_type=jnp.array(0), - consumption_weights=jnp.array([0.4, 0.4, 0.4]), - coefficients_rra=jnp.array([2.0, 2.0, 2.0]), + consumption_weight=jnp.array(0.4), + coefficient_rra=jnp.array(2.0), utility_scale_factor=jnp.array(1.0), ) assert jnp.isfinite(result) @@ -92,9 +91,8 @@ def test_utility_log_case() -> None: result = preferences.u_working_life( consumption_equiv=jnp.array(10000.0), leisure=jnp.array(3000.0), - pref_type=jnp.array(0), - consumption_weights=jnp.array([0.4, 0.4, 0.4]), - coefficients_rra=jnp.array([1.0, 1.0, 1.0]), + consumption_weight=jnp.array(0.4), + coefficient_rra=jnp.array(1.0), utility_scale_factor=jnp.array(1.0), ) composite = 10000.0**0.4 * 3000.0**0.6 @@ -105,11 +103,10 @@ def test_utility_log_case() -> None: def test_bequest_positive_assets() -> None: result = preferences.bequest( assets=jnp.array(100000.0), - pref_type=jnp.array(0), bequest_shifter=5000.0, scaled_bequest_weight=0.5, - consumption_weights=jnp.array([0.4, 0.4, 0.4]), - coefficients_rra=jnp.array([2.0, 2.0, 2.0]), + consumption_weight=jnp.array(0.4), + coefficient_rra=jnp.array(2.0), utility_scale_factor=jnp.array(1.0), ) assert jnp.isfinite(result) @@ -118,11 +115,10 @@ def test_bequest_positive_assets() -> None: def test_bequest_zero_assets() -> None: result = preferences.bequest( assets=jnp.array(0.0), - pref_type=jnp.array(0), bequest_shifter=5000.0, scaled_bequest_weight=0.5, - consumption_weights=jnp.array([0.4, 0.4, 0.4]), - coefficients_rra=jnp.array([2.0, 2.0, 2.0]), + consumption_weight=jnp.array(0.4), + coefficient_rra=jnp.array(2.0), utility_scale_factor=jnp.array(1.0), ) assert jnp.isfinite(result) diff --git a/tests/test_preferences.py b/tests/test_preferences.py index e412a16..14ab67d 100644 --- a/tests/test_preferences.py +++ b/tests/test_preferences.py @@ -18,22 +18,15 @@ BEQUEST_SHIFTER = 500_000.0 REFERENCE_HOURS = 500.0 -# Pref-type-indexed params: three identical entries so pref_type=0 selects -# the struct-ret scalar value used by the regression tests. -WEIGHT_BY_TYPE = jnp.array([CONSUMPTION_WEIGHT, CONSUMPTION_WEIGHT, CONSUMPTION_WEIGHT]) -RRA_5_BY_TYPE = jnp.array([5.0, 5.0, 5.0]) -RRA_1_BY_TYPE = jnp.array([1.0, 1.0, 1.0]) - # --- utility_scale_factor --- def test_utility_scale_factor_crra() -> None: result = preferences.utility_scale_factor( - pref_type=jnp.array(0), average_consumption_unequiv=AVERAGE_CONSUMPTION, - consumption_weights=WEIGHT_BY_TYPE, - coefficients_rra=RRA_5_BY_TYPE, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(5.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, @@ -43,10 +36,9 @@ def test_utility_scale_factor_crra() -> None: def test_utility_scale_factor_log() -> None: result = preferences.utility_scale_factor( - pref_type=jnp.array(0), average_consumption_unequiv=AVERAGE_CONSUMPTION, - consumption_weights=WEIGHT_BY_TYPE, - coefficients_rra=RRA_1_BY_TYPE, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(1.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, @@ -98,10 +90,9 @@ def test_scaled_bequest_weight_zero() -> None: def test_utility_log_regression() -> None: scale = preferences.utility_scale_factor( - pref_type=jnp.array(0), average_consumption_unequiv=AVERAGE_CONSUMPTION, - consumption_weights=WEIGHT_BY_TYPE, - coefficients_rra=RRA_1_BY_TYPE, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(1.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, @@ -109,9 +100,8 @@ def test_utility_log_regression() -> None: result = preferences.u_working_life( consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), - pref_type=jnp.array(0), - consumption_weights=WEIGHT_BY_TYPE, - coefficients_rra=RRA_1_BY_TYPE, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(1.0), utility_scale_factor=scale, ) assert jnp.isclose(result, 0.999_720_557_696_258_7, rtol=1e-5) @@ -119,10 +109,9 @@ def test_utility_log_regression() -> None: def test_utility_crra_regression() -> None: scale = preferences.utility_scale_factor( - pref_type=jnp.array(0), average_consumption_unequiv=AVERAGE_CONSUMPTION, - consumption_weights=WEIGHT_BY_TYPE, - coefficients_rra=RRA_5_BY_TYPE, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(5.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, @@ -130,9 +119,8 @@ def test_utility_crra_regression() -> None: result = preferences.u_working_life( consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), - pref_type=jnp.array(0), - consumption_weights=WEIGHT_BY_TYPE, - coefficients_rra=RRA_5_BY_TYPE, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(5.0), utility_scale_factor=scale, ) assert jnp.isclose(result, -1.009_987_562_073_720_9, rtol=1e-5) @@ -141,10 +129,9 @@ def test_utility_crra_regression() -> None: def test_utility_married_equivalence() -> None: """Married with equiv-scaled consumption_unequiv should equal single utility.""" scale = preferences.utility_scale_factor( - pref_type=jnp.array(0), average_consumption_unequiv=AVERAGE_CONSUMPTION, - consumption_weights=WEIGHT_BY_TYPE, - coefficients_rra=RRA_5_BY_TYPE, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(5.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, @@ -152,17 +139,15 @@ def test_utility_married_equivalence() -> None: single = preferences.u_working_life( consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), - pref_type=jnp.array(0), - consumption_weights=WEIGHT_BY_TYPE, - coefficients_rra=RRA_5_BY_TYPE, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(5.0), utility_scale_factor=scale, ) married = preferences.u_working_life( consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), - pref_type=jnp.array(0), - consumption_weights=WEIGHT_BY_TYPE, - coefficients_rra=RRA_5_BY_TYPE, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(5.0), utility_scale_factor=scale, ) assert jnp.isclose(single, married, rtol=1e-5) @@ -173,10 +158,9 @@ def test_utility_married_equivalence() -> None: def test_bequest_log_regression() -> None: scale = preferences.utility_scale_factor( - pref_type=jnp.array(0), average_consumption_unequiv=AVERAGE_CONSUMPTION, - consumption_weights=WEIGHT_BY_TYPE, - coefficients_rra=RRA_1_BY_TYPE, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(1.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, @@ -191,11 +175,10 @@ def test_bequest_log_regression() -> None: ) result = preferences.bequest( assets=jnp.array(10000.0), - pref_type=jnp.array(0), bequest_shifter=BEQUEST_SHIFTER, scaled_bequest_weight=bwt.item(), - consumption_weights=WEIGHT_BY_TYPE, - coefficients_rra=RRA_1_BY_TYPE, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(1.0), utility_scale_factor=scale, ) assert jnp.isclose(result, 86.080_677_139_309_2, rtol=1e-5) @@ -203,10 +186,9 @@ def test_bequest_log_regression() -> None: def test_bequest_crra_regression() -> None: scale = preferences.utility_scale_factor( - pref_type=jnp.array(0), average_consumption_unequiv=AVERAGE_CONSUMPTION, - consumption_weights=WEIGHT_BY_TYPE, - coefficients_rra=RRA_5_BY_TYPE, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(5.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, @@ -221,11 +203,10 @@ def test_bequest_crra_regression() -> None: ) result = preferences.bequest( assets=jnp.array(10000.0), - pref_type=jnp.array(0), bequest_shifter=BEQUEST_SHIFTER, scaled_bequest_weight=bwt.item(), - consumption_weights=WEIGHT_BY_TYPE, - coefficients_rra=RRA_5_BY_TYPE, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(5.0), utility_scale_factor=scale, ) assert jnp.isclose(result, -45.799_247_573_576_66, rtol=1e-5) From 6da86ec75937b225d77017414f1292073212ecb2 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 11 May 2026 09:49:33 +0200 Subject: [PATCH 50/54] Phase 5: require all params explicitly; consolidate base derived categoricals MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `create_model` (baseline + ACA) now requires `fixed_params`, `wage_params`, `derived_categoricals`, and `pref_type_grid`. No `None` defaults, no `or {}` shortcuts. Same for `build_all_regimes` and `build_grids`; the AIME piecewise-fallback and the static assets floor are gone. Add `aca_model.baseline.derived_categoricals.BASE_DERIVED_CATEGORICALS` constant — `target_his` lives there now and both `create_model` factories merge it into the caller's `derived_categoricals` so callers no longer maintain a per-file `base_derived = {"target_his": ...}` block. Expose `get_hcc_persistent_shock(grid_config)` / `get_hcc_persistent_grid_points(grid_config)` so callers (aca-data `task_predicted_hcc_insurer`, aca-estimation `_assemble_params`) can derive the shock without a bare model. Benchmark snapshot now bundles `wage_params` alongside `fixed_params` and `params`. `get_benchmark_params` returns a 3-tuple. Test helpers in `tests/helpers/model.py` thread the snapshot through with the benchmark `derived_categoricals` set. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../_benchmark_data/benchmark_params.pkl | Bin 65398 -> 68327 bytes src/aca_model/aca/model.py | 42 +++----- src/aca_model/aca/regimes/__init__.py | 9 +- .../baseline/derived_categoricals.py | 12 +++ src/aca_model/baseline/model.py | 46 ++++----- src/aca_model/baseline/regimes/__init__.py | 16 ++- src/aca_model/baseline/regimes/_common.py | 97 ++++++++---------- src/aca_model/benchmark.py | 12 +-- tests/helpers/model.py | 50 +++++---- tests/test_benchmark.py | 4 +- .../test_initial_conditions_extreme_assets.py | 2 +- tests/test_model_creation.py | 22 ++-- 12 files changed, 157 insertions(+), 155 deletions(-) create mode 100644 src/aca_model/baseline/derived_categoricals.py diff --git a/src/aca_model/_benchmark_data/benchmark_params.pkl b/src/aca_model/_benchmark_data/benchmark_params.pkl index 5297faa4ccb61966e2b4f91eeef3c4d1b206db92..719e2e0e295da0f8448860e8a09ae52d253300eb 100644 GIT binary patch delta 2478 zcmc(g2~bl<7{`-v$e{=ntAI!WtvD1I$UA~o{cH(TFrIvp;lh^S*7ZN(NrE>WQyVG1(S5!=rAX5Q}G@7r(pzx%%5 zZttMhqa1DBtvWeu7WNu>|&%@%vtGQ-e*u;NYnRD9})>h5fcqo7iGq!Zt)|zsGHRs3= zhS^*Qrehqo_;4Zaf)l5-h6^qWMT}|}7gF}qy`lYFh_SuVoluYJD91XOSguX3ouxkyZVXSY7Cle0}L-IG4S`exV+&o2F>QBR*JnC%sbXLh5rx(*F)9U zLG*P~hNQ{CFV&l=IpA-{Nt&XD{&R3?ZP7HF;*? zK^%m2PmMBm5Rh~Si;t+pVNSVTdf9OT8pUMmieGVvtM@v&ubP0CA$e=49sxxai(XPU z2ng^n*qrY~0Qq&9d)5O2sOAw)uU9C%bQ6^^N2NGuujbbyHJU6 zK>`UyFIWY{VICCJbZi(v0cpSV7{;{oKqGWgFbPy@<;UUfZg7S{%J?MAdcG~`(P=kG zY43@Q*I$%FD?P7Fplp&hzsO(rz`!ekNNEbk!ttiaXW9~pd9!I{Ppk)h&{Xz^?0 zf+b2Ry@WPXJ0kBmv`jA_-D&spC|co0#y{1!b75#z?XI0OxX@&z7BuT&u+DIakyjuW zv^9=7801HV8h$fGtR#BD~85X^qE6Ys44`S}Thi5Ms-Of%n;o82I{} zKcUW|Fo=(eZI&Ov;0sI7lUL7RaI7mPxv>m`w3ML6OV2PEm|Gg?+lqnn@yPv+I1V~T zv*)v4p@BB~x5TP&KqddFH4cC373w$_ql~wEcFo!yhdrC1g=xp3v3g)}Y8VcFtNnXL zHUxAv2!f5XaM1sTQ11%#Bf#q9kaFJz91?fe?NKHX@W4eYr}HijPp*E6G!k&2mTbL^ z(7`)+e9r~~zKb~JZ0bsY&B`imc@F`5v!=!qs|e^le??4BBcWiYUGx2I1PuJqUggFm z$D=idEcN?9YYbWH_kq?JQsn=N%`pQ^G@hYcQKnA>s7WJ0Wrvy>X;_kqG&hq>MoDP5 zAxmpNjZHzV&{V2KMiz_SxHKg-f*GA;7w9pu%1L3UeQ2h15o1psKp)gWR7j(l zU^J^JU75-Rqcb5N*@o{(WJMl(*?KAi*+!)o(&arCbfv(KT}gNLcyNY?pF6?f>8U-_ z>5N|9NZ5Z+_2V0=ZX#8d=H(ijf~wF|V>h&tPV4r1ccsFS5U&X?4oy3L@4DYmscS%2 z^jKK`qfFEp9j|m^pEoV)v!UDjlIfa0J36x8g)`THKGMHTi+alut~;D@CFciz=jcp0 mPsO7L|M&$sqFfsN`5u~97?q$|sYInj$s>!R-m)B(%IPnatkOmR delta 679 zcmaDpmF3$%X0`^_soDP-HnLs#wE6Epe#XfQl-QWqIVU$NicDX?$jIW#$H~CZ89PNI zOAyG&5&{ybB~y}Wr)0!-7Pn7H%5caM1B#WlO$nNk2owd%Xk?^hF#^RhQe&rpco|u; zQ>z$l8BM3lGBJ7~8?H=*;nSBeG5RBGBGW}SC@%5?x=2bY%NQhBlw|^o0Yd|hl!0l3?V!zU)>DQPUm8SPHGkQ(0WdX9jGc&eM zpU=W*$TXL8`Xv@dpUKlX#U~5Ia!%g=P+)R_g2v{$&t8mdM}WG$qb5I0lVDp8WM=K1 z{N{_q_GC83i%gRP*i|OSe&L$j(IU^ZoO7~ao-F4Ykjosh_Hs^Mm?| Model: """Create an ACA policy variant model. @@ -31,22 +32,18 @@ def create_model( n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. policy: Which ACA policy combination to apply (e.g. `PolicyVariant.ACA`). - fixed_params: Parameters to fix at model creation time, or `None` - to skip. Fixed params are partialled into compiled functions - and removed from the params template. Pass data-derived - constants here; only estimation parameters should go through - `model.simulate(params=...)`. + fixed_params: Parameters to fix at model creation time. Pass + data-derived constants here; only estimation parameters + should go through `model.simulate(params=...)`. wage_params: Data-derived wage profile dict (`log_ft_wage_mean`, `log_ft_wage_std`, `adj_wage_hours_*`) used only at grid-build time to size the assets-floor to `-max_annual_labor_income`. - Not routed to the pylcm Model. `None` skips the floor sizing. + Not routed to the pylcm Model. derived_categoricals: Extra categorical mappings for derived - variables not in the model's state/action grids, or `None`. - Needed when `fixed_params` contains `pd.Series` indexed by DAG - function outputs. - grid_config: Continuous-grid point counts. Pass `GRID_CONFIG` for - production values or `BENCHMARK_GRID_CONFIG` for the - fast-but-structurally-faithful benchmark. + variables not in the model's state/action grids. `target_his` + is added automatically via `BASE_DERIVED_CATEGORICALS`. + grid_config: Continuous-grid point counts. + pref_type_grid: Pref-type `DiscreteGrid`. Returns: pylcm Model with ACA-specific function overrides. @@ -62,22 +59,15 @@ def create_model( grid_config=grid_config, fixed_params=fixed_params, wage_params=wage_params, + pref_type_grid=pref_type_grid, ) - # See `baseline.model.create_model` for why `target_his` is declared - # as a base-layer derived categorical. - base_derived: dict[str, DiscreteGrid] = { - "target_his": DiscreteGrid(HealthInsuranceState), - } - if derived_categoricals is not None: - base_derived.update(derived_categoricals) - return Model( regimes=regimes, ages=ages, regime_id_class=RegimeId, description=f"Structural retirement model ({policy.name})", - fixed_params=fixed_params or {}, - derived_categoricals=base_derived, + fixed_params=fixed_params, + derived_categoricals={**BASE_DERIVED_CATEGORICALS, **derived_categoricals}, n_subjects=n_subjects, ) diff --git a/src/aca_model/aca/regimes/__init__.py b/src/aca_model/aca/regimes/__init__.py index 5b9f4bf..ca26ca2 100644 --- a/src/aca_model/aca/regimes/__init__.py +++ b/src/aca_model/aca/regimes/__init__.py @@ -4,7 +4,7 @@ from collections.abc import Mapping from typing import Any -from lcm import Regime +from lcm import DiscreteGrid, Regime from aca_model.aca.health_insurance import PolicyVariant from aca_model.aca.regimes._overrides import apply_aca_overrides @@ -17,15 +17,16 @@ def build_all_regimes( *, policy: PolicyVariant, grid_config: GridConfig, - fixed_params: Mapping[str, Any] | None, - wage_params: Mapping[str, Any] | None, + fixed_params: Mapping[str, Any], + wage_params: Mapping[str, Any], + pref_type_grid: DiscreteGrid, ) -> dict[str, Regime]: """Build all 19 regimes with ACA policy overrides.""" regimes = baseline_build_all_regimes( grid_config=grid_config, fixed_params=fixed_params, wage_params=wage_params, - pref_type_grid=None, + pref_type_grid=pref_type_grid, ) result = {} for name, regime in regimes.items(): diff --git a/src/aca_model/baseline/derived_categoricals.py b/src/aca_model/baseline/derived_categoricals.py new file mode 100644 index 0000000..d9745d5 --- /dev/null +++ b/src/aca_model/baseline/derived_categoricals.py @@ -0,0 +1,12 @@ +"""Base-layer derived categoricals shared by baseline + ACA model factories.""" + +from types import MappingProxyType + +from lcm import DiscreteGrid + +from aca_model.baseline.health_insurance import HealthInsuranceState + +# `target_his` is a state subsumed into regimes. +BASE_DERIVED_CATEGORICALS = MappingProxyType( + {"target_his": DiscreteGrid(HealthInsuranceState)} +) diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index a278bce..5e3f934 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -5,7 +5,7 @@ Usage: from aca_model.baseline.model import create_model - model = create_model(n_subjects=...) + model = create_model(n_subjects=..., fixed_params=..., wage_params=..., ...) params = get_default_params() V = model.solve(params) """ @@ -15,7 +15,7 @@ from lcm import AgeGrid, DiscreteGrid, Model -from aca_model.baseline.health_insurance import HealthInsuranceState +from aca_model.baseline.derived_categoricals import BASE_DERIVED_CATEGORICALS from aca_model.baseline.regimes import RegimeId, build_all_regimes from aca_model.config import MODEL_CONFIG, GridConfig @@ -23,35 +23,36 @@ def create_model( *, n_subjects: int, - fixed_params: Mapping[str, Any] | None, - wage_params: Mapping[str, Any] | None, - derived_categoricals: Mapping[str, DiscreteGrid] | None, + fixed_params: Mapping[str, Any], + wage_params: Mapping[str, Any], + derived_categoricals: Mapping[str, DiscreteGrid], grid_config: GridConfig, - pref_type_grid: DiscreteGrid | None, + pref_type_grid: DiscreteGrid, ) -> Model: """Create the baseline structural retirement model. Args: n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. - fixed_params: Parameters to fix at model creation time, or `None` - to skip. Fixed params are partialled into compiled functions - and removed from the params template. Pass data-derived - constants here; only estimation parameters should go through + fixed_params: Parameters to fix at model creation time. Fixed + params are partialled into compiled functions and removed + from the params template. Pass data-derived constants here; + only estimation parameters should go through `model.simulate(params=...)`. wage_params: Data-derived wage profile dict (`log_ft_wage_mean`, `log_ft_wage_std`, `adj_wage_hours_*`) used only at grid-build time to size the assets-floor to `-max_annual_labor_income`. - Not routed to the pylcm Model. `None` skips the floor sizing. + Not routed to the pylcm Model. derived_categoricals: Extra categorical mappings for derived - variables not in the model's state/action grids, or `None`. - Needed when `fixed_params` contains `pd.Series` indexed by DAG - function outputs. + variables not in the model's state/action grids. Needed when + `fixed_params` contains `pd.Series` indexed by DAG function + outputs. `target_his` is added automatically via + `BASE_DERIVED_CATEGORICALS`. grid_config: Continuous-grid point counts. Pass `GRID_CONFIG` for production values or `BENCHMARK_GRID_CONFIG` for the fast-but-structurally-faithful benchmark. - pref_type_grid: Pref-type `DiscreteGrid`, or `None` to use - `DiscreteGrid(PrefType)`. Pass a custom grid to substitute - the production layout (e.g. the 2-type benchmark variant). + pref_type_grid: Pref-type `DiscreteGrid`. Pass + `DiscreteGrid(PrefType)` for the production 3-type layout, + or a compact variant (e.g. `DiscreteGrid(BenchmarkPrefType)`). Returns: A pylcm Model with 19 regimes (18 non-terminal + dead) spanning @@ -70,19 +71,12 @@ def create_model( pref_type_grid=pref_type_grid, ) - # `target_his` is a state subsumed into regimes. - base_derived: dict[str, DiscreteGrid] = { - "target_his": DiscreteGrid(HealthInsuranceState), - } - if derived_categoricals is not None: - base_derived.update(derived_categoricals) - return Model( regimes=regimes, ages=ages, regime_id_class=RegimeId, description="Baseline structural retirement model (pre-ACA)", - fixed_params=fixed_params or {}, - derived_categoricals=base_derived, + fixed_params=fixed_params, + derived_categoricals={**BASE_DERIVED_CATEGORICALS, **derived_categoricals}, n_subjects=n_subjects, ) diff --git a/src/aca_model/baseline/regimes/__init__.py b/src/aca_model/baseline/regimes/__init__.py index 2473489..5384900 100644 --- a/src/aca_model/baseline/regimes/__init__.py +++ b/src/aca_model/baseline/regimes/__init__.py @@ -60,18 +60,16 @@ def build_regime(name: str, grids: Grids) -> Regime: def build_all_regimes( *, grid_config: GridConfig, - fixed_params: Mapping[str, Any] | None, - wage_params: Mapping[str, Any] | None, - pref_type_grid: DiscreteGrid | None, + fixed_params: Mapping[str, Any], + wage_params: Mapping[str, Any], + pref_type_grid: DiscreteGrid, ) -> dict[str, Regime]: """Build all 19 baseline regimes (18 non-terminal + dead). - `fixed_params` is forwarded to `build_grids` for data-driven AIME - breakpoints; `wage_params` for the data-driven assets floor; - either being `None` keeps the corresponding static fallback. - `pref_type_grid` lets callers inject a compact `DiscreteGrid(...)` - (e.g. the benchmark's 2-type `BenchmarkPrefType`); `None` falls - back to `DiscreteGrid(PrefType)`. + `fixed_params` carries the PIA bends for the AIME piecewise grid; + `wage_params` sizes the assets-floor to `-max_annual_labor_income`; + `pref_type_grid` selects the pref-type cardinality (production + `DiscreteGrid(PrefType)` or the benchmark's 2-type variant). """ grids = build_grids( grid_config=grid_config, diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 5d52b7b..cc3c829 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -32,7 +32,6 @@ ) from aca_model.agent.health import Health, HealthWithDisability from aca_model.agent.labor_market import LaborSupply, LaggedLaborSupply, SpousalIncome -from aca_model.agent.preferences import PrefType from aca_model.baseline import health_insurance from aca_model.baseline.health_insurance import BuyPrivate from aca_model.config import MODEL_CONFIG, GridConfig @@ -207,30 +206,22 @@ class Grids: def build_grids( *, grid_config: GridConfig, - fixed_params: Mapping[str, Any] | None, - wage_params: Mapping[str, Any] | None, - pref_type_grid: DiscreteGrid | None, + fixed_params: Mapping[str, Any], + wage_params: Mapping[str, Any], + pref_type_grid: DiscreteGrid, ) -> Grids: """Build continuous-state/action grids from a `GridConfig`. - When `fixed_params` carries `pia_aime_grid`, the AIME grid becomes - a `PiecewiseLinSpacedGrid` breakpointed at the PIA bends (total 32 - points). When `wage_params` provides `log_ft_wage_mean` and friends - (as produced by `aca_data.task_wages`), the assets grid's lower - bound is set to `-max_annual_labor_income` so that the worst shock - lands on a gridpoint inside the support. Without `fixed_params` / - `wage_params` (bare model for tests / compile-only paths), both - grids fall back to their historical static shapes. + The AIME grid is `PiecewiseLinSpacedGrid` breakpointed at the PIA + bends from `fixed_params["pia_aime_grid"]` (total 32 points). The + assets grid's lower bound is `-max_annual_labor_income` computed + from `wage_params` (`log_ft_wage_mean`, `log_ft_wage_std`, + `adj_wage_hours_*`). `wage_params` is passed separately rather than via `fixed_params` because `log_ft_wage_mean` is a per-iteration param at estimation time (reconstructed from `wage_bias_coeffs_*`), not a fixed one; the grid floor must still be known at build time. - - `pref_type_grid` lets callers (e.g. the benchmark) substitute a - compact `DiscreteGrid(...)` for the production - `DiscreteGrid(PrefType)`. When `None`, defaults to the production - 3-type grid. """ # Unit-variance standardised shocks: the total_costs / wage # formulas rescale these by fixed_params-level std parameters @@ -245,13 +236,7 @@ def build_grids( sigma=(1.0 - _WAGE_RHO**2) ** 0.5, mu=0.0, ) - _HCC_RHO = 0.925 - hcc_persistent = lcm.shocks.ar1.Rouwenhorst( - n_points=grid_config.n_hcc_persistent_gridpoints, - rho=_HCC_RHO, - sigma=(1.0 - _HCC_RHO**2) ** 0.5, - mu=0.0, - ) + hcc_persistent = get_hcc_persistent_shock(grid_config=grid_config) hcc_transitory = lcm.shocks.iid.Normal( n_points=grid_config.n_hcc_transitory_gridpoints, gauss_hermite=True, @@ -259,11 +244,9 @@ def build_grids( sigma=1.0, ) - assets_start = 0.0 - if wage_params is not None and _has_required_wage_keys(wage_params=wage_params): - assets_start = -_compute_max_annual_labor_income( - wage_params=wage_params, wage_res_grid=wage_res - ) + assets_start = -_compute_max_annual_labor_income( + wage_params=wage_params, wage_res_grid=wage_res + ) return Grids( assets=LinSpacedGrid( @@ -279,28 +262,44 @@ def build_grids( wage_res=wage_res, hcc_persistent=hcc_persistent, hcc_transitory=hcc_transitory, - pref_type=pref_type_grid or DiscreteGrid(PrefType), + pref_type=pref_type_grid, ) +_HCC_RHO = 0.925 + + +def get_hcc_persistent_shock(*, grid_config: GridConfig) -> lcm.shocks.ar1.Rouwenhorst: + """Return the persistent-HCC AR(1) shock grid for a given `grid_config`. + + Exposed so callers that need the shock's gridpoints / transition + probs (e.g. `assemble_fixed_params`, the HCC insurer predictor) + can derive them from `grid_config` alone without instantiating a + full `Model`. + """ + return lcm.shocks.ar1.Rouwenhorst( + n_points=grid_config.n_hcc_persistent_gridpoints, + rho=_HCC_RHO, + sigma=(1.0 - _HCC_RHO**2) ** 0.5, + mu=0.0, + ) + + +def get_hcc_persistent_grid_points(*, grid_config: GridConfig) -> FloatND: + """Materialise the persistent-HCC shock gridpoints for `grid_config`.""" + return get_hcc_persistent_shock(grid_config=grid_config).to_jax() + + def _build_aime_grid( - *, grid_config: GridConfig, fixed_params: Mapping[str, Any] | None + *, grid_config: GridConfig, fixed_params: Mapping[str, Any] ) -> ContinuousGrid: """Return the AIME grid. - With `pia_aime_grid` available, the grid is piecewise-linspaced with - breakpoints at the PIA bends and `_AIME_PIECE_N_POINTS` in each - segment. `n_aime_gridpoints` from `grid_config` is ignored on this - path; the total is fixed by the PIA structure (32 points). Without - the fixed params, falls back to the historical `LinSpacedGrid`. + The grid is piecewise-linspaced with breakpoints at the PIA bends + in `fixed_params["pia_aime_grid"]` and `_AIME_PIECE_N_POINTS` in + each segment. `n_aime_gridpoints` from `grid_config` is ignored on + this path; the total is fixed by the PIA structure (32 points). """ - if fixed_params is None or "pia_aime_grid" not in fixed_params: - return LinSpacedGrid( - start=0.0, - stop=8_000.0, - n_points=grid_config.n_aime_gridpoints, - batch_size=grid_config.n_aime_batch_size, - ) kinks = [float(k) for k in np.asarray(fixed_params["pia_aime_grid"])] pieces = ( Piece(interval=f"[{kinks[0]}, {kinks[1]})", n_points=_AIME_PIECE_N_POINTS[0]), @@ -312,18 +311,6 @@ def _build_aime_grid( ) -def _has_required_wage_keys(*, wage_params: Mapping[str, Any]) -> bool: - return all( - key in wage_params - for key in ( - "log_ft_wage_mean", - "log_ft_wage_std", - "adj_wage_hours_exp", - "adj_wage_hours_int", - ) - ) - - def _compute_max_annual_labor_income( *, wage_params: Mapping[str, Any], diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index d215b8b..57522f5 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -50,7 +50,6 @@ "good_health": DiscreteGrid(GoodHealth), "is_married": DiscreteGrid(IsMarried), "his": DiscreteGrid(HealthInsuranceState), - "target_his": DiscreteGrid(HealthInsuranceState), "pref_type": DiscreteGrid(BenchmarkPrefType), } @@ -84,11 +83,11 @@ def create_benchmark_model( functions for that batch shape. pref_type_grid: Pref-type grid; pass `DiscreteGrid(BenchmarkPrefType)`. """ - fixed_params, _ = get_benchmark_params(model=None) + fixed_params, wage_params, _ = get_benchmark_params(model=None) return create_model( grid_config=BENCHMARK_GRID_CONFIG, fixed_params=fixed_params, - wage_params=None, + wage_params=wage_params, derived_categoricals=_DERIVED_CATEGORICALS, pref_type_grid=pref_type_grid, n_subjects=n_subjects, @@ -97,8 +96,8 @@ def create_benchmark_model( def get_benchmark_params( *, model: Model | None -) -> tuple[dict[str, Any], dict[str, Any]]: - """Load the frozen `(fixed_params, params)` snapshot. +) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + """Load the frozen `(fixed_params, wage_params, params)` snapshot. When `model` is provided, consumption_unequiv gridpoints are injected into `params` for each regime that declares `consumption_unequiv` as @@ -109,10 +108,11 @@ def get_benchmark_params( with _PARAMS_FILE.open("rb") as fh: data = cloudpickle.load(fh) fixed_params = data["fixed_params"] + wage_params = data["wage_params"] params = data["params"] if model is not None: params = inject_consumption_unequiv_points(params=params, model=model) - return fixed_params, params + return fixed_params, wage_params, params def get_benchmark_initial_conditions( diff --git a/tests/helpers/model.py b/tests/helpers/model.py index 930c33e..dc7c407 100644 --- a/tests/helpers/model.py +++ b/tests/helpers/model.py @@ -1,38 +1,52 @@ -"""Tiny factories that wrap `create_model` with `None` for every optional input. +"""Tiny factories that wrap `create_model` with the benchmark snapshot. -Used by tests that don't need fixed params, wage params, or a custom pref-type -grid. These helpers exist so production `create_model` factories can stay -default-free without forcing every test call site to spell out -`fixed_params=None, wage_params=None, ...` six times. +Used by tests that need a structurally faithful model without spelling +out fixed_params, wage_params, and a pref-type grid at every call site. +Production callers (aca-estimation, scripts) assemble these explicitly. """ -from lcm import Model +from lcm import DiscreteGrid, Model from aca_model.aca.health_insurance import PolicyVariant from aca_model.aca.model import create_model as _create_aca_model +from aca_model.agent.health import GoodHealth +from aca_model.agent.labor_market import IsMarried +from aca_model.agent.preferences import BenchmarkPrefType +from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.baseline.model import create_model as _create_baseline_model -from aca_model.config import GRID_CONFIG +from aca_model.benchmark import get_benchmark_params +from aca_model.config import BENCHMARK_GRID_CONFIG + +_DERIVED_CATEGORICALS = { + "good_health": DiscreteGrid(GoodHealth), + "is_married": DiscreteGrid(IsMarried), + "his": DiscreteGrid(HealthInsuranceState), + "pref_type": DiscreteGrid(BenchmarkPrefType), +} def make_baseline_model(*, n_subjects: int) -> Model: - """Baseline model with `GRID_CONFIG` and no fixed/wage/derived params.""" + """Baseline model on `BENCHMARK_GRID_CONFIG` with the benchmark snapshot params.""" + fixed_params, wage_params, _ = get_benchmark_params(model=None) return _create_baseline_model( n_subjects=n_subjects, - fixed_params=None, - wage_params=None, - derived_categoricals=None, - grid_config=GRID_CONFIG, - pref_type_grid=None, + fixed_params=fixed_params, + wage_params=wage_params, + derived_categoricals=_DERIVED_CATEGORICALS, + grid_config=BENCHMARK_GRID_CONFIG, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), ) def make_aca_model(*, n_subjects: int, policy: PolicyVariant) -> Model: - """ACA model with `GRID_CONFIG` and no fixed/wage/derived params.""" + """ACA model on `BENCHMARK_GRID_CONFIG` with the benchmark snapshot params.""" + fixed_params, wage_params, _ = get_benchmark_params(model=None) return _create_aca_model( n_subjects=n_subjects, policy=policy, - fixed_params=None, - wage_params=None, - derived_categoricals=None, - grid_config=GRID_CONFIG, + fixed_params=fixed_params, + wage_params=wage_params, + derived_categoricals=_DERIVED_CATEGORICALS, + grid_config=BENCHMARK_GRID_CONFIG, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), ) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 649442e..d3d83f4 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -19,7 +19,7 @@ def test_benchmark_model_simulates_end_to_end() -> None: n_subjects=n_subjects, pref_type_grid=DiscreteGrid(BenchmarkPrefType), ) - _, params = get_benchmark_params(model=model) + _, _, params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=n_subjects, seed=0 ) @@ -59,7 +59,7 @@ def test_benchmark_simulate_obeys_borrowing_constraint() -> None: n_subjects=n_subjects, pref_type_grid=DiscreteGrid(BenchmarkPrefType), ) - _, params = get_benchmark_params(model=model) + _, _, params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=n_subjects, seed=0 ) diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py index 7e583e3..ef6e1e9 100644 --- a/tests/test_initial_conditions_extreme_assets.py +++ b/tests/test_initial_conditions_extreme_assets.py @@ -104,7 +104,7 @@ def test_extreme_negative_assets_subject_passes_validation() -> None: n_subjects=n_subjects, pref_type_grid=DiscreteGrid(BenchmarkPrefType), ) - _, params = get_benchmark_params(model=model) + _, _, params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=n_subjects, seed=0 diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index 042e8a3..37984da 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -4,30 +4,36 @@ import pytest from helpers.model import make_aca_model, make_baseline_model +from lcm import DiscreteGrid from aca_model.aca import health_insurance as aca_hi from aca_model.aca.health_insurance import PolicyVariant from aca_model.aca.regimes import build_all_regimes as _build_aca_regimes +from aca_model.agent.preferences import BenchmarkPrefType from aca_model.baseline.regimes import REGIME_SPECS, RegimeId from aca_model.baseline.regimes import build_regime as _build_regime from aca_model.baseline.regimes._common import build_grids -from aca_model.config import GRID_CONFIG +from aca_model.benchmark import get_benchmark_params +from aca_model.config import BENCHMARK_GRID_CONFIG + +_FIXED_PARAMS, _WAGE_PARAMS, _ = get_benchmark_params(model=None) def build_aca_regimes(policy: PolicyVariant) -> dict: return _build_aca_regimes( policy=policy, - grid_config=GRID_CONFIG, - fixed_params=None, - wage_params=None, + grid_config=BENCHMARK_GRID_CONFIG, + fixed_params=_FIXED_PARAMS, + wage_params=_WAGE_PARAMS, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), ) _GRIDS = build_grids( - grid_config=GRID_CONFIG, - fixed_params=None, - wage_params=None, - pref_type_grid=None, + grid_config=BENCHMARK_GRID_CONFIG, + fixed_params=_FIXED_PARAMS, + wage_params=_WAGE_PARAMS, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), ) From e2861d042567009a3a839e9cfabb4205dcdd4930 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 11 May 2026 10:01:17 +0200 Subject: [PATCH 51/54] =?UTF-8?q?Phase=206:=20precompute=5Ftargets=20?= =?UTF-8?q?=E2=86=92=20precompute=5Ftarget=5Fregimes;=20introduce=20Regime?= =?UTF-8?q?Spec=20TypedDict?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `precompute_targets(spec: Mapping[str, str]) -> dict[str, int]` becomes `precompute_target_regimes(spec: RegimeSpec) -> MappingProxyType[str, int]`. The four-axis `RegimeSpec` `TypedDict` makes the valid keys (`his`, `mc`, `ss`, `canwork`) explicit and propagates through every spec-consuming helper. Helper renames make the dispatch verb read as "per target *regime*", matching the new function name: - `_build_per_target_next_assets` → `_build_per_target_regime_next_assets` - `_build_per_target_health` → `_build_per_target_regime_health` - `_build_per_target_claimed_ss` → `_build_per_target_regime_claimed_ss` - `_build_per_target_lagged_labor_supply` → `_build_per_target_regime_lagged_labor_supply` Every `targets` local that's the result of the precompute call is now `target_regimes`. Spec parameters on the HIS-specific builders (`_retiree`, `_tied`, `_nongroup`) and ACA overrides also adopt the new `RegimeSpec` type. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/aca_model/aca/regimes/_overrides.py | 3 +- src/aca_model/baseline/regimes/_common.py | 112 +++++++++++--------- src/aca_model/baseline/regimes/_nongroup.py | 3 +- src/aca_model/baseline/regimes/_retiree.py | 3 +- src/aca_model/baseline/regimes/_tied.py | 3 +- 5 files changed, 70 insertions(+), 54 deletions(-) diff --git a/src/aca_model/aca/regimes/_overrides.py b/src/aca_model/aca/regimes/_overrides.py index 79bd9e4..4ab590e 100644 --- a/src/aca_model/aca/regimes/_overrides.py +++ b/src/aca_model/aca/regimes/_overrides.py @@ -8,11 +8,12 @@ from aca_model.aca import health_insurance as aca_hi from aca_model.aca.health_insurance import PolicyVariant +from aca_model.baseline.regimes._common import RegimeSpec def apply_aca_overrides( functions: dict, - spec: dict[str, str], + spec: RegimeSpec, policy: PolicyVariant, ) -> None: """Override baseline functions with ACA versions in-place. diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index cc3c829..a148d4b 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -6,7 +6,8 @@ from collections.abc import Callable, Mapping from dataclasses import dataclass -from typing import Any +from types import MappingProxyType +from typing import Any, TypedDict import jax.numpy as jnp import lcm.shocks.ar1 @@ -62,8 +63,17 @@ class RegimeId: dead: int +class RegimeSpec(TypedDict): + """Structural decomposition of a regime: (HIS, Medicare, SS, work) axes.""" + + his: str + mc: str + ss: str + canwork: str + + # {his}_{mc}_{ss}_{canwork} -REGIME_SPECS: dict[str, dict[str, str]] = { +REGIME_SPECS: dict[str, RegimeSpec] = { "retiree_nomc_inelig_canwork": { "his": "retiree", "mc": "nomc", @@ -367,7 +377,7 @@ def _compute_max_annual_labor_income( } -def make_active_func(spec: dict[str, str]) -> Callable[..., Any]: +def make_active_func(spec: RegimeSpec) -> Callable[..., Any]: """Return the age predicate for a regime spec.""" key = (spec["mc"], spec["ss"], spec["canwork"]) predicate = _ACTIVE_PREDICATES.get(key) @@ -377,7 +387,7 @@ def make_active_func(spec: dict[str, str]) -> Callable[..., Any]: return predicate -def build_states(spec: dict[str, str], grids: Grids) -> dict: +def build_states(spec: RegimeSpec, grids: Grids) -> dict: """Build the state dict for a non-dead regime.""" can_work = spec["canwork"] == "canwork" @@ -400,7 +410,7 @@ def build_states(spec: dict[str, str], grids: Grids) -> dict: return states -def build_actions(spec: dict[str, str], grids: Grids) -> dict: +def build_actions(spec: RegimeSpec, grids: Grids) -> dict: """Build the action dict for a non-dead regime.""" actions: dict = {} if spec["ss"] == "choose": @@ -444,7 +454,7 @@ def build_dead_regime(grids: Grids) -> Regime: ) -def select_ss_benefit(spec: dict[str, str]) -> Callable[..., Any]: +def select_ss_benefit(spec: RegimeSpec) -> Callable[..., Any]: """Select the appropriate SS benefit function for a regime.""" ss = spec["ss"] @@ -457,21 +467,21 @@ def select_ss_benefit(spec: dict[str, str]) -> Callable[..., Any]: return social_security.benefit_inelig_pre65 -def select_utility(spec: dict[str, str]) -> Callable[..., Any]: +def select_utility(spec: RegimeSpec) -> Callable[..., Any]: """Select the utility function for a regime.""" if spec["canwork"] != "canwork": return preferences.u_retired return preferences.u_working_life -def _select_leisure(spec: dict[str, str]) -> Callable[..., Any]: +def _select_leisure(spec: RegimeSpec) -> Callable[..., Any]: """Select the leisure function for a canwork regime.""" if spec["his"] == "tied": return preferences.leisure_tied return preferences.leisure -def build_common_functions(spec: dict[str, str]) -> dict: +def build_common_functions(spec: RegimeSpec) -> dict: """Build the shared functions dict for a non-dead regime. Contains all functions common to every HIS type. Per-HIS modules add @@ -543,7 +553,7 @@ def build_common_functions(spec: dict[str, str]) -> dict: return functions -def precompute_targets(spec: Mapping[str, str]) -> dict[str, int]: +def precompute_target_regimes(spec: RegimeSpec) -> MappingProxyType[str, int]: """Pre-compute target regime IDs for each next-age bracket.""" def _resolve(his_val: str, mc_val: str, ss_val: str, canwork_val: str) -> int: @@ -559,22 +569,24 @@ def _resolve(his_val: str, mc_val: str, ss_val: str, canwork_val: str) -> int: ng_his = "nongroup" if spec["his"] == "tied" else spec["his"] - return { - "forcedout": _resolve(ng_his, "oamc", "forced", "forcedout"), - "forcedout_ng": _resolve("nongroup", "oamc", "forced", "forcedout"), - "forced_forced": _resolve(spec["his"], "oamc", "forced", "canwork"), - "forced_forced_ng": _resolve("nongroup", "oamc", "forced", "canwork"), - "forced_choose": _resolve(spec["his"], "oamc", "choose", "canwork"), - "forced_choose_ng": _resolve("nongroup", "oamc", "choose", "canwork"), - "dimc_choose": _resolve(spec["his"], "dimc", "choose", "canwork"), - "dimc_choose_ng": _resolve("nongroup", "dimc", "choose", "canwork"), - "nomc_choose": _resolve(spec["his"], "nomc", "choose", "canwork"), - "nomc_choose_ng": _resolve("nongroup", "nomc", "choose", "canwork"), - "dimc_inelig": _resolve(spec["his"], "dimc", "inelig", "canwork"), - "dimc_inelig_ng": _resolve("nongroup", "dimc", "inelig", "canwork"), - "nomc_inelig": _resolve(spec["his"], "nomc", "inelig", "canwork"), - "nomc_inelig_ng": _resolve("nongroup", "nomc", "inelig", "canwork"), - } + return MappingProxyType( + { + "forcedout": _resolve(ng_his, "oamc", "forced", "forcedout"), + "forcedout_ng": _resolve("nongroup", "oamc", "forced", "forcedout"), + "forced_forced": _resolve(spec["his"], "oamc", "forced", "canwork"), + "forced_forced_ng": _resolve("nongroup", "oamc", "forced", "canwork"), + "forced_choose": _resolve(spec["his"], "oamc", "choose", "canwork"), + "forced_choose_ng": _resolve("nongroup", "oamc", "choose", "canwork"), + "dimc_choose": _resolve(spec["his"], "dimc", "choose", "canwork"), + "dimc_choose_ng": _resolve("nongroup", "dimc", "choose", "canwork"), + "nomc_choose": _resolve(spec["his"], "nomc", "choose", "canwork"), + "nomc_choose_ng": _resolve("nongroup", "nomc", "choose", "canwork"), + "dimc_inelig": _resolve(spec["his"], "dimc", "inelig", "canwork"), + "dimc_inelig_ng": _resolve("nongroup", "dimc", "inelig", "canwork"), + "nomc_inelig": _resolve(spec["his"], "nomc", "inelig", "canwork"), + "nomc_inelig_ng": _resolve("nongroup", "nomc", "inelig", "canwork"), + } + ) _TARGET_KEYS = ( @@ -590,9 +602,9 @@ def _resolve(his_val: str, mc_val: str, ss_val: str, canwork_val: str) -> int: def make_targets(name: str) -> tuple[dict[str, int], dict[str, int]]: """Build own and nongroup target subsets for a regime name.""" - tgts = precompute_targets(REGIME_SPECS[name]) - own = {k: tgts[k] for k in _TARGET_KEYS} - ng = {k: tgts[k + "_ng"] for k in _TARGET_KEYS} + target_regimes = precompute_target_regimes(REGIME_SPECS[name]) + own = {k: target_regimes[k] for k in _TARGET_KEYS} + ng = {k: target_regimes[k + "_ng"] for k in _TARGET_KEYS} return own, ng @@ -631,11 +643,11 @@ def select_target_for_age( ) -def build_state_transitions(spec: dict[str, str]) -> dict: +def build_state_transitions(spec: RegimeSpec) -> dict: """Build the state transitions dict for a non-dead regime.""" transitions: dict = {} - transitions["health"] = _build_per_target_health(spec) - transitions["assets"] = _build_per_target_next_assets(spec) + transitions["health"] = _build_per_target_regime_health(spec) + transitions["assets"] = _build_per_target_regime_next_assets(spec) transitions["pref_type"] = None transitions["aime"] = ( social_security.next_aime @@ -643,17 +655,17 @@ def build_state_transitions(spec: dict[str, str]) -> dict: else social_security.next_aime_disabled ) transitions["spousal_income"] = MarkovTransition(labor_market.next_spousal_income) - lagged_supply_transition = _build_per_target_lagged_labor_supply(spec) + lagged_supply_transition = _build_per_target_regime_lagged_labor_supply(spec) if lagged_supply_transition: transitions["lagged_labor_supply"] = lagged_supply_transition - claimed_ss_transition = _build_per_target_claimed_ss(spec) + claimed_ss_transition = _build_per_target_regime_claimed_ss(spec) if claimed_ss_transition: transitions["claimed_ss"] = claimed_ss_transition return transitions -def _build_per_target_next_assets( - spec: Mapping[str, str], +def _build_per_target_regime_next_assets( + spec: RegimeSpec, ) -> dict[RegimeName, Callable[..., FloatND]]: """Build per-target assets transitions. @@ -663,13 +675,13 @@ def _build_per_target_next_assets( `aime` state and pylcm cannot resolve `next_aime` there. Non-dead targets use the full `next_assets` with the pension correction. """ - targets = precompute_targets(spec) + target_regimes = precompute_target_regimes(spec) id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} result: dict[RegimeName, Callable[..., FloatND]] = {} seen_ids: set[int] = set() - for target_id in targets.values(): + for target_id in target_regimes.values(): if target_id in seen_ids: continue seen_ids.add(target_id) @@ -682,21 +694,21 @@ def _build_per_target_next_assets( return result -def _build_per_target_health( - spec: Mapping[str, str], +def _build_per_target_regime_health( + spec: RegimeSpec, ) -> dict[RegimeName, MarkovTransition]: """Build per-target health transitions. Pre-65 regimes use HealthWithDisability (3-state), post-65 use Health (2-state). Cross-grid transitions (3->2) happen at the age-65 boundary. """ - targets = precompute_targets(spec) + target_regimes = precompute_target_regimes(spec) id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} result: dict[RegimeName, MarkovTransition] = {} seen_ids: set[int] = set() - for target_id in targets.values(): + for target_id in target_regimes.values(): if target_id == RegimeId.dead or target_id in seen_ids: continue seen_ids.add(target_id) @@ -714,8 +726,8 @@ def _build_per_target_health( return result -def _build_per_target_claimed_ss( - spec: Mapping[str, str], +def _build_per_target_regime_claimed_ss( + spec: RegimeSpec, ) -> dict[RegimeName, Callable[..., BoolND]]: """Build per-target claimed_ss transitions. @@ -726,13 +738,13 @@ def _build_per_target_claimed_ss( if spec["ss"] in ("forced", "forcedout"): return {} - targets = precompute_targets(spec) + target_regimes = precompute_target_regimes(spec) id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} result: dict[RegimeName, Callable[..., BoolND]] = {} seen_ids: set[int] = set() - for target_id in targets.values(): + for target_id in target_regimes.values(): if target_id == RegimeId.dead or target_id in seen_ids: continue seen_ids.add(target_id) @@ -751,8 +763,8 @@ def _build_per_target_claimed_ss( return result -def _build_per_target_lagged_labor_supply( - spec: Mapping[str, str], +def _build_per_target_regime_lagged_labor_supply( + spec: RegimeSpec, ) -> dict[RegimeName, Callable[..., BoolND]]: """Build per-target lagged_labor_supply transitions. @@ -767,13 +779,13 @@ def _build_per_target_lagged_labor_supply( if spec["canwork"] != "canwork": return {} - targets = precompute_targets(spec) + target_regimes = precompute_target_regimes(spec) id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} result: dict[RegimeName, Callable[..., BoolND]] = {} seen_ids: set[int] = set() - for target_id in targets.values(): + for target_id in target_regimes.values(): if target_id == RegimeId.dead or target_id in seen_ids: continue seen_ids.add(target_id) diff --git a/src/aca_model/baseline/regimes/_nongroup.py b/src/aca_model/baseline/regimes/_nongroup.py index 7ee82ff..a723b44 100644 --- a/src/aca_model/baseline/regimes/_nongroup.py +++ b/src/aca_model/baseline/regimes/_nongroup.py @@ -15,6 +15,7 @@ from aca_model.baseline.regimes._common import ( REGIME_SPECS, Grids, + RegimeSpec, build_actions, build_common_functions, build_regime_probs, @@ -74,7 +75,7 @@ def transition( return transition -def _build_functions(spec: dict[str, str]) -> dict: +def _build_functions(spec: RegimeSpec) -> dict: """Build functions dict for a nongroup regime.""" can_work = spec["canwork"] == "canwork" functions = build_common_functions(spec) diff --git a/src/aca_model/baseline/regimes/_retiree.py b/src/aca_model/baseline/regimes/_retiree.py index a941fa9..4f16faa 100644 --- a/src/aca_model/baseline/regimes/_retiree.py +++ b/src/aca_model/baseline/regimes/_retiree.py @@ -16,6 +16,7 @@ from aca_model.baseline.regimes._common import ( REGIME_SPECS, Grids, + RegimeSpec, build_actions, build_common_functions, build_regime_probs, @@ -86,7 +87,7 @@ def transition( return transition -def _build_functions(spec: dict[str, str]) -> dict: +def _build_functions(spec: RegimeSpec) -> dict: """Build functions dict for a retiree regime.""" can_work = spec["canwork"] == "canwork" functions = build_common_functions(spec) diff --git a/src/aca_model/baseline/regimes/_tied.py b/src/aca_model/baseline/regimes/_tied.py index 4351cf5..df76fa4 100644 --- a/src/aca_model/baseline/regimes/_tied.py +++ b/src/aca_model/baseline/regimes/_tied.py @@ -17,6 +17,7 @@ from aca_model.baseline.regimes._common import ( REGIME_SPECS, Grids, + RegimeSpec, build_actions, build_common_functions, build_regime_probs, @@ -65,7 +66,7 @@ def transition( return transition -def _build_functions(spec: dict[str, str]) -> dict: +def _build_functions(spec: RegimeSpec) -> dict: """Build functions dict for a tied regime.""" functions = build_common_functions(spec) From 71a8535069db79dea00439629da291dc950f9a06 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 11 May 2026 10:40:52 +0200 Subject: [PATCH 52/54] ci: bump pylcm pin to 99a5e31 (post-#345 main) `ca66ba9` was an interim commit on the pre-squash `distributed` lineage; the squash-force-push to that branch (#346) made it unreachable. Bump to pylcm main HEAD, which carries the canonical-float-dtype work this branch depends on. Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index ec60050..c9ecc47 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -29,7 +29,7 @@ jobs: - name: Install pylcm (unreleased feature branch required) run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@ca66ba9" + git+https://github.com/OpenSourceEconomics/pylcm.git@99a5e31" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest From 2779011c211fcd210c574634dad03681eab64079 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 11 May 2026 10:44:22 +0200 Subject: [PATCH 53/54] ci: step name reflects that pylcm pin is on main, not a feature branch Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index c9ecc47..cb0fdab 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -26,7 +26,7 @@ jobs: - uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - - name: Install pylcm (unreleased feature branch required) + - name: Install pylcm from main (PyPI release lags) run: >- pip install "pylcm @ git+https://github.com/OpenSourceEconomics/pylcm.git@99a5e31" From fedd7565da2ba1b2afce429d30a8453c36506221 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 11 May 2026 11:36:41 +0200 Subject: [PATCH 54/54] =?UTF-8?q?Rename=20consumption=5Funequiv=20?= =?UTF-8?q?=E2=86=92=20consumption=5Fdollars;=20tighten=20types;=20drop=20?= =?UTF-8?q?factory=20magic?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - consumption_unequiv → consumption_dollars (state, action, params, module, test module). The dollar / dollars suffix replaces the unequiv / equiv pair on the in-$ side; `consumption_equiv` (the utility-equivalized form) keeps its name. - consumption_dollar_grid.py → consumption_dollars_grid.py. - `_DERIVED_CATEGORICALS` dicts now include `target_his` alongside `his`, `good_health`, `is_married`, `pref_type`. The previous `BASE_DERIVED_CATEGORICALS` constant + per-factory merge is gone — callers pass the full dict. - `RegimeSpec` TypedDict fields tightened to `Literal[...]` for the four axes. - `fixed_params` annotations switched to `lcm.typing.UserParams` (the pylcm-side alias) across `create_model`, `build_all_regimes`, `build_grids`. - `u_working_life` → `u_can_work`; `u_retired` → `u_cannot_work`. - `_build_per_target_regime_next_assets` → `_build_per_target_regime_assets`; ordering of dispatch site and definition aligned (assets → health → claimed_ss → lagged_labor_supply). `lagged_supply_transition` → `lagged_labor_supply_transition`. - `MAX_CONSUMPTION_DOLLARS` docstring: drop the in-line rationale, add a TODO referencing pylcm#348. - `aca.create_model` return doc → "pylcm Model" (no "ACA-specific" qualifier). - CI: pylcm pin uses `@main` (PyPI release lags so the git pin stays). Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/main.yml | 4 +- .../_benchmark_data/benchmark_params.pkl | Bin 68327 -> 68327 bytes src/aca_model/aca/model.py | 15 +++-- src/aca_model/aca/regimes/__init__.py | 3 +- src/aca_model/agent/assets_and_income.py | 22 +++---- src/aca_model/agent/preferences.py | 14 ++-- .../baseline/derived_categoricals.py | 12 ---- src/aca_model/baseline/model.py | 15 ++--- src/aca_model/baseline/regimes/__init__.py | 3 +- src/aca_model/baseline/regimes/_common.py | 62 +++++++++--------- src/aca_model/benchmark.py | 11 ++-- src/aca_model/config.py | 4 +- ...iv_grid.py => consumption_dollars_grid.py} | 42 ++++++------ tests/helpers/model.py | 1 + tests/test_benchmark.py | 10 +-- tests/test_budget_chain_integration.py | 6 +- ...id.py => test_consumption_dollars_grid.py} | 36 +++++----- .../test_initial_conditions_extreme_assets.py | 42 ++++++------ tests/test_model_components.py | 4 +- tests/test_model_creation.py | 2 +- tests/test_pension_integration.py | 2 +- tests/test_preferences.py | 38 +++++------ 22 files changed, 169 insertions(+), 179 deletions(-) delete mode 100644 src/aca_model/baseline/derived_categoricals.py rename src/aca_model/{consumption_unequiv_grid.py => consumption_dollars_grid.py} (76%) rename tests/{test_consumption_unequiv_grid.py => test_consumption_dollars_grid.py} (69%) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index cb0fdab..67c82fa 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -26,10 +26,10 @@ jobs: - uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - - name: Install pylcm from main (PyPI release lags) + - name: Install pylcm run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@99a5e31" + git+https://github.com/OpenSourceEconomics/pylcm.git@main" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest diff --git a/src/aca_model/_benchmark_data/benchmark_params.pkl b/src/aca_model/_benchmark_data/benchmark_params.pkl index 719e2e0e295da0f8448860e8a09ae52d253300eb..d0d9c1dac7cf1641c10007edf3ad3865dc0d0436 100644 GIT binary patch delta 166 zcmV;X09pU%l?3ON1O$Kul?1T_$_4?Clg|bm2V`$-Y+-V9v+@QP_?5o_K9g=BDwA&i zY?o320Wt!W&y@}Vv9}}w0UrUE+5rJdlL`MjmHh!#liL3>lb|97mHq)nm(T$L3YS9z z0ZNzG0|6nIu>%1n84h%5ZDn6&Ze(wFZDnqBlt)0Wt>sH7f5Mkx1OXhE5(EJ#myH7f UB$pBd0Th=`1OXJceFOn&22=n#H2?qr delta 167 zcmV;Y09gO$l?3ON1O$Kul>xB?$_4>{lg|bm2X$^`adl~Sv+@QP_?IgJ0Xvg!AS#o7 z^DL8{|3H_!0s%4tmC}_F0kM+_|2~(H0RcystpWiKlimL^lb|97muCV2MwizC0aUfZ z0UrUELjwUymoWqZB$uuO0VbD?0|6l!4s>a4WnX1(WN&wEWo~qoM?kKv dict[str, Regime]: diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index 844369c..92e9abb 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -35,7 +35,7 @@ def cash_on_hand( return assets + after_tax_income + ssi_benefit - hic_premium -def consumption_unequiv_floor( +def consumption_dollars_floor( consumption_equiv_floor: float, equivalence_scale: FloatND, ) -> FloatND: @@ -45,17 +45,17 @@ def consumption_unequiv_floor( def transfers( cash_on_hand: FloatND, - consumption_unequiv_floor: FloatND, + consumption_dollars_floor: FloatND, ) -> FloatND: """Government transfers to enforce the consumption floor.""" - return jnp.maximum(0.0, consumption_unequiv_floor - cash_on_hand) + return jnp.maximum(0.0, consumption_dollars_floor - cash_on_hand) def next_assets( cash_on_hand: FloatND, transfers: FloatND, pension_assets_adjustment: FloatND, - consumption_unequiv: ContinuousAction, + consumption_dollars: ContinuousAction, oop_costs: FloatND, ) -> ContinuousState: """Compute beginning-of-next-period assets for non-terminal targets. @@ -67,7 +67,7 @@ def next_assets( cash_on_hand + transfers + pension_assets_adjustment - - consumption_unequiv + - consumption_dollars - oop_costs ) @@ -75,7 +75,7 @@ def next_assets( def next_assets_when_dead( cash_on_hand: FloatND, transfers: FloatND, - consumption_unequiv: ContinuousAction, + consumption_dollars: ContinuousAction, oop_costs: FloatND, ) -> ContinuousState: """Compute beginning-of-next-period assets for the dead/terminal target. @@ -86,17 +86,17 @@ def next_assets_when_dead( (which would otherwise need to come from a transition `dead` does not have, since `aime` is not a state in the terminal regime). """ - return cash_on_hand + transfers - consumption_unequiv - oop_costs + return cash_on_hand + transfers - consumption_dollars - oop_costs def borrowing_constraint( - consumption_unequiv: ContinuousAction, + consumption_dollars: ContinuousAction, cash_on_hand: FloatND, - consumption_unequiv_floor: FloatND, + consumption_dollars_floor: FloatND, ) -> BoolND: """Consumption cannot exceed post-transfer resources. - Post-transfer resources are `max(cash_on_hand, consumption_unequiv_floor)`: + Post-transfer resources are `max(cash_on_hand, consumption_dollars_floor)`: the transfer system tops `cash_on_hand` to the floor when below, otherwise resources are unchanged. The algebraic identity is `cash_on_hand + transfers == max(cash_on_hand, floor)`; the `max` @@ -105,4 +105,4 @@ def borrowing_constraint( the kink-boundary comparison at large negative values of `assets`. The `max` form returns `floor` exactly. """ - return consumption_unequiv <= jnp.maximum(cash_on_hand, consumption_unequiv_floor) + return consumption_dollars <= jnp.maximum(cash_on_hand, consumption_dollars_floor) diff --git a/src/aca_model/agent/preferences.py b/src/aca_model/agent/preferences.py index e7a356d..5c08541 100644 --- a/src/aca_model/agent/preferences.py +++ b/src/aca_model/agent/preferences.py @@ -122,14 +122,14 @@ def leisure_retired( def consumption_equiv( - consumption_unequiv: ContinuousAction, + consumption_dollars: ContinuousAction, equivalence_scale: FloatND, ) -> FloatND: """Utility-equivalized consumption.""" - return consumption_unequiv / equivalence_scale + return consumption_dollars / equivalence_scale -def u_working_life( +def u_can_work( consumption_equiv: FloatND, leisure: FloatND, consumption_weight: FloatND, @@ -152,7 +152,7 @@ def u_working_life( return u * utility_scale_factor -def u_retired( +def u_cannot_work( consumption_equiv: FloatND, good_health: IntND, consumption_weight: FloatND, @@ -167,7 +167,7 @@ def u_retired( time_endowment=time_endowment, leisure_cost_of_bad_health=leisure_cost_of_bad_health, ) - return u_working_life( + return u_can_work( consumption_equiv=consumption_equiv, leisure=leisure, consumption_weight=consumption_weight, @@ -233,7 +233,7 @@ def discount_factor( def utility_scale_factor( - average_consumption_unequiv: float, + average_consumption_dollars: float, consumption_weight: FloatND, coefficient_rra: FloatND, time_endowment: float, @@ -242,7 +242,7 @@ def utility_scale_factor( ) -> FloatND: """Compute the scale factor so utility is approximately 1 at typical values.""" average_leisure = time_endowment - reference_hours - fixed_cost_of_work_intercept - u_cons = average_consumption_unequiv**consumption_weight + u_cons = average_consumption_dollars**consumption_weight u_leisure = average_leisure ** (1.0 - consumption_weight) one_minus_rra = jnp.where( diff --git a/src/aca_model/baseline/derived_categoricals.py b/src/aca_model/baseline/derived_categoricals.py deleted file mode 100644 index d9745d5..0000000 --- a/src/aca_model/baseline/derived_categoricals.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Base-layer derived categoricals shared by baseline + ACA model factories.""" - -from types import MappingProxyType - -from lcm import DiscreteGrid - -from aca_model.baseline.health_insurance import HealthInsuranceState - -# `target_his` is a state subsumed into regimes. -BASE_DERIVED_CATEGORICALS = MappingProxyType( - {"target_his": DiscreteGrid(HealthInsuranceState)} -) diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index 5e3f934..98416ce 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -14,8 +14,8 @@ from typing import Any from lcm import AgeGrid, DiscreteGrid, Model +from lcm.typing import UserParams -from aca_model.baseline.derived_categoricals import BASE_DERIVED_CATEGORICALS from aca_model.baseline.regimes import RegimeId, build_all_regimes from aca_model.config import MODEL_CONFIG, GridConfig @@ -23,7 +23,7 @@ def create_model( *, n_subjects: int, - fixed_params: Mapping[str, Any], + fixed_params: UserParams, wage_params: Mapping[str, Any], derived_categoricals: Mapping[str, DiscreteGrid], grid_config: GridConfig, @@ -42,11 +42,10 @@ def create_model( `log_ft_wage_std`, `adj_wage_hours_*`) used only at grid-build time to size the assets-floor to `-max_annual_labor_income`. Not routed to the pylcm Model. - derived_categoricals: Extra categorical mappings for derived - variables not in the model's state/action grids. Needed when - `fixed_params` contains `pd.Series` indexed by DAG function - outputs. `target_his` is added automatically via - `BASE_DERIVED_CATEGORICALS`. + derived_categoricals: Categorical mappings for `pd.Series` + fixed_params index levels that aren't model state/action + grids — `target_his`, `his`, `good_health`, `is_married`, + `pref_type`. grid_config: Continuous-grid point counts. Pass `GRID_CONFIG` for production values or `BENCHMARK_GRID_CONFIG` for the fast-but-structurally-faithful benchmark. @@ -77,6 +76,6 @@ def create_model( regime_id_class=RegimeId, description="Baseline structural retirement model (pre-ACA)", fixed_params=fixed_params, - derived_categoricals={**BASE_DERIVED_CATEGORICALS, **derived_categoricals}, + derived_categoricals=derived_categoricals, n_subjects=n_subjects, ) diff --git a/src/aca_model/baseline/regimes/__init__.py b/src/aca_model/baseline/regimes/__init__.py index 5384900..bd4c564 100644 --- a/src/aca_model/baseline/regimes/__init__.py +++ b/src/aca_model/baseline/regimes/__init__.py @@ -14,6 +14,7 @@ from typing import Any from lcm import DiscreteGrid, Regime +from lcm.typing import UserParams from aca_model.baseline.regimes import _nongroup as nongroup from aca_model.baseline.regimes import _retiree as retiree @@ -60,7 +61,7 @@ def build_regime(name: str, grids: Grids) -> Regime: def build_all_regimes( *, grid_config: GridConfig, - fixed_params: Mapping[str, Any], + fixed_params: UserParams, wage_params: Mapping[str, Any], pref_type_grid: DiscreteGrid, ) -> dict[str, Regime]: diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index a148d4b..a2e3a13 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -7,7 +7,7 @@ from collections.abc import Callable, Mapping from dataclasses import dataclass from types import MappingProxyType -from typing import Any, TypedDict +from typing import Any, Literal, TypedDict import jax.numpy as jnp import lcm.shocks.ar1 @@ -23,7 +23,7 @@ ) from lcm.grids.continuous import ContinuousGrid from lcm.grids.piecewise import Piece, PiecewiseLinSpacedGrid -from lcm.typing import BoolND, FloatND, RegimeName +from lcm.typing import BoolND, FloatND, RegimeName, UserParams from aca_model.agent import ( assets_and_income, @@ -34,7 +34,7 @@ from aca_model.agent.health import Health, HealthWithDisability from aca_model.agent.labor_market import LaborSupply, LaggedLaborSupply, SpousalIncome from aca_model.baseline import health_insurance -from aca_model.baseline.health_insurance import BuyPrivate +from aca_model.baseline.health_insurance import BuyPrivate, HealthInsuranceState from aca_model.config import MODEL_CONFIG, GridConfig from aca_model.environment import social_security, taxes from aca_model.environment.social_security import ClaimedSS @@ -66,10 +66,10 @@ class RegimeId: class RegimeSpec(TypedDict): """Structural decomposition of a regime: (HIS, Medicare, SS, work) axes.""" - his: str - mc: str - ss: str - canwork: str + his: Literal["retiree", "tied", "nongroup"] + mc: Literal["nomc", "dimc", "oamc"] + ss: Literal["inelig", "choose", "forced"] + canwork: Literal["canwork", "forcedout"] # {his}_{mc}_{ss}_{canwork} @@ -191,7 +191,7 @@ class RegimeSpec(TypedDict): class Grids: assets: LinSpacedGrid aime: ContinuousGrid - consumption_unequiv: ContinuousGrid + consumption_dollars: ContinuousGrid wage_res: Any hcc_persistent: Any hcc_transitory: Any @@ -203,20 +203,21 @@ class Grids: _AIME_PIECE_N_POINTS: tuple[int, int, int] = (10, 11, 11) -MAX_CONSUMPTION_UNEQUIV: float = 300_000.0 -"""Upper bound of the runtime consumption_unequiv grid in $/year. +MAX_CONSUMPTION_DOLLARS: float = 300_000.0 +"""Upper bound of the runtime consumption_dollars grid in $/year. Lives here next to the other grid bounds (assets `stop=500_000.0`, -AIME `stop=8_000.0`). `inject_consumption_unequiv_points` imports it -directly — pylcm rejects `fixed_params` entries no DAG function -consumes, so this stays a module constant. +AIME `stop=8_000.0`). + +TODO: route through `fixed_params` once pylcm#348 lands (so the bound +can vary across optimizer iterations without re-importing this module). """ def build_grids( *, grid_config: GridConfig, - fixed_params: Mapping[str, Any], + fixed_params: UserParams, wage_params: Mapping[str, Any], pref_type_grid: DiscreteGrid, ) -> Grids: @@ -266,8 +267,8 @@ def build_grids( batch_size=grid_config.n_assets_batch_size, ), aime=_build_aime_grid(grid_config=grid_config, fixed_params=fixed_params), - consumption_unequiv=IrregSpacedGrid( - n_points=grid_config.n_consumption_unequiv_gridpoints, + consumption_dollars=IrregSpacedGrid( + n_points=grid_config.n_consumption_dollars_gridpoints, ), wage_res=wage_res, hcc_persistent=hcc_persistent, @@ -301,7 +302,7 @@ def get_hcc_persistent_grid_points(*, grid_config: GridConfig) -> FloatND: def _build_aime_grid( - *, grid_config: GridConfig, fixed_params: Mapping[str, Any] + *, grid_config: GridConfig, fixed_params: UserParams ) -> ContinuousGrid: """Return the AIME grid. @@ -419,7 +420,7 @@ def build_actions(spec: RegimeSpec, grids: Grids) -> dict: actions["labor_supply"] = DiscreteGrid(LaborSupply) if spec["his"] == "nongroup" and spec["mc"] == "nomc": actions["buy_private"] = DiscreteGrid(BuyPrivate) - actions["consumption_unequiv"] = grids.consumption_unequiv + actions["consumption_dollars"] = grids.consumption_dollars return actions @@ -470,8 +471,8 @@ def select_ss_benefit(spec: RegimeSpec) -> Callable[..., Any]: def select_utility(spec: RegimeSpec) -> Callable[..., Any]: """Select the utility function for a regime.""" if spec["canwork"] != "canwork": - return preferences.u_retired - return preferences.u_working_life + return preferences.u_cannot_work + return preferences.u_can_work def _select_leisure(spec: RegimeSpec) -> Callable[..., Any]: @@ -510,9 +511,6 @@ def build_common_functions(spec: RegimeSpec) -> dict: functions["is_married"] = labor_market.is_married functions["equivalence_scale"] = preferences.equivalence_scale functions["utility_scale_factor"] = preferences.utility_scale_factor - # Pref-type-indexed scalars: DAG functions resolve the per-cell - # value from the params Series so downstream consumers get a - # scalar broadcast to every cell. functions["consumption_weight"] = preferences.consumption_weight functions["coefficient_rra"] = preferences.coefficient_rra functions["discount_factor"] = preferences.discount_factor @@ -546,7 +544,7 @@ def build_common_functions(spec: RegimeSpec) -> dict: # Cash on hand and transfers functions["cash_on_hand"] = assets_and_income.cash_on_hand - functions["consumption_unequiv_floor"] = assets_and_income.consumption_unequiv_floor + functions["consumption_dollars_floor"] = assets_and_income.consumption_dollars_floor functions["transfers"] = assets_and_income.transfers functions["consumption_equiv"] = preferences.consumption_equiv @@ -646,8 +644,14 @@ def select_target_for_age( def build_state_transitions(spec: RegimeSpec) -> dict: """Build the state transitions dict for a non-dead regime.""" transitions: dict = {} + transitions["assets"] = _build_per_target_regime_assets(spec) transitions["health"] = _build_per_target_regime_health(spec) - transitions["assets"] = _build_per_target_regime_next_assets(spec) + claimed_ss_transition = _build_per_target_regime_claimed_ss(spec) + if claimed_ss_transition: + transitions["claimed_ss"] = claimed_ss_transition + lagged_labor_supply_transition = _build_per_target_regime_lagged_labor_supply(spec) + if lagged_labor_supply_transition: + transitions["lagged_labor_supply"] = lagged_labor_supply_transition transitions["pref_type"] = None transitions["aime"] = ( social_security.next_aime @@ -655,16 +659,10 @@ def build_state_transitions(spec: RegimeSpec) -> dict: else social_security.next_aime_disabled ) transitions["spousal_income"] = MarkovTransition(labor_market.next_spousal_income) - lagged_supply_transition = _build_per_target_regime_lagged_labor_supply(spec) - if lagged_supply_transition: - transitions["lagged_labor_supply"] = lagged_supply_transition - claimed_ss_transition = _build_per_target_regime_claimed_ss(spec) - if claimed_ss_transition: - transitions["claimed_ss"] = claimed_ss_transition return transitions -def _build_per_target_regime_next_assets( +def _build_per_target_regime_assets( spec: RegimeSpec, ) -> dict[RegimeName, Callable[..., FloatND]]: """Build per-target assets transitions. diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 57522f5..5b519d5 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -38,7 +38,7 @@ from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.baseline.model import create_model from aca_model.config import BENCHMARK_GRID_CONFIG -from aca_model.consumption_unequiv_grid import inject_consumption_unequiv_points +from aca_model.consumption_dollars_grid import inject_consumption_dollars_points _PARAMS_FILE = ( Path(__file__).resolve().parent / "_benchmark_data" / "benchmark_params.pkl" @@ -50,6 +50,7 @@ "good_health": DiscreteGrid(GoodHealth), "is_married": DiscreteGrid(IsMarried), "his": DiscreteGrid(HealthInsuranceState), + "target_his": DiscreteGrid(HealthInsuranceState), "pref_type": DiscreteGrid(BenchmarkPrefType), } @@ -99,10 +100,10 @@ def get_benchmark_params( ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: """Load the frozen `(fixed_params, wage_params, params)` snapshot. - When `model` is provided, consumption_unequiv gridpoints are injected - into `params` for each regime that declares `consumption_unequiv` as + When `model` is provided, consumption_dollars gridpoints are injected + into `params` for each regime that declares `consumption_dollars` as an `IrregSpacedGrid` with runtime-supplied points. The lower bound is - read from `params["consumption_unequiv_floor"]`. Pass `model=None` to + read from `params["consumption_dollars_floor"]`. Pass `model=None` to skip injection (e.g. when constructing the model with `fixed_params`). """ with _PARAMS_FILE.open("rb") as fh: @@ -111,7 +112,7 @@ def get_benchmark_params( wage_params = data["wage_params"] params = data["params"] if model is not None: - params = inject_consumption_unequiv_points(params=params, model=model) + params = inject_consumption_dollars_points(params=params, model=model) return fixed_params, wage_params, params diff --git a/src/aca_model/config.py b/src/aca_model/config.py index cfa132d..101ef2d 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -29,7 +29,7 @@ class ModelConfig: class GridConfig: n_assets_gridpoints: int = 24 n_aime_gridpoints: int = 12 - n_consumption_unequiv_gridpoints: int = 70 + n_consumption_dollars_gridpoints: int = 70 n_wage_res_gridpoints: int = 5 n_hcc_persistent_gridpoints: int = 3 n_hcc_transitory_gridpoints: int = 5 @@ -47,7 +47,7 @@ class GridConfig: BENCHMARK_GRID_CONFIG = GridConfig( n_assets_gridpoints=3, n_aime_gridpoints=3, - n_consumption_unequiv_gridpoints=5, + n_consumption_dollars_gridpoints=5, n_wage_res_gridpoints=3, n_hcc_persistent_gridpoints=3, n_hcc_transitory_gridpoints=3, diff --git a/src/aca_model/consumption_unequiv_grid.py b/src/aca_model/consumption_dollars_grid.py similarity index 76% rename from src/aca_model/consumption_unequiv_grid.py rename to src/aca_model/consumption_dollars_grid.py index ba1b74c..7487fd8 100644 --- a/src/aca_model/consumption_unequiv_grid.py +++ b/src/aca_model/consumption_dollars_grid.py @@ -1,12 +1,12 @@ -"""Runtime-supplied gridpoints for the consumption_unequiv action. +"""Runtime-supplied gridpoints for the consumption_dollars action. Consumption is declared as `IrregSpacedGrid(n_points=N)` in `baseline.regimes._common.build_grids` so the bounds can track runtime parameters: the lower bound from the per-iteration `consumption_equiv_floor` parameter (and its couples-scaled twin), -the upper bound from `MAX_CONSUMPTION_UNEQUIV` in +the upper bound from `MAX_CONSUMPTION_DOLLARS` in `baseline.regimes._common`. Callers must inject the actual gridpoints -into `params` via `inject_consumption_unequiv_points` before calling +into `params` via `inject_consumption_dollars_points` before calling `model.solve()` / `model.simulate()`. The grid pins the two regime-relevant transfer-floor levels exactly @@ -16,7 +16,7 @@ - `pts[0] = consumption_equiv_floor` (single household: equiv_scale=1) - `pts[1] = consumption_equiv_floor * 2 ** exponent` (married) -- `pts[2:] = geomspace(pts[1], MAX_CONSUMPTION_UNEQUIV, n_points - 1)` +- `pts[2:] = geomspace(pts[1], MAX_CONSUMPTION_DOLLARS, n_points - 1)` """ from collections.abc import Mapping @@ -26,22 +26,22 @@ from jax import Array from lcm import IrregSpacedGrid, Model -from aca_model.baseline.regimes._common import MAX_CONSUMPTION_UNEQUIV +from aca_model.baseline.regimes._common import MAX_CONSUMPTION_DOLLARS -def inject_consumption_unequiv_points( +def inject_consumption_dollars_points( *, params: Mapping[str, Any], model: Model, ) -> dict[str, Any]: - """Inject consumption_unequiv gridpoints into per-regime params. + """Inject consumption_dollars gridpoints into per-regime params. - Walks every regime, reads its `consumption_unequiv` action grid, - and writes `params[regime_name]["consumption_unequiv"] = {"points": }`. + Walks every regime, reads its `consumption_dollars` action grid, + and writes `params[regime_name]["consumption_dollars"] = {"points": }`. The lower two gridpoints are the single and married unequiv transfer floors; the rest are geomspaced from the married floor up - to `MAX_CONSUMPTION_UNEQUIV`. + to `MAX_CONSUMPTION_DOLLARS`. Args: params: Existing params mapping with `consumption_equiv_floor` @@ -52,10 +52,10 @@ def inject_consumption_unequiv_points( equivalence-scale exponent. Returns: - New params dict with consumption_unequiv points injected. + New params dict with consumption_dollars points injected. Raises: - ValueError: If a regime is missing the `consumption_unequiv` + ValueError: If a regime is missing the `consumption_dollars` action, or its grid is not an `IrregSpacedGrid` with `pass_points_at_runtime=True`. """ @@ -65,16 +65,16 @@ def inject_consumption_unequiv_points( for regime_name, regime in model.regimes.items(): if regime.terminal: continue - grid = regime.actions.get("consumption_unequiv") + grid = regime.actions.get("consumption_dollars") if grid is None: msg = ( - f"Regime {regime_name!r} is missing the `consumption_unequiv` " + f"Regime {regime_name!r} is missing the `consumption_dollars` " f"action — the runtime-points grid must be on every regime." ) raise ValueError(msg) if not (isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime): msg = ( - f"Regime {regime_name!r} has a `consumption_unequiv` action " + f"Regime {regime_name!r} has a `consumption_dollars` action " f"whose grid is not an `IrregSpacedGrid(pass_points_at_runtime=True)`; " f"got {type(grid).__name__}." ) @@ -82,24 +82,24 @@ def inject_consumption_unequiv_points( # Runtime-points grids always have `n_points` set (the constructor # rejects the (points=None, n_points=None) combo); narrow for ty. assert grid.n_points is not None - points = _compute_consumption_unequiv_points( + points = _compute_consumption_dollars_points( consumption_equiv_floor=consumption_equiv_floor, exponent=exponent, n_points=grid.n_points, ) regime_entry = dict(out.get(regime_name, {})) - regime_entry["consumption_unequiv"] = {"points": points} + regime_entry["consumption_dollars"] = {"points": points} out[regime_name] = regime_entry return out -def _compute_consumption_unequiv_points( +def _compute_consumption_dollars_points( *, consumption_equiv_floor: Array, exponent: Array, n_points: int, ) -> Array: - """Return log-spaced consumption_unequiv gridpoints with both floors pinned. + """Return log-spaced consumption_dollars gridpoints with both floors pinned. Single and married households face different unequiv (in-$) floors (`consumption_equiv_floor` and the married-scaled twin @@ -108,12 +108,12 @@ def _compute_consumption_unequiv_points( a feasible action; otherwise sub-ULP drift can flip the `<=` comparison for subjects with very negative cash. The geomspace tail starts at the married floor and runs to - `MAX_CONSUMPTION_UNEQUIV` so the two pinned points stay strictly + `MAX_CONSUMPTION_DOLLARS` so the two pinned points stay strictly increasing. """ married_unequiv_floor = consumption_equiv_floor * jnp.asarray(2.0) ** exponent tail = jnp.geomspace( - married_unequiv_floor, MAX_CONSUMPTION_UNEQUIV, num=n_points - 1 + married_unequiv_floor, MAX_CONSUMPTION_DOLLARS, num=n_points - 1 ) pts = jnp.concatenate([consumption_equiv_floor[None], tail]) # `jnp.geomspace` returns `start * r^0` for the first tail element, diff --git a/tests/helpers/model.py b/tests/helpers/model.py index dc7c407..be778b4 100644 --- a/tests/helpers/model.py +++ b/tests/helpers/model.py @@ -21,6 +21,7 @@ "good_health": DiscreteGrid(GoodHealth), "is_married": DiscreteGrid(IsMarried), "his": DiscreteGrid(HealthInsuranceState), + "target_his": DiscreteGrid(HealthInsuranceState), "pref_type": DiscreteGrid(BenchmarkPrefType), } diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index d3d83f4..b1be815 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -41,13 +41,13 @@ def test_benchmark_model_simulates_end_to_end() -> None: @pytest.mark.long_running def test_benchmark_simulate_obeys_borrowing_constraint() -> None: - """`consumption_unequiv <= max(cash_on_hand, floor)` holds for every alive row. + """`consumption_dollars <= max(cash_on_hand, floor)` holds for every alive row. The simulator only ever picks feasible actions — the borrowing constraint must hold post-hoc on the simulated panel. A regression that drops the constraint from a regime, replaces the floor with something looser, or lets an action grid skip the floor would - surface as a row with `consumption_unequiv > max(cash_on_hand, floor)`. + surface as a row with `consumption_dollars > max(cash_on_hand, floor)`. The constraint's RHS is `max(cash_on_hand, floor)` rather than `cash_on_hand + transfers`: the additive form rounds short by @@ -74,10 +74,10 @@ def test_benchmark_simulate_obeys_borrowing_constraint() -> None: df = result.to_dataframe(additional_targets=["cash_on_hand", "equivalence_scale"]) alive = df.loc[df["regime"] != "dead"].copy() - consumption_unequiv_floor = float(params["consumption_unequiv_floor"]) - floor = consumption_unequiv_floor * alive["equivalence_scale"].to_numpy() + consumption_dollars_floor = float(params["consumption_dollars_floor"]) + floor = consumption_dollars_floor * alive["equivalence_scale"].to_numpy() rhs = np.maximum(alive["cash_on_hand"].to_numpy(), floor) - slack = rhs - alive["consumption_unequiv"].to_numpy() + slack = rhs - alive["consumption_dollars"].to_numpy() assert (slack >= 0).all(), ( f"borrowing_constraint violated on {int((slack < 0).sum())} row(s); " f"min slack = {slack.min():.6g}" diff --git a/tests/test_budget_chain_integration.py b/tests/test_budget_chain_integration.py index 87ab670..f087d16 100644 --- a/tests/test_budget_chain_integration.py +++ b/tests/test_budget_chain_integration.py @@ -108,7 +108,7 @@ def test_retired_agent_with_pension() -> None: def test_transfers_kick_in_below_floor() -> None: - """When cash_on_hand < consumption_unequiv_floor, transfers fill the gap.""" + """When cash_on_hand < consumption_dollars_floor, transfers fill the gap.""" functions = { "cash_on_hand": assets_and_income.cash_on_hand, "transfers": assets_and_income.transfers, @@ -126,9 +126,9 @@ def test_transfers_kick_in_below_floor() -> None: ssi_benefit=jnp.array(0.0), hic_premium=jnp.array(0.0), oop_costs=jnp.array(0.0), - consumption_unequiv_floor=jnp.array(5000.0), + consumption_dollars_floor=jnp.array(5000.0), pension_assets_adjustment=jnp.array(0.0), - consumption_unequiv=jnp.array(4000.0), + consumption_dollars=jnp.array(4000.0), ) # cash_on_hand = 500 + 200 = 700 diff --git a/tests/test_consumption_unequiv_grid.py b/tests/test_consumption_dollars_grid.py similarity index 69% rename from tests/test_consumption_unequiv_grid.py rename to tests/test_consumption_dollars_grid.py index 92593b4..1f42e6f 100644 --- a/tests/test_consumption_unequiv_grid.py +++ b/tests/test_consumption_dollars_grid.py @@ -1,18 +1,18 @@ """Consumption-grid invariants required by the borrowing constraint. The borrowing constraint in `agent.assets_and_income.borrowing_constraint` -compares the lowest consumption_unequiv action against -`max(cash_on_hand, consumption_unequiv_floor)`. For subjects with cash +compares the lowest consumption_dollars action against +`max(cash_on_hand, consumption_dollars_floor)`. For subjects with cash below the floor (large-negative-asset subjects, moderate-negative-asset retirees, etc.) this RHS collapses to exactly -`consumption_unequiv_floor`. The constraint is feasible iff the -relevant household-floor gridpoint is `<= consumption_unequiv_floor`. +`consumption_dollars_floor`. The constraint is feasible iff the +relevant household-floor gridpoint is `<= consumption_dollars_floor`. For singles (`equivalence_scale = 1`) that floor is `consumption_equiv_floor`; for married households (`equivalence_scale = 2 ** exponent`) it is `consumption_equiv_floor * 2 ** exponent`. Both must land **exactly** -on the consumption_unequiv grid. +on the consumption_dollars grid. `jnp.geomspace(start, stop, num=n)` returns `start * r^i` with `r = (stop/start)^(1/(n-1))`; mathematically `r^0 == 1` so the first @@ -21,17 +21,17 @@ A positive drift above the floor flips the kink-boundary `<=` and rejects every action for the affected subjects. -`_compute_consumption_unequiv_points` therefore prepends the singles' +`_compute_consumption_dollars_points` therefore prepends the singles' floor as `pts[0]`, runs `geomspace` from the married floor up to -`MAX_CONSUMPTION_UNEQUIV` for the rest, and pins the geomspace start +`MAX_CONSUMPTION_DOLLARS` for the rest, and pins the geomspace start back to the married floor exactly. Test those invariants directly. """ import jax.numpy as jnp import pytest -from aca_model.baseline.regimes._common import MAX_CONSUMPTION_UNEQUIV -from aca_model.consumption_unequiv_grid import _compute_consumption_unequiv_points +from aca_model.baseline.regimes._common import MAX_CONSUMPTION_DOLLARS +from aca_model.consumption_dollars_grid import _compute_consumption_dollars_points EXPONENT = 0.7 # production value (env_constants["exponent"]) SINGLE_FLOOR = 1597.0921419521899 # production value @@ -39,11 +39,11 @@ @pytest.mark.parametrize("n_points", [5, 16, 64, 70, 100]) -def test_compute_consumption_unequiv_points_first_equals_singles_floor( +def test_compute_consumption_dollars_points_first_equals_singles_floor( n_points: int, ) -> None: """`pts[0]` equals the singles' floor exactly under any `n_points`.""" - pts = _compute_consumption_unequiv_points( + pts = _compute_consumption_dollars_points( consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR), exponent=jnp.asarray(EXPONENT), n_points=n_points, @@ -52,11 +52,11 @@ def test_compute_consumption_unequiv_points_first_equals_singles_floor( @pytest.mark.parametrize("n_points", [5, 16, 64, 70, 100]) -def test_compute_consumption_unequiv_points_second_equals_married_floor( +def test_compute_consumption_dollars_points_second_equals_married_floor( n_points: int, ) -> None: """`pts[1]` equals `consumption_equiv_floor * 2 ** exponent` exactly.""" - pts = _compute_consumption_unequiv_points( + pts = _compute_consumption_dollars_points( consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR), exponent=jnp.asarray(EXPONENT), n_points=n_points, @@ -65,9 +65,9 @@ def test_compute_consumption_unequiv_points_second_equals_married_floor( assert float(pts[1]) == expected -def test_compute_consumption_unequiv_points_strictly_increasing() -> None: +def test_compute_consumption_dollars_points_strictly_increasing() -> None: """Gridpoints are strictly increasing — no kink-pinning ties.""" - pts = _compute_consumption_unequiv_points( + pts = _compute_consumption_dollars_points( consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR), exponent=jnp.asarray(EXPONENT), n_points=70, @@ -76,11 +76,11 @@ def test_compute_consumption_unequiv_points_strictly_increasing() -> None: assert bool((diffs > 0).all()) -def test_compute_consumption_unequiv_points_last_equals_max() -> None: +def test_compute_consumption_dollars_points_last_equals_max() -> None: """The final point is the configured upper bound.""" - pts = _compute_consumption_unequiv_points( + pts = _compute_consumption_dollars_points( consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR), exponent=jnp.asarray(EXPONENT), n_points=70, ) - assert float(pts[-1]) == pytest.approx(MAX_CONSUMPTION_UNEQUIV) + assert float(pts[-1]) == pytest.approx(MAX_CONSUMPTION_DOLLARS) diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py index ef6e1e9..3b16522 100644 --- a/tests/test_initial_conditions_extreme_assets.py +++ b/tests/test_initial_conditions_extreme_assets.py @@ -2,7 +2,7 @@ The transfer system (`agent.assets_and_income.transfers`) tops cash-on-hand to the household-$ consumption floor at any starting state, so the lowest -consumption_unequiv-grid point is always a feasible action regardless of +consumption_dollars-grid point is always a feasible action regardless of how negative starting assets are. The model's constraints — and pylcm's `validate_initial_conditions` pass — must reflect this. """ @@ -20,22 +20,22 @@ ) -def test_borrowing_constraint_admits_consumption_unequiv_at_floor() -> None: - """`consumption_unequiv == consumption_unequiv_floor` at the kink is feasible by equality.""" - consumption_unequiv_floor = jnp.asarray(5_000.0) +def test_borrowing_constraint_admits_consumption_dollars_at_floor() -> None: + """`consumption_dollars == consumption_dollars_floor` at the kink is feasible by equality.""" + consumption_dollars_floor = jnp.asarray(5_000.0) cash_on_hand = jnp.asarray(-50_000.0) # below floor — RHS = floor admitted = bool( borrowing_constraint( - consumption_unequiv=consumption_unequiv_floor, + consumption_dollars=consumption_dollars_floor, cash_on_hand=cash_on_hand, - consumption_unequiv_floor=consumption_unequiv_floor, + consumption_dollars_floor=consumption_dollars_floor, ) ) assert admitted -def test_borrowing_constraint_admits_consumption_unequiv_at_married_floor() -> None: +def test_borrowing_constraint_admits_consumption_dollars_at_married_floor() -> None: """At a married household's higher floor, the equivalence-scale-lifted floor is feasible.""" consumption_equiv_floor = jnp.asarray(5_000.0) married_floor = consumption_equiv_floor * jnp.asarray(2.0) ** 0.7 @@ -43,27 +43,27 @@ def test_borrowing_constraint_admits_consumption_unequiv_at_married_floor() -> N admitted = bool( borrowing_constraint( - consumption_unequiv=married_floor, + consumption_dollars=married_floor, cash_on_hand=cash_on_hand, - consumption_unequiv_floor=married_floor, + consumption_dollars_floor=married_floor, ) ) assert admitted -def test_borrowing_constraint_rejects_consumption_unequiv_above_post_transfer_resources() -> ( +def test_borrowing_constraint_rejects_consumption_dollars_above_post_transfer_resources() -> ( None ): - """`consumption_unequiv > max(cash_on_hand, floor)` is rejected.""" - consumption_unequiv_floor = jnp.asarray(5_000.0) + """`consumption_dollars > max(cash_on_hand, floor)` is rejected.""" + consumption_dollars_floor = jnp.asarray(5_000.0) cash_on_hand = jnp.asarray(-50_000.0) - consumption_unequiv = consumption_unequiv_floor + 1.0 + consumption_dollars = consumption_dollars_floor + 1.0 admitted = bool( borrowing_constraint( - consumption_unequiv=consumption_unequiv, + consumption_dollars=consumption_dollars, cash_on_hand=cash_on_hand, - consumption_unequiv_floor=consumption_unequiv_floor, + consumption_dollars_floor=consumption_dollars_floor, ) ) assert not admitted @@ -74,19 +74,19 @@ def test_borrowing_constraint_admits_floor_at_million_dollar_negative_cash() -> At large negative `assets`, the algebraically equivalent `cash_on_hand + transfers` form rounds to `floor - 5.7e-11` at fp64, - flipping `consumption_unequiv <= ...` for the lowest - consumption_unequiv gridpoint. The `max(cash_on_hand, floor)` form + flipping `consumption_dollars <= ...` for the lowest + consumption_dollars gridpoint. The `max(cash_on_hand, floor)` form returns `floor` exactly. """ - consumption_unequiv_floor = jnp.asarray(1597.0921419521899) # production value + consumption_dollars_floor = jnp.asarray(1597.0921419521899) # production value cash_on_hand = jnp.asarray(-1_000_000.0) - consumption_unequiv = consumption_unequiv_floor # lowest grid point + consumption_dollars = consumption_dollars_floor # lowest grid point admitted = bool( borrowing_constraint( - consumption_unequiv=consumption_unequiv, + consumption_dollars=consumption_dollars, cash_on_hand=cash_on_hand, - consumption_unequiv_floor=consumption_unequiv_floor, + consumption_dollars_floor=consumption_dollars_floor, ) ) assert admitted diff --git a/tests/test_model_components.py b/tests/test_model_components.py index 9de3b60..5b7df6a 100644 --- a/tests/test_model_components.py +++ b/tests/test_model_components.py @@ -77,7 +77,7 @@ def test_leisure_bad_health() -> None: def test_utility_positive_leisure() -> None: - result = preferences.u_working_life( + result = preferences.u_can_work( consumption_equiv=jnp.array(10000.0), leisure=jnp.array(3000.0), consumption_weight=jnp.array(0.4), @@ -88,7 +88,7 @@ def test_utility_positive_leisure() -> None: def test_utility_log_case() -> None: - result = preferences.u_working_life( + result = preferences.u_can_work( consumption_equiv=jnp.array(10000.0), leisure=jnp.array(3000.0), consumption_weight=jnp.array(0.4), diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index 37984da..7ae6e36 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -76,7 +76,7 @@ def test_forcedout_regimes_no_labor_supply(name: str) -> None: regime = build_regime(name) assert "labor_supply" not in regime.actions assert "log_ft_wage_res" not in regime.states - assert "consumption_unequiv" in regime.actions + assert "consumption_dollars" in regime.actions @pytest.mark.parametrize( diff --git a/tests/test_pension_integration.py b/tests/test_pension_integration.py index 0fabb35..0f6c07d 100644 --- a/tests/test_pension_integration.py +++ b/tests/test_pension_integration.py @@ -95,7 +95,7 @@ def test_next_assets_includes_pension_adjustment() -> None: cash_on_hand=jnp.array(100_000.0), transfers=jnp.array(0.0), pension_assets_adjustment=jnp.array(5_000.0), - consumption_unequiv=jnp.array(80_000.0), + consumption_dollars=jnp.array(80_000.0), oop_costs=jnp.array(0.0), ) assert jnp.isclose(result, 25_000.0, atol=ATOL) diff --git a/tests/test_preferences.py b/tests/test_preferences.py index 14ab67d..1b5107f 100644 --- a/tests/test_preferences.py +++ b/tests/test_preferences.py @@ -16,7 +16,7 @@ RATE_OF_RETURN = 0.01 BEQUEST_WEIGHT = 0.02 BEQUEST_SHIFTER = 500_000.0 -REFERENCE_HOURS = 500.0 +REFERENCE_HOURS = 1000.0 # --- utility_scale_factor --- @@ -24,26 +24,26 @@ def test_utility_scale_factor_crra() -> None: result = preferences.utility_scale_factor( - average_consumption_unequiv=AVERAGE_CONSUMPTION, + average_consumption_dollars=AVERAGE_CONSUMPTION, consumption_weight=jnp.array(CONSUMPTION_WEIGHT), coefficient_rra=jnp.array(5.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, ) - assert jnp.isclose(result, 1.114_807_837_680_009_4e16, rtol=1e-6) + assert jnp.isclose(result, 9_233_279_397_806_166.0, rtol=1e-6) def test_utility_scale_factor_log() -> None: result = preferences.utility_scale_factor( - average_consumption_unequiv=AVERAGE_CONSUMPTION, + average_consumption_dollars=AVERAGE_CONSUMPTION, consumption_weight=jnp.array(CONSUMPTION_WEIGHT), coefficient_rra=jnp.array(1.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, ) - assert jnp.isclose(result, 0.112_474_080_852_230_33, rtol=1e-6) + assert jnp.isclose(result, 0.113_073_257_794_546_72, rtol=1e-6) # --- scaled_bequest_weight --- @@ -90,60 +90,60 @@ def test_scaled_bequest_weight_zero() -> None: def test_utility_log_regression() -> None: scale = preferences.utility_scale_factor( - average_consumption_unequiv=AVERAGE_CONSUMPTION, + average_consumption_dollars=AVERAGE_CONSUMPTION, consumption_weight=jnp.array(CONSUMPTION_WEIGHT), coefficient_rra=jnp.array(1.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, ) - result = preferences.u_working_life( + result = preferences.u_can_work( consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), consumption_weight=jnp.array(CONSUMPTION_WEIGHT), coefficient_rra=jnp.array(1.0), utility_scale_factor=scale, ) - assert jnp.isclose(result, 0.999_720_557_696_258_7, rtol=1e-5) + assert jnp.isclose(result, 1.005_046_313_660_588_5, rtol=1e-5) def test_utility_crra_regression() -> None: scale = preferences.utility_scale_factor( - average_consumption_unequiv=AVERAGE_CONSUMPTION, + average_consumption_dollars=AVERAGE_CONSUMPTION, consumption_weight=jnp.array(CONSUMPTION_WEIGHT), coefficient_rra=jnp.array(5.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, ) - result = preferences.u_working_life( + result = preferences.u_can_work( consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), consumption_weight=jnp.array(CONSUMPTION_WEIGHT), coefficient_rra=jnp.array(5.0), utility_scale_factor=scale, ) - assert jnp.isclose(result, -1.009_987_562_073_720_9, rtol=1e-5) + assert jnp.isclose(result, -0.836_511_642_073_019_1, rtol=1e-5) def test_utility_married_equivalence() -> None: - """Married with equiv-scaled consumption_unequiv should equal single utility.""" + """Married with equiv-scaled consumption_dollars should equal single utility.""" scale = preferences.utility_scale_factor( - average_consumption_unequiv=AVERAGE_CONSUMPTION, + average_consumption_dollars=AVERAGE_CONSUMPTION, consumption_weight=jnp.array(CONSUMPTION_WEIGHT), coefficient_rra=jnp.array(5.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, reference_hours=REFERENCE_HOURS, ) - single = preferences.u_working_life( + single = preferences.u_can_work( consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), consumption_weight=jnp.array(CONSUMPTION_WEIGHT), coefficient_rra=jnp.array(5.0), utility_scale_factor=scale, ) - married = preferences.u_working_life( + married = preferences.u_can_work( consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), consumption_weight=jnp.array(CONSUMPTION_WEIGHT), @@ -158,7 +158,7 @@ def test_utility_married_equivalence() -> None: def test_bequest_log_regression() -> None: scale = preferences.utility_scale_factor( - average_consumption_unequiv=AVERAGE_CONSUMPTION, + average_consumption_dollars=AVERAGE_CONSUMPTION, consumption_weight=jnp.array(CONSUMPTION_WEIGHT), coefficient_rra=jnp.array(1.0), time_endowment=TIME_ENDOWMENT, @@ -181,12 +181,12 @@ def test_bequest_log_regression() -> None: coefficient_rra=jnp.array(1.0), utility_scale_factor=scale, ) - assert jnp.isclose(result, 86.080_677_139_309_2, rtol=1e-5) + assert jnp.isclose(result, 86.539_249_963_643_88, rtol=1e-5) def test_bequest_crra_regression() -> None: scale = preferences.utility_scale_factor( - average_consumption_unequiv=AVERAGE_CONSUMPTION, + average_consumption_dollars=AVERAGE_CONSUMPTION, consumption_weight=jnp.array(CONSUMPTION_WEIGHT), coefficient_rra=jnp.array(5.0), time_endowment=TIME_ENDOWMENT, @@ -209,4 +209,4 @@ def test_bequest_crra_regression() -> None: coefficient_rra=jnp.array(5.0), utility_scale_factor=scale, ) - assert jnp.isclose(result, -45.799_247_573_576_66, rtol=1e-5) + assert jnp.isclose(result, -37.932_748_117_035_63, rtol=1e-5)