From 1b5b2997504f7c0d50982308948eb6a309025595 Mon Sep 17 00:00:00 2001 From: christian harrington Date: Tue, 3 Mar 2026 17:49:37 +0000 Subject: [PATCH 1/6] refactor the input dfs into a dynamic input object for greater container flexibility and less piping changes for pools requiring reserve inputs --- quantammsim/core_simulator/dynamic_inputs.py | 112 +++++++++ quantammsim/core_simulator/forward_pass.py | 102 ++++---- quantammsim/hooks/dynamic_fee_base_hook.py | 19 +- quantammsim/pools/ECLP/gyroscope.py | 11 +- quantammsim/pools/FM_AMM/cow_pool.py | 18 +- quantammsim/pools/G3M/balancer/balancer.py | 22 +- .../pools/G3M/quantamm/TFMM_base_pool.py | 12 +- quantammsim/pools/base_pool.py | 6 +- quantammsim/pools/hodl_pool.py | 12 +- quantammsim/pools/reCLAMM/reclamm.py | 18 +- quantammsim/runners/__init__.py | 4 +- quantammsim/runners/jax_runner_utils.py | 127 ++++++---- quantammsim/runners/jax_runners.py | 127 ++++------ quantammsim/runners/multi_period_sgd.py | 2 + quantammsim/runners/training_evaluator.py | 1 + .../finance/param_financial_calculator.py | 15 +- scripts/demo_run_chunks_from_chain_data.py | 11 +- scripts/demo_run_from_chain_data.py | 13 +- scripts/reclamm/sim_vs_world_comparison.py | 35 ++- tests/integration/test_dynamic_gas_fees.py | 31 ++- .../pools/reCLAMM/test_reclamm_fee_revenue.py | 18 +- tests/scripts/dynamic_gas_test.py | 20 +- tests/unit/test_jax_runner_utils.py | 236 ++++++++++++++++++ tests/unit/test_jax_runners_comprehensive.py | 217 ++++++++++++++++ tests/unit/test_lint_bugs.py | 9 +- 25 files changed, 908 insertions(+), 290 deletions(-) create mode 100644 quantammsim/core_simulator/dynamic_inputs.py diff --git a/quantammsim/core_simulator/dynamic_inputs.py b/quantammsim/core_simulator/dynamic_inputs.py new file mode 100644 index 0000000..b1eedac --- /dev/null +++ b/quantammsim/core_simulator/dynamic_inputs.py @@ -0,0 +1,112 @@ +from dataclasses import dataclass +from typing import Any, NamedTuple, Optional + +import jax.numpy as jnp + + +@dataclass(frozen=True) +class DynamicInputFrames: + """Outer-layer container for optional pandas-backed dynamic inputs.""" + + trades: Optional[Any] = None + fees: Optional[Any] = None + gas_cost: Optional[Any] = None + arb_fees: Optional[Any] = None + lp_supply: Optional[Any] = None + + +class DynamicInputArrays(NamedTuple): + """Fixed-structure JAX pytree for dynamic simulation inputs.""" + + trades: jnp.ndarray + fees: jnp.ndarray + gas_cost: jnp.ndarray + arb_fees: jnp.ndarray + lp_supply: jnp.ndarray + + +def default_dynamic_input_flags() -> dict: + """Static dispatch flags for forward-pass path selection.""" + return { + "use_dynamic_inputs": False, + "has_trades": False, + "has_dynamic_fees": False, + "has_dynamic_gas_cost": False, + "has_dynamic_arb_fees": False, + "has_lp_supply": False, + } + + +def dynamic_input_flags_from_frames(dynamic_input_frames: Optional[DynamicInputFrames]) -> dict: + """Build stable dispatch flags from the outer-layer frame container.""" + if dynamic_input_frames is None: + return default_dynamic_input_flags() + + flags = { + "use_dynamic_inputs": False, + "has_trades": dynamic_input_frames.trades is not None, + "has_dynamic_fees": dynamic_input_frames.fees is not None, + "has_dynamic_gas_cost": dynamic_input_frames.gas_cost is not None, + "has_dynamic_arb_fees": dynamic_input_frames.arb_fees is not None, + "has_lp_supply": dynamic_input_frames.lp_supply is not None, + } + flags["use_dynamic_inputs"] = any(flags.values()) + return flags + + +def resolve_dynamic_input_flags( + dynamic_inputs: Optional[DynamicInputArrays], + dynamic_input_flags: Optional[dict] = None, +) -> dict: + """Return a safe dispatch flag set for the provided hot-path bundle.""" + flags = ( + default_dynamic_input_flags() + if dynamic_input_flags is None + else dict(dynamic_input_flags) + ) + if dynamic_inputs is not None: + flags["use_dynamic_inputs"] = True + return flags + + +def empty_dynamic_input_arrays() -> DynamicInputArrays: + """Create a canonical empty bundle with stable pytree structure.""" + return DynamicInputArrays( + trades=jnp.zeros((1, 3), dtype=jnp.float64), + fees=jnp.zeros((1,), dtype=jnp.float64), + gas_cost=jnp.zeros((1,), dtype=jnp.float64), + arb_fees=jnp.zeros((1,), dtype=jnp.float64), + lp_supply=jnp.ones((1,), dtype=jnp.float64), + ) + + +def resolve_dynamic_input_components( + dynamic_inputs: Optional[DynamicInputArrays], + dynamic_input_flags: dict, + static_dict: dict, +) -> dict: + """Resolve dynamic-input leaves against static scalar defaults.""" + arrays = empty_dynamic_input_arrays() if dynamic_inputs is None else dynamic_inputs + return { + "trades": arrays.trades if dynamic_input_flags["has_trades"] else None, + "fees": ( + arrays.fees + if dynamic_input_flags["has_dynamic_fees"] + else jnp.asarray([static_dict["fees"]], dtype=jnp.float64) + ), + "gas_cost": ( + arrays.gas_cost + if dynamic_input_flags["has_dynamic_gas_cost"] + else jnp.asarray([static_dict["gas_cost"]], dtype=jnp.float64) + ), + "arb_fees": ( + arrays.arb_fees + if dynamic_input_flags["has_dynamic_arb_fees"] + else jnp.asarray([static_dict["arb_fees"]], dtype=jnp.float64) + ), + "lp_supply": ( + arrays.lp_supply + if dynamic_input_flags["has_lp_supply"] + else jnp.ones((1,), dtype=jnp.float64) + ), + } diff --git a/quantammsim/core_simulator/forward_pass.py b/quantammsim/core_simulator/forward_pass.py index 205feed..4dd74f6 100644 --- a/quantammsim/core_simulator/forward_pass.py +++ b/quantammsim/core_simulator/forward_pass.py @@ -54,11 +54,28 @@ import numpy as np from functools import partial +from quantammsim.core_simulator.dynamic_inputs import ( + DynamicInputArrays, + default_dynamic_input_flags, + empty_dynamic_input_arrays, + resolve_dynamic_input_flags, +) np.seterr(all="raise") np.seterr(under="print") +def _resolve_dynamic_inputs(dynamic_inputs, static_dict): + """Return a stable hot-path bundle plus static dispatch flags.""" + dynamic_input_flags = resolve_dynamic_input_flags( + dynamic_inputs, + static_dict.get("dynamic_input_flags"), + ) + if dynamic_inputs is None: + dynamic_inputs = empty_dynamic_input_arrays() + return dynamic_inputs, dynamic_input_flags + + def _apply_price_noise(prices, sigma, seed_int): """Apply multiplicative log-normal noise to prices. @@ -703,15 +720,12 @@ def _calculate_return_value( return return_metrics[return_val]() -@partial(jit, static_argnums=(7, 8)) +@partial(jit, static_argnums=(4, 5)) def forward_pass( params, start_index, prices, - trades_array=None, - fees_array=None, - gas_cost_array=None, - arb_fees_array=None, + dynamic_inputs=None, pool=None, static_dict=None, ): @@ -734,17 +748,8 @@ def forward_pass( prices : array-like A 2D array of market prices for the assets involved in the simulation. - trades_array : array-like, optional - An array of trades to be considered in the simulation. Defaults to None. - - fees_array : array-like, optional - An array of fees to be applied during the simulation. Defaults to None. - - gas_cost_array : array-like, optional - An array of gas costs to be considered in the simulation. Defaults to None. - - arb_fees_array : array-like, optional - An array of arbitrage fees to be applied during the simulation. Defaults to None. + dynamic_inputs : DynamicInputArrays, optional + Fixed-structure bundle of dynamic trades/fees/gas/arb/LP arrays. pool : object An instance of a pool object that provides methods @@ -794,8 +799,8 @@ def forward_pass( - The function handles different cases for fees and trades, adjusting the calculation method accordingly: - 1. If any of `fees_array`, `gas_cost_array`, `arb_fees_array`, - or `trades_array` is provided, it uses `pool.calculate_reserves_with_dynamic_inputs`. + 1. If any dynamic-input flags are enabled, it uses + `pool.calculate_reserves_with_dynamic_inputs`. 2. If any of `fees`, `gas_cost`, or `arb_fees` in `static_dict` is a nonzero scalar value, it uses `pool.calculate_reserves_with_fees`. @@ -836,6 +841,7 @@ def forward_pass( "training_data_kind": "historic", "arb_frequency": 1, "do_trades": False, + "dynamic_input_flags": default_dynamic_input_flags(), } # 'pool' has default of None only to handle how partial function @@ -864,28 +870,20 @@ def forward_pass( # 1. Any of Fees, gas costs, and arb fees are provided as arrays, or trades are provided # 2. Any of Fees, gas costs, and arb fees are nonzero scalar values, with no trades provided # 3. Fees, gas costs, and arb fees are all zero, with no trades provided + dynamic_inputs, dynamic_input_flags = _resolve_dynamic_inputs( + dynamic_inputs, static_dict + ) + fee_revenue = None - if any( - ele is not None - for ele in [fees_array, gas_cost_array, arb_fees_array, trades_array] - ): - # Case 1, at least one of fees, gas costs, or arb fees is not None - if fees_array is None: - fees_array = jnp.array([static_dict["fees"]]) - if gas_cost_array is None: - gas_cost_array = jnp.array([static_dict["gas_cost"]]) - if arb_fees_array is None: - arb_fees_array = jnp.array([static_dict["arb_fees"]]) + if dynamic_input_flags["use_dynamic_inputs"]: + # Case 1, at least one dynamic input is enabled if hasattr(pool, "calculate_reserves_and_fee_revenue_with_dynamic_inputs"): reserves, fee_revenue = pool.calculate_reserves_and_fee_revenue_with_dynamic_inputs( params, static_dict, prices, start_index, - fees_array=fees_array, - arb_thresh_array=gas_cost_array, - arb_fees_array=arb_fees_array, - trade_array=trades_array, + dynamic_inputs=dynamic_inputs, ) else: reserves = pool.calculate_reserves_with_dynamic_inputs( @@ -893,10 +891,7 @@ def forward_pass( static_dict, prices, start_index, - fees_array=fees_array, - arb_thresh_array=gas_cost_array, - arb_fees_array=arb_fees_array, - trade_array=trades_array, + dynamic_inputs=dynamic_inputs, ) elif True in ( ele > 0.0 @@ -1003,15 +998,12 @@ def forward_pass( return base_metric -@partial(jit, static_argnums=(7, 8)) +@partial(jit, static_argnums=(4, 5)) def forward_pass_nograd( params, start_index, prices, - trades_array=None, - fees_array=None, - gas_cost_array=None, - arb_fees_array=None, + dynamic_inputs=None, pool=None, static_dict=None, ): @@ -1036,17 +1028,8 @@ def forward_pass_nograd( prices : array-like A 2D array of market prices for the assets involved in the simulation. - trades_array : array-like, optional - An array of trades to be considered in the simulation. Defaults to None. - - fees_array : array-like, optional - An array of fees to be applied during the simulation. Defaults to None. - - gas_cost_array : array-like, optional - An array of gas costs to be considered in the simulation. Defaults to None. - - arb_fees_array : array-like, optional - An array of arbitrage fees to be applied during the simulation. Defaults to None. + dynamic_inputs : DynamicInputArrays, optional + Fixed-structure bundle of dynamic trades/fees/gas/arb/LP arrays. pool : object An instance of a pool object that provides methods @@ -1096,8 +1079,8 @@ def forward_pass_nograd( - The function handles different cases for fees and trades, adjusting the calculation method accordingly: - 1. If any of `fees_array`, `gas_cost_array`, `arb_fees_array`, - or `trades_array` is provided, it uses `pool.calculate_reserves_with_dynamic_inputs`. + 1. If any dynamic-input flags are enabled, it uses + `pool.calculate_reserves_with_dynamic_inputs`. 2. If any of `fees`, `gas_cost`, or `arb_fees` in `static_dict` is a nonzero scalar value, it uses `pool.calculate_reserves_with_fees`. @@ -1122,14 +1105,15 @@ def forward_pass_nograd( params = {k: stop_gradient(v) for k, v in params.items()} start_index = stop_gradient(start_index) prices = stop_gradient(prices) + if dynamic_inputs is not None: + dynamic_inputs = DynamicInputArrays( + *(stop_gradient(arr) for arr in dynamic_inputs) + ) return forward_pass( params, start_index, prices, - trades_array, - fees_array, - gas_cost_array, - arb_fees_array, + dynamic_inputs, pool, static_dict, ) diff --git a/quantammsim/hooks/dynamic_fee_base_hook.py b/quantammsim/hooks/dynamic_fee_base_hook.py index 40a9714..9e00af6 100644 --- a/quantammsim/hooks/dynamic_fee_base_hook.py +++ b/quantammsim/hooks/dynamic_fee_base_hook.py @@ -3,6 +3,10 @@ import jax.numpy as jnp from jax.lax import dynamic_slice +from quantammsim.core_simulator.dynamic_inputs import ( + DynamicInputArrays, + empty_dynamic_input_arrays, +) class BaseDynamicFeeHook(ABC): """Mixin class to add dynamic fee calculation capabilities to pools. @@ -113,16 +117,21 @@ def calculate_reserves_with_fees( (int((bout_length) / chunk_period), 1), ) dynamic_fees = raw_dynamic_fees.repeat(chunk_period, axis=0).squeeze() - # Use existing dynamic inputs infrastructure + empty_inputs = empty_dynamic_input_arrays() + dynamic_inputs = DynamicInputArrays( + trades=empty_inputs.trades, + fees=dynamic_fees, + gas_cost=jnp.asarray(run_fingerprint["gas_cost"], dtype=jnp.float64), + arb_fees=jnp.asarray(run_fingerprint["arb_fees"], dtype=jnp.float64), + lp_supply=empty_inputs.lp_supply, + ) + return self.calculate_reserves_with_dynamic_inputs( params, run_fingerprint, prices, start_index, - dynamic_fees, - run_fingerprint["gas_cost"], - run_fingerprint["arb_fees"], - dynamic_fees, + dynamic_inputs, additional_oracle_input, ) diff --git a/quantammsim/pools/ECLP/gyroscope.py b/quantammsim/pools/ECLP/gyroscope.py index ad7884a..5aebcef 100644 --- a/quantammsim/pools/ECLP/gyroscope.py +++ b/quantammsim/pools/ECLP/gyroscope.py @@ -322,11 +322,7 @@ def calculate_reserves_with_dynamic_inputs( run_fingerprint: Dict[str, Any], prices: jnp.ndarray, start_index: jnp.ndarray, - fees_array: jnp.ndarray, - arb_thresh_array: jnp.ndarray, - arb_fees_array: jnp.ndarray, - trade_array: jnp.ndarray, - lp_supply_array: jnp.ndarray = None, + dynamic_inputs, additional_oracle_input: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: # Gyroscope ECLP pools are only defined for 2 assets @@ -346,6 +342,11 @@ def calculate_reserves_with_dynamic_inputs( else: arb_acted_upon_local_prices = local_prices + fees_array = dynamic_inputs.fees + arb_thresh_array = dynamic_inputs.gas_cost + arb_fees_array = dynamic_inputs.arb_fees + trade_array = dynamic_inputs.trades + # calculate initial reserves initial_pool_value = run_fingerprint["initial_pool_value"] initial_reserves = initialise_gyroscope_reserves_given_value( diff --git a/quantammsim/pools/FM_AMM/cow_pool.py b/quantammsim/pools/FM_AMM/cow_pool.py index 5a9536f..96299ce 100644 --- a/quantammsim/pools/FM_AMM/cow_pool.py +++ b/quantammsim/pools/FM_AMM/cow_pool.py @@ -58,10 +58,9 @@ class CowPool(AbstractPool): start_index, additional_oracle_input=None) -> jnp.ndarray: Calculates the reserves of the pool without considering fees. - calculate_reserves_with_dynamic_inputs(params, run_fingerprint, prices, - start_index, fees_array, arb_thresh_array, arb_fees_array, trade_array, - additional_oracle_input=None) -> jnp.ndarray: - Calculates the reserves of the pool with dynamic inputs for fees, + calculate_reserves_with_dynamic_inputs(params, run_fingerprint, prices, + start_index, dynamic_inputs, additional_oracle_input=None) -> jnp.ndarray: + Calculates the reserves of the pool with dynamic inputs for fees, arbitrage thresholds, arbitrage fees, and trades. init_base_parameters(initial_values_dict, run_fingerprint, n_assets, @@ -196,11 +195,7 @@ def calculate_reserves_with_dynamic_inputs( run_fingerprint: Dict[str, Any], prices: jnp.ndarray, start_index: jnp.ndarray, - fees_array: jnp.ndarray, - arb_thresh_array: jnp.ndarray, - arb_fees_array: jnp.ndarray, - trade_array: jnp.ndarray, - lp_supply_array: jnp.ndarray = None, + dynamic_inputs, additional_oracle_input: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: bout_length = run_fingerprint["bout_length"] @@ -216,6 +211,11 @@ def calculate_reserves_with_dynamic_inputs( else: arb_acted_upon_local_prices = local_prices + fees_array = dynamic_inputs.fees + arb_thresh_array = dynamic_inputs.gas_cost + arb_fees_array = dynamic_inputs.arb_fees + trade_array = dynamic_inputs.trades + initial_pool_value = run_fingerprint["initial_pool_value"] initial_value_per_token = weights * initial_pool_value initial_reserves = initial_value_per_token / arb_acted_upon_local_prices[0] diff --git a/quantammsim/pools/G3M/balancer/balancer.py b/quantammsim/pools/G3M/balancer/balancer.py index 0d7ec30..986d49c 100644 --- a/quantammsim/pools/G3M/balancer/balancer.py +++ b/quantammsim/pools/G3M/balancer/balancer.py @@ -257,11 +257,7 @@ def calculate_reserves_with_dynamic_inputs( run_fingerprint: Dict[str, Any], prices: jnp.ndarray, start_index: jnp.ndarray, - fees_array: jnp.ndarray, - arb_thresh_array: jnp.ndarray, - arb_fees_array: jnp.ndarray, - trade_array: jnp.ndarray, - lp_supply_array: jnp.ndarray = None, + dynamic_inputs, additional_oracle_input: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: """ @@ -289,14 +285,8 @@ def calculate_reserves_with_dynamic_inputs( Price history array start_index : jnp.ndarray Starting index for the calculation window - fees_array : jnp.ndarray - Time-varying trading fees - arb_thresh_array : jnp.ndarray - Time-varying arbitrage thresholds - arb_fees_array : jnp.ndarray - Time-varying arbitrage fees - trade_array : jnp.ndarray - Custom trade sequence + dynamic_inputs : DynamicInputArrays + Fixed-structure bundle of dynamic inputs. Returns ------- @@ -316,6 +306,12 @@ def calculate_reserves_with_dynamic_inputs( else: arb_acted_upon_local_prices = local_prices + fees_array = dynamic_inputs.fees + arb_thresh_array = dynamic_inputs.gas_cost + arb_fees_array = dynamic_inputs.arb_fees + trade_array = dynamic_inputs.trades + lp_supply_array = dynamic_inputs.lp_supply + initial_pool_value = run_fingerprint["initial_pool_value"] initial_value_per_token = weights * initial_pool_value initial_reserves = initial_value_per_token / arb_acted_upon_local_prices[0] diff --git a/quantammsim/pools/G3M/quantamm/TFMM_base_pool.py b/quantammsim/pools/G3M/quantamm/TFMM_base_pool.py index e15b361..b087b5a 100644 --- a/quantammsim/pools/G3M/quantamm/TFMM_base_pool.py +++ b/quantammsim/pools/G3M/quantamm/TFMM_base_pool.py @@ -254,11 +254,7 @@ def calculate_reserves_with_dynamic_inputs( run_fingerprint: Dict[str, Any], prices: jnp.ndarray, start_index: jnp.ndarray, - fees_array: jnp.ndarray, - arb_thresh_array: jnp.ndarray, - arb_fees_array: jnp.ndarray, - trade_array: jnp.ndarray, - lp_supply_array: jnp.ndarray = None, + dynamic_inputs, additional_oracle_input: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: bout_length = run_fingerprint["bout_length"] @@ -278,6 +274,12 @@ def calculate_reserves_with_dynamic_inputs( arb_acted_upon_weights = weights arb_acted_upon_local_prices = local_prices + fees_array = dynamic_inputs.fees + arb_thresh_array = dynamic_inputs.gas_cost + arb_fees_array = dynamic_inputs.arb_fees + trade_array = dynamic_inputs.trades + lp_supply_array = dynamic_inputs.lp_supply + initial_pool_value = run_fingerprint["initial_pool_value"] initial_value_per_token = arb_acted_upon_weights[0] * initial_pool_value initial_reserves = initial_value_per_token / arb_acted_upon_local_prices[0] diff --git a/quantammsim/pools/base_pool.py b/quantammsim/pools/base_pool.py index cc2d7b3..d8c6d75 100644 --- a/quantammsim/pools/base_pool.py +++ b/quantammsim/pools/base_pool.py @@ -93,11 +93,7 @@ def calculate_reserves_with_dynamic_inputs( run_fingerprint: Dict[str, Any], prices: jnp.ndarray, start_index: jnp.ndarray, - fees_array: jnp.ndarray, - arb_thresh_array: jnp.ndarray, - arb_fees_array: jnp.ndarray, - trade_array: jnp.ndarray, - lp_supply_array: jnp.ndarray = None, + dynamic_inputs: Any, additional_oracle_input: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: pass diff --git a/quantammsim/pools/hodl_pool.py b/quantammsim/pools/hodl_pool.py index 4c1085e..c0740fb 100644 --- a/quantammsim/pools/hodl_pool.py +++ b/quantammsim/pools/hodl_pool.py @@ -39,9 +39,9 @@ class HODLPool(AbstractPool): additional_oracle_input=None): Calculates the reserves without fees, assuming no trading activity. - calculate_reserves_with_dynamic_inputs(params, run_fingerprint, prices, start_index, - fees_array, arb_thresh_array, arb_fees_array, trade_array, additional_oracle_input=None): - Calculates the reserves with dynamic inputs, which in this case is + calculate_reserves_with_dynamic_inputs(params, run_fingerprint, prices, start_index, + dynamic_inputs, additional_oracle_input=None): + Calculates the reserves with dynamic inputs, which in this case is the same as reserves without fees due to no activity. init_base_parameters(initial_values_dict, run_fingerprint, n_assets, @@ -126,11 +126,7 @@ def calculate_reserves_with_dynamic_inputs( run_fingerprint: Dict[str, Any], prices: jnp.ndarray, start_index: jnp.ndarray, - fees_array: jnp.ndarray, - arb_thresh_array: jnp.ndarray, - arb_fees_array: jnp.ndarray, - trade_array: jnp.ndarray, - lp_supply_array: jnp.ndarray = None, + dynamic_inputs, additional_oracle_input: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: # hodl means no activity, so reserves are just the initial reserves diff --git a/quantammsim/pools/reCLAMM/reclamm.py b/quantammsim/pools/reCLAMM/reclamm.py index 152a264..c73b6c4 100644 --- a/quantammsim/pools/reCLAMM/reclamm.py +++ b/quantammsim/pools/reCLAMM/reclamm.py @@ -275,11 +275,7 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( run_fingerprint: Dict[str, Any], prices: jnp.ndarray, start_index: jnp.ndarray, - fees_array: jnp.ndarray, - arb_thresh_array: jnp.ndarray, - arb_fees_array: jnp.ndarray, - trade_array: jnp.ndarray, - lp_supply_array: jnp.ndarray = None, + dynamic_inputs, additional_oracle_input: Optional[jnp.ndarray] = None, ): """Calculate reserves and LP fee revenue with time-varying inputs. @@ -291,6 +287,9 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( LP fee revenue per timestep in USD. """ s = self._init_pool_state(params, run_fingerprint, prices, start_index) + fees_array = dynamic_inputs.fees + arb_thresh_array = dynamic_inputs.gas_cost + arb_fees_array = dynamic_inputs.arb_fees bout_length = run_fingerprint["bout_length"] max_len = bout_length - 1 @@ -367,14 +366,13 @@ def calculate_reserves_with_dynamic_inputs( run_fingerprint: Dict[str, Any], prices: jnp.ndarray, start_index: jnp.ndarray, - fees_array: jnp.ndarray, - arb_thresh_array: jnp.ndarray, - arb_fees_array: jnp.ndarray, - trade_array: jnp.ndarray, - lp_supply_array: jnp.ndarray = None, + dynamic_inputs, additional_oracle_input: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: s = self._init_pool_state(params, run_fingerprint, prices, start_index) + fees_array = dynamic_inputs.fees + arb_thresh_array = dynamic_inputs.gas_cost + arb_fees_array = dynamic_inputs.arb_fees bout_length = run_fingerprint["bout_length"] max_len = bout_length - 1 diff --git a/quantammsim/runners/__init__.py b/quantammsim/runners/__init__.py index 987a800..511800b 100644 --- a/quantammsim/runners/__init__.py +++ b/quantammsim/runners/__init__.py @@ -28,7 +28,7 @@ from .jax_runner_utils import ( nan_rollback, Hashabledict, - get_trades_and_fees, + prepare_dynamic_inputs, get_unique_tokens, OptunaManager, generate_evaluation_points, @@ -80,7 +80,7 @@ # Utilities "nan_rollback", "Hashabledict", - "get_trades_and_fees", + "prepare_dynamic_inputs", "get_unique_tokens", "OptunaManager", "generate_evaluation_points", diff --git a/quantammsim/runners/jax_runner_utils.py b/quantammsim/runners/jax_runner_utils.py index 399cd37..64c916d 100644 --- a/quantammsim/runners/jax_runner_utils.py +++ b/quantammsim/runners/jax_runner_utils.py @@ -14,6 +14,12 @@ raw_fee_like_amounts_to_fee_like_array, raw_trades_to_trade_array, ) +from quantammsim.core_simulator.dynamic_inputs import ( + DynamicInputArrays, + DynamicInputFrames, + dynamic_input_flags_from_frames, + empty_dynamic_input_arrays, +) from quantammsim.apis.rest_apis.simulator_dtos.simulation_run_dto import ( LiquidityPoolCoinDto, @@ -1215,39 +1221,40 @@ def unpermute_list_of_params(list_of_params): return list_of_params_to_return -def get_trades_and_fees( - run_fingerprint, raw_trades, fees_df, gas_cost_df, arb_fees_df, lp_supply_df, do_test_period=False -): - """ - Process trade and fee data for a simulation run. +def _to_dynamic_input_arrays( + trades_array, + fees_array, + gas_cost_array, + arb_fees_array, + lp_supply_array, +) -> DynamicInputArrays: + """Normalize optional numpy arrays into the fixed hot-path container.""" + empty = empty_dynamic_input_arrays() + return DynamicInputArrays( + trades=empty.trades if trades_array is None else jnp.asarray(trades_array, dtype=jnp.float64), + fees=empty.fees if fees_array is None else jnp.asarray(fees_array, dtype=jnp.float64), + gas_cost=empty.gas_cost if gas_cost_array is None else jnp.asarray(gas_cost_array, dtype=jnp.float64), + arb_fees=empty.arb_fees if arb_fees_array is None else jnp.asarray(arb_fees_array, dtype=jnp.float64), + lp_supply=empty.lp_supply if lp_supply_array is None else jnp.asarray(lp_supply_array, dtype=jnp.float64), + ) - Takes raw trades, fees, gas costs and arbitrage fees and converts them into arrays - suitable for simulation. Handles both training and test periods if specified. - Parameters - ---------- - run_fingerprint : dict - Dictionary containing run configuration including start/end dates and tokens - raw_trades : pd.DataFrame, optional - DataFrame containing raw trade data - fees_df : pd.DataFrame, optional - DataFrame containing fee data - gas_cost_df : pd.DataFrame, optional - DataFrame containing gas cost data - arb_fees_df : pd.DataFrame, optional - DataFrame containing arbitrage fee data - lp_supply_df : pd.DataFrame, optional - DataFrame containing LP supply data - do_test_period : bool, optional - Whether to process data for a test period after training period (default False) +def prepare_dynamic_inputs( + run_fingerprint, + dynamic_input_frames: Optional[DynamicInputFrames] = None, + do_test_period: bool = False, +): + """Convert optional pandas inputs into fixed-structure dynamic input bundles.""" + if dynamic_input_frames is None: + dynamic_input_frames = DynamicInputFrames() + + raw_trades = dynamic_input_frames.trades + fees_df = dynamic_input_frames.fees + gas_cost_df = dynamic_input_frames.gas_cost + arb_fees_df = dynamic_input_frames.arb_fees + lp_supply_df = dynamic_input_frames.lp_supply + dynamic_input_flags = dynamic_input_flags_from_frames(dynamic_input_frames) - Returns - ------- - dict - Contains processed arrays for trades, fees, gas costs and arb fees for both - training and test periods as applicable - """ - # Process raw trades if provided if raw_trades is not None: train_period_trades = raw_trades_to_trade_array( raw_trades, @@ -1265,7 +1272,7 @@ def get_trades_and_fees( else: train_period_trades = None test_period_trades = None - # Process fees, gas costs, and arb fees if provided + fees_array = ( raw_fee_like_amounts_to_fee_like_array( fees_df, @@ -1281,8 +1288,8 @@ def get_trades_and_fees( test_fees_array = ( raw_fee_like_amounts_to_fee_like_array( fees_df, - run_fingerprint["startDateString"], run_fingerprint["endDateString"], + run_fingerprint["endTestDateString"], names=["fees"], fill_method="ffill", ) @@ -1361,25 +1368,32 @@ def get_trades_and_fees( else None ) return { - "train_period_trades": train_period_trades, - "test_period_trades": test_period_trades, - "fees_array": fees_array, - "gas_cost_array": gas_cost_array, - "arb_fees_array": arb_fees_array, - "lp_supply_array": lp_supply_array, - "test_fees_array": test_fees_array, - "test_gas_cost_array": test_gas_cost_array, - "test_arb_fees_array": test_arb_fees_array, - "test_lp_supply_array": test_lp_supply_array, - } - else: - return { - "train_period_trades": train_period_trades, - "fees_array": fees_array, - "gas_cost_array": gas_cost_array, - "arb_fees_array": arb_fees_array, - "lp_supply_array": lp_supply_array, + "train_dynamic_inputs": _to_dynamic_input_arrays( + train_period_trades, + fees_array, + gas_cost_array, + arb_fees_array, + lp_supply_array, + ), + "test_dynamic_inputs": _to_dynamic_input_arrays( + test_period_trades, + test_fees_array, + test_gas_cost_array, + test_arb_fees_array, + test_lp_supply_array, + ), + "dynamic_input_flags": dynamic_input_flags, } + return { + "train_dynamic_inputs": _to_dynamic_input_arrays( + train_period_trades, + fees_array, + gas_cost_array, + arb_fees_array, + lp_supply_array, + ), + "dynamic_input_flags": dynamic_input_flags, + } def create_daily_unix_array(start_date_str, end_date_str): @@ -1622,24 +1636,33 @@ def try_forward_pass(n_sets: int) -> bool: "n_assets": n_tokens, "training_data_kind": probe_fingerprint["optimisation_settings"]["training_data_kind"], "do_trades": False, + "dynamic_input_flags": { + "use_dynamic_inputs": False, + "has_trades": False, + "has_dynamic_fees": False, + "has_dynamic_gas_cost": False, + "has_dynamic_arb_fees": False, + "has_lp_supply": False, + }, }, ) # Create vmapped forward pass partial_forward = Partial( forward_pass_nograd, + dynamic_inputs=None, prices=data_dict["prices"], static_dict=static_dict, pool=pool, ) vmapped_forward = jit( - vmap(partial_forward, in_axes=[params_in_axes_dict, None, None]) + vmap(partial_forward, in_axes=[params_in_axes_dict, None]) ) # Run forward pass start_index = (data_dict["start_idx"], 0) - _ = vmapped_forward(params, start_index, None) + _ = vmapped_forward(params, start_index) # Force computation to complete jnp.zeros(1).block_until_ready() diff --git a/quantammsim/runners/jax_runners.py b/quantammsim/runners/jax_runners.py index c403be5..1a8fdf3 100644 --- a/quantammsim/runners/jax_runners.py +++ b/quantammsim/runners/jax_runners.py @@ -52,6 +52,10 @@ forward_pass_nograd, _calculate_return_value, ) +from quantammsim.core_simulator.dynamic_inputs import ( + DynamicInputFrames, + resolve_dynamic_input_components, +) from quantammsim.core_simulator.windowing_utils import get_indices, filter_coarse_weights_by_data_indices import hashlib @@ -79,7 +83,7 @@ from quantammsim.runners.jax_runner_utils import ( Hashabledict, - get_trades_and_fees, + prepare_dynamic_inputs, get_unique_tokens, OptunaManager, generate_evaluation_points, @@ -639,6 +643,14 @@ def train_on_historic_data( "n_assets": n_assets, "training_data_kind": run_fingerprint["optimisation_settings"]["training_data_kind"], "do_trades": False, + "dynamic_input_flags": { + "use_dynamic_inputs": False, + "has_trades": False, + "has_dynamic_fees": False, + "has_dynamic_gas_cost": False, + "has_dynamic_arb_fees": False, + "has_lp_supply": False, + }, }, ) @@ -653,6 +665,7 @@ def train_on_historic_data( continuous_static_dict["bout_length"] = original_bout_length + data_dict["bout_length_test"] partial_forward_pass_nograd_batch_continuous = Partial( forward_pass_nograd, + dynamic_inputs=None, static_dict=Hashabledict(continuous_static_dict), pool=pool, ) @@ -798,6 +811,7 @@ def init_optimizer(params): # Build scan-compatible update (prices as explicit arg, not closure) partial_step_no_prices = Partial( forward_pass, + dynamic_inputs=None, static_dict=Hashabledict(base_static_dict), pool=pool, ) @@ -1798,14 +1812,10 @@ def do_run_on_historic_data( root=None, price_data=None, verbose=False, - raw_trades=None, fees=None, gas_cost=None, arb_fees=None, - fees_df=None, - gas_cost_df=None, - arb_fees_df=None, - lp_supply_df=None, + dynamic_input_frames: DynamicInputFrames = None, do_test_period=False, low_data_mode=False, preslice_burnin=True, @@ -1831,23 +1841,14 @@ def do_run_on_historic_data( Pre-loaded price data. When None, loaded from parquet files. verbose : bool, optional Print progress information (default False). - raw_trades : DataFrame, optional - Real trade data to inject. Columns: unix timestamp (minute), - token_in, token_out, amount_in. fees : float, optional Swap fee override (e.g. 0.003 for 30 bps). gas_cost : float, optional Gas cost override per transaction. arb_fees : float, optional Arbitrageur fee override. - fees_df : DataFrame, optional - Time-varying swap fees (columns: unix, fee). - gas_cost_df : DataFrame, optional - Time-varying gas costs (columns: unix, gas_cost). - arb_fees_df : DataFrame, optional - Time-varying arb fees (columns: unix, arb_fee). - lp_supply_df : DataFrame, optional - Time-varying LP supply changes. + dynamic_input_frames : DynamicInputFrames, optional + Optional container of trades / fee / gas / arb / LP supply DataFrames. do_test_period : bool, optional If True, also run the OOS test period defined by ``endDateString`` to ``endTestDateString`` (default False). @@ -1892,15 +1893,21 @@ def do_run_on_historic_data( np.random.seed(0) - dynamic_inputs_dict = get_trades_and_fees( + dynamic_inputs_dict = prepare_dynamic_inputs( run_fingerprint, - raw_trades, - fees_df, - gas_cost_df, - arb_fees_df, - lp_supply_df, + dynamic_input_frames=dynamic_input_frames, do_test_period=do_test_period, ) + train_dynamic_inputs = ( + dynamic_inputs_dict["train_dynamic_inputs"] + if dynamic_inputs_dict["dynamic_input_flags"]["use_dynamic_inputs"] + else None + ) + test_dynamic_inputs = ( + dynamic_inputs_dict.get("test_dynamic_inputs") + if dynamic_inputs_dict["dynamic_input_flags"]["use_dynamic_inputs"] + else None + ) # Load price data if not provided if price_data is None: @@ -1944,7 +1951,8 @@ def do_run_on_historic_data( "fees": fees if fees is not None else run_fingerprint["fees"], "arb_fees": arb_fees if arb_fees is not None else run_fingerprint["arb_fees"], "gas_cost": gas_cost if gas_cost is not None else run_fingerprint["gas_cost"], - "do_trades": False if raw_trades is None else run_fingerprint["do_trades"], + "do_trades": dynamic_inputs_dict["dynamic_input_flags"]["has_trades"], + "dynamic_input_flags": dynamic_inputs_dict["dynamic_input_flags"], # Include date strings for run-time use "startDateString": run_fingerprint["startDateString"], "endDateString": run_fingerprint["endDateString"], @@ -2000,10 +2008,7 @@ def do_run_on_historic_data( param, (data_dict["start_idx"], 0), data_dict["prices"], - dynamic_inputs_dict["train_period_trades"], - dynamic_inputs_dict["fees_array"], - dynamic_inputs_dict["gas_cost_array"], - dynamic_inputs_dict["arb_fees_array"], + train_dynamic_inputs, ) if low_data_mode: output_dict["final_prices"] = output_dict["prices"][-1] @@ -2019,10 +2024,7 @@ def do_run_on_historic_data( param, (data_dict["start_idx_test"], 0), data_dict["prices"], - dynamic_inputs_dict["test_period_trades"], - dynamic_inputs_dict["test_fees_array"], - dynamic_inputs_dict["test_gas_cost_array"], - dynamic_inputs_dict["test_arb_fees_array"], + test_dynamic_inputs, ) if low_data_mode: output_dict_test["final_prices"] = output_dict_test["prices"][-1] @@ -2061,14 +2063,10 @@ def do_run_on_historic_data_with_provided_coarse_weights( root=None, price_data=None, verbose=False, - raw_trades=None, fees=None, gas_cost=None, arb_fees=None, - fees_df=None, - gas_cost_df=None, - arb_fees_df=None, - lp_supply_df=None, + dynamic_input_frames: DynamicInputFrames = None, do_test_period=False, low_data_mode=False, ): @@ -2098,22 +2096,14 @@ def do_run_on_historic_data_with_provided_coarse_weights( Pre-loaded price data. verbose : bool, optional Print progress (default False). - raw_trades : DataFrame, optional - Real trade data to inject. fees : float, optional Swap fee override. gas_cost : float, optional Gas cost override. arb_fees : float, optional Arbitrageur fee override. - fees_df : DataFrame, optional - Time-varying swap fees. - gas_cost_df : DataFrame, optional - Time-varying gas costs. - arb_fees_df : DataFrame, optional - Time-varying arb fees. - lp_supply_df : DataFrame, optional - Time-varying LP supply changes. + dynamic_input_frames : DynamicInputFrames, optional + Optional container of trades / fee / gas / arb / LP supply DataFrames. do_test_period : bool, optional Run OOS test period (default False). low_data_mode : bool, optional @@ -2152,13 +2142,9 @@ def do_run_on_historic_data_with_provided_coarse_weights( np.random.seed(0) - dynamic_inputs_dict = get_trades_and_fees( + dynamic_inputs_dict = prepare_dynamic_inputs( run_fingerprint, - raw_trades, - fees_df, - gas_cost_df, - arb_fees_df, - lp_supply_df, + dynamic_input_frames=dynamic_input_frames, do_test_period=do_test_period, ) @@ -2201,7 +2187,8 @@ def do_run_on_historic_data_with_provided_coarse_weights( "fees": fees if fees is not None else run_fingerprint["fees"], "arb_fees": arb_fees if arb_fees is not None else run_fingerprint["arb_fees"], "gas_cost": gas_cost if gas_cost is not None else run_fingerprint["gas_cost"], - "do_trades": False if raw_trades is None else run_fingerprint["do_trades"], + "do_trades": dynamic_inputs_dict["dynamic_input_flags"]["has_trades"], + "dynamic_input_flags": dynamic_inputs_dict["dynamic_input_flags"], # Include date strings for run-time use "startDateString": run_fingerprint["startDateString"], "endDateString": run_fingerprint["endDateString"], @@ -2268,18 +2255,18 @@ def do_run_on_historic_data_with_provided_coarse_weights( # weights=HashableArrayWrapper(weights), # initial_reserves=HashableArrayWrapper(params["initial_reserves"]), # ) - fees_array = dynamic_inputs_dict.get("fees_array") - arb_thresh_array = dynamic_inputs_dict.get("gas_cost_array") - arb_fees_array = dynamic_inputs_dict.get("arb_fees_array") - trade_array = dynamic_inputs_dict.get("trades") - lp_supply_array = dynamic_inputs_dict.get("lp_supply_array") - - if fees_array is None: - fees_array = jnp.array([static_dict["fees"]]) - if arb_thresh_array is None: - arb_thresh_array = jnp.array([static_dict["gas_cost"]]) - if arb_fees_array is None: - arb_fees_array = jnp.array([static_dict["arb_fees"]]) + dynamic_input_flags = dynamic_inputs_dict["dynamic_input_flags"] + dynamic_inputs = dynamic_inputs_dict["train_dynamic_inputs"] + resolved_dynamic_inputs = resolve_dynamic_input_components( + dynamic_inputs, + dynamic_input_flags, + static_dict, + ) + fees_array = resolved_dynamic_inputs["fees"] + arb_thresh_array = resolved_dynamic_inputs["gas_cost"] + arb_fees_array = resolved_dynamic_inputs["arb_fees"] + trade_array = resolved_dynamic_inputs["trades"] + lp_supply_array = resolved_dynamic_inputs["lp_supply"] # initial_pool_value = run_fingerprint["initial_pool_value"] # initial_value_per_token = arb_acted_upon_weights[0] * initial_pool_value @@ -2298,10 +2285,8 @@ def do_run_on_historic_data_with_provided_coarse_weights( fees_array = fees_array[:max_len] arb_thresh_array = arb_thresh_array[:max_len] - arb_thresh_array = arb_thresh_array * 0.0 arb_fees_array = arb_fees_array[:max_len] - if lp_supply_array is not None: - lp_supply_array = lp_supply_array[:max_len] + lp_supply_array = lp_supply_array[:max_len] if trade_array is not None: trade_array = trade_array[:max_len] # Broadcast input arrays to match the maximum leading dimension. @@ -2316,10 +2301,6 @@ def do_run_on_historic_data_with_provided_coarse_weights( arb_fees_array_broadcast = jnp.broadcast_to( arb_fees_array, (max_len,) + arb_fees_array.shape[1:] ) - # if lp_supply_array is not provided, we set it to a constant of 1.0 - if lp_supply_array is None: - lp_supply_array = jnp.array(1.0) - lp_supply_array_broadcast = jnp.broadcast_to( lp_supply_array, (max_len,) + lp_supply_array.shape[1:] ) diff --git a/quantammsim/runners/multi_period_sgd.py b/quantammsim/runners/multi_period_sgd.py index 7802968..c81e00a 100644 --- a/quantammsim/runners/multi_period_sgd.py +++ b/quantammsim/runners/multi_period_sgd.py @@ -430,6 +430,7 @@ def multi_period_sgd_training( # Create base forward pass base_forward_pass = Partial( forward_pass, + dynamic_inputs=None, prices=data_dict["prices"], static_dict=Hashabledict(static_dict), pool=pool, @@ -506,6 +507,7 @@ def multi_period_sgd_training( partial_nograd = jit(Partial( forward_pass_nograd, + dynamic_inputs=None, prices=data_dict["prices"], static_dict=Hashabledict(static_dict), pool=pool, diff --git a/quantammsim/runners/training_evaluator.py b/quantammsim/runners/training_evaluator.py index e79a071..011b3f7 100644 --- a/quantammsim/runners/training_evaluator.py +++ b/quantammsim/runners/training_evaluator.py @@ -742,6 +742,7 @@ def _compute_metrics( eval_fn = jit(Partial( forward_pass_nograd, + dynamic_inputs=None, prices=data_dict["prices"], static_dict=Hashabledict(static_dict), pool=pool, diff --git a/quantammsim/simulator_analysis_tools/finance/param_financial_calculator.py b/quantammsim/simulator_analysis_tools/finance/param_financial_calculator.py index f92f7da..aa31519 100644 --- a/quantammsim/simulator_analysis_tools/finance/param_financial_calculator.py +++ b/quantammsim/simulator_analysis_tools/finance/param_financial_calculator.py @@ -20,6 +20,7 @@ from quantammsim.runners.jax_runners import do_run_on_historic_data from quantammsim.runners.jax_runner_utils import optimized_output_conversion +from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames import quantammsim.simulator_analysis_tools.finance.financial_analysis_calculator as fac import quantammsim.simulator_analysis_tools.finance.financial_analysis_functions as faf import quantammsim.simulator_analysis_tools.finance.financial_analysis_utils as fau @@ -238,6 +239,12 @@ def run_pool_simulation(simulationRunDto): run_fingerprint["fees"] = static_fee fee_steps_df = None + dynamic_input_frames = DynamicInputFrames( + trades=raw_trades, + fees=fee_steps_df, + gas_cost=gas_cost_df, + ) + print("run fingerprint-------------------", run_fingerprint) print("update rule parameter dict converted-------------------", update_rule_parameter_dict_converted) outputDict = do_run_on_historic_data( @@ -247,9 +254,7 @@ def run_pool_simulation(simulationRunDto): price_data=price_data_local, verbose=True, do_test_period=False, - raw_trades=raw_trades, - gas_cost_df=gas_cost_df, - fees_df=fee_steps_df + dynamic_input_frames=dynamic_input_frames, ) print("outputDict: ", outputDict.keys()) resultTimeSteps = optimized_output_conversion(simulationRunDto, outputDict, tokens) @@ -293,9 +298,7 @@ def run_pool_simulation(simulationRunDto): price_data=price_data_local, verbose=False, do_test_period=False, - raw_trades=raw_trades, - gas_cost_df=gas_cost_df, - fees_df=fee_steps_df, + dynamic_input_frames=dynamic_input_frames, ) # Extract final weights from the result. diff --git a/scripts/demo_run_chunks_from_chain_data.py b/scripts/demo_run_chunks_from_chain_data.py index fbd9cf9..09c19d1 100644 --- a/scripts/demo_run_chunks_from_chain_data.py +++ b/scripts/demo_run_chunks_from_chain_data.py @@ -28,6 +28,7 @@ import numpy as np import pandas as pd import matplotlib as mpl +from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames import matplotlib.pyplot as plt import jax.numpy as jnp @@ -474,10 +475,12 @@ def _df_meta_and_head(df, name, n=3): run_fingerprint=fingerprint, coarse_weights=cw_window, params=params, - fees_df=scraped["fees_df"], - gas_cost_df=scraped["gas_cost_df"], - lp_supply_df=scraped["lp_supply_df"], - arb_fees_df=scraped["arb_fees_df"], + dynamic_input_frames=DynamicInputFrames( + fees=scraped["fees_df"], + gas_cost=scraped["gas_cost_df"], + lp_supply=scraped["lp_supply_df"], + arb_fees=scraped["arb_fees_df"], + ), ) # ---------------- Correct, window-aligned plotting block (time-aware + plain y) ---------------- diff --git a/scripts/demo_run_from_chain_data.py b/scripts/demo_run_from_chain_data.py index 78ecbc4..b1f4139 100644 --- a/scripts/demo_run_from_chain_data.py +++ b/scripts/demo_run_from_chain_data.py @@ -1,4 +1,5 @@ import jax.numpy as jnp +from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames from quantammsim.core_simulator.param_utils import ( memory_days_to_logit_lamb, ) @@ -997,10 +998,12 @@ def generate_daily_variations(start_date_str, end_date_str): run_fingerprint=config["fingerprint"], coarse_weights=config["coarse_weights"], params=config["params"], - fees_df=config["fees_df"], - gas_cost_df=config["gas_cost_df"], - lp_supply_df=config["lp_supply_df"], - arb_fees_df=config["arb_fees_df"], + dynamic_input_frames=DynamicInputFrames( + fees=config["fees_df"], + gas_cost=config["gas_cost_df"], + lp_supply=config["lp_supply_df"], + arb_fees=config["arb_fees_df"], + ), ) print("-" * 80) print(f"Pool Type: {config['fingerprint']['rule']}") @@ -1191,4 +1194,4 @@ def generate_daily_variations(start_date_str, end_date_str): # actual_reserves_np=local_reserves, # actual_unix_values=datetime_array, # ) - # raise Exception("Stop here") \ No newline at end of file + # raise Exception("Stop here") diff --git a/scripts/reclamm/sim_vs_world_comparison.py b/scripts/reclamm/sim_vs_world_comparison.py index 0c754ea..bf7d19d 100644 --- a/scripts/reclamm/sim_vs_world_comparison.py +++ b/scripts/reclamm/sim_vs_world_comparison.py @@ -34,6 +34,7 @@ from pathlib import Path from datetime import datetime, timezone +from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames from quantammsim.runners.jax_runners import do_run_on_historic_data # ── On-chain reClAMM params ─────────────────────────────────────────────────── @@ -187,9 +188,18 @@ def sample_at_timestamps(minute_vals, start_unix_sec, timestamps_sec): return minute_vals[indices] -def run_pool(tokens, start, end, rule, fees, params, gas_cost=0.0, - protocol_fee_split=0.0, gas_cost_df=None, - onchain_initial_state=None): +def run_pool( + tokens, + start, + end, + rule, + fees, + params, + gas_cost=0.0, + protocol_fee_split=0.0, + dynamic_input_frames=None, + onchain_initial_state=None, +): """Run a quantammsim pool and return minute-level results. Returns (val_eth, price_ratio, start_unix_sec) where val_eth and @@ -220,7 +230,9 @@ def run_pool(tokens, start, end, rule, fees, params, gas_cost=0.0, fp["reclamm_initial_state"] = onchain_initial_state result = do_run_on_historic_data( - run_fingerprint=fp, params=params, gas_cost_df=gas_cost_df, + run_fingerprint=fp, + params=params, + dynamic_input_frames=dynamic_input_frames, ) # Prices: sorted tokens → [AAVE, ETH] in USD @@ -291,7 +303,8 @@ def run_gas_experiment(args): gas_df = load_gas_csv(pct) val_eth_min, _, _ = run_pool( tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, - protocol_fee_split=PROTOCOL_FEE_SPLIT, gas_cost_df=gas_df, + protocol_fee_split=PROTOCOL_FEE_SPLIT, + dynamic_input_frames=DynamicInputFrames(gas_cost=gas_df), ) gas_results_min[pct] = val_eth_min @@ -433,7 +446,8 @@ def run_gas_scale_experiment(args): gas_df["trade_gas_cost_usd"] = gas_df_raw["trade_gas_cost_usd"] * scale val_eth_min, pr_min, start_sec = run_pool( tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, - protocol_fee_split=PROTOCOL_FEE_SPLIT, gas_cost_df=gas_df, + protocol_fee_split=PROTOCOL_FEE_SPLIT, + dynamic_input_frames=DynamicInputFrames(gas_cost=gas_df), onchain_initial_state=onchain_state, ) results_min[(pct, scale)] = (val_eth_min, start_sec) @@ -662,7 +676,8 @@ def run_best_gas_experiment(args): gas_df_50p = load_gas_csv("50p") g50_min, _, _ = run_pool( tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, - protocol_fee_split=PROTOCOL_FEE_SPLIT, gas_cost_df=gas_df_50p, + protocol_fee_split=PROTOCOL_FEE_SPLIT, + dynamic_input_frames=DynamicInputFrames(gas_cost=gas_df_50p), onchain_initial_state=onchain_state, ) @@ -674,7 +689,8 @@ def run_best_gas_experiment(args): gas_df_75p_scaled["trade_gas_cost_usd"] *= 0.75 g75_min, _, _ = run_pool( tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, - protocol_fee_split=PROTOCOL_FEE_SPLIT, gas_cost_df=gas_df_75p_scaled, + protocol_fee_split=PROTOCOL_FEE_SPLIT, + dynamic_input_frames=DynamicInputFrames(gas_cost=gas_df_75p_scaled), onchain_initial_state=onchain_state, ) @@ -686,7 +702,8 @@ def run_best_gas_experiment(args): gas_df_90p_scaled["trade_gas_cost_usd"] *= 0.25 g90_min, _, _ = run_pool( tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, - protocol_fee_split=PROTOCOL_FEE_SPLIT, gas_cost_df=gas_df_90p_scaled, + protocol_fee_split=PROTOCOL_FEE_SPLIT, + dynamic_input_frames=DynamicInputFrames(gas_cost=gas_df_90p_scaled), onchain_initial_state=onchain_state, ) diff --git a/tests/integration/test_dynamic_gas_fees.py b/tests/integration/test_dynamic_gas_fees.py index 5870414..8eea4cd 100644 --- a/tests/integration/test_dynamic_gas_fees.py +++ b/tests/integration/test_dynamic_gas_fees.py @@ -9,6 +9,7 @@ import jax.numpy as jnp from pathlib import Path +from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames from quantammsim.runners.jax_runners import do_run_on_historic_data @@ -78,8 +79,12 @@ class TestDynamicGasAndFees: @pytest.mark.requires_data def test_run_with_gas_and_fees(self, base_fingerprint, base_params, gas_df, fees_df, data_root): """Test simulation with both gas costs and dynamic fees.""" + dynamic_input_frames = DynamicInputFrames(gas_cost=gas_df, fees=fees_df) result = do_run_on_historic_data( - base_fingerprint, base_params, root=data_root, gas_cost_df=gas_df, fees_df=fees_df + base_fingerprint, + base_params, + root=data_root, + dynamic_input_frames=dynamic_input_frames, ) assert result is not None @@ -92,8 +97,12 @@ def test_run_with_gas_and_fees(self, base_fingerprint, base_params, gas_df, fees @pytest.mark.requires_data def test_run_with_gas_only(self, base_fingerprint, base_params, gas_df, data_root): """Test simulation with gas costs only.""" + dynamic_input_frames = DynamicInputFrames(gas_cost=gas_df) result = do_run_on_historic_data( - base_fingerprint, base_params, root=data_root, gas_cost_df=gas_df + base_fingerprint, + base_params, + root=data_root, + dynamic_input_frames=dynamic_input_frames, ) assert result is not None @@ -104,8 +113,12 @@ def test_run_with_gas_only(self, base_fingerprint, base_params, gas_df, data_roo @pytest.mark.requires_data def test_run_with_fees_only(self, base_fingerprint, base_params, fees_df, data_root): """Test simulation with dynamic fees only.""" + dynamic_input_frames = DynamicInputFrames(fees=fees_df) result = do_run_on_historic_data( - base_fingerprint, base_params, root=data_root, fees_df=fees_df + base_fingerprint, + base_params, + root=data_root, + dynamic_input_frames=dynamic_input_frames, ) assert result is not None @@ -120,8 +133,12 @@ def test_gas_reduces_final_value(self, base_fingerprint, base_params, gas_df, da result_no_gas = do_run_on_historic_data(base_fingerprint, base_params, root=data_root) # Run with gas + dynamic_input_frames = DynamicInputFrames(gas_cost=gas_df) result_with_gas = do_run_on_historic_data( - base_fingerprint, base_params, root=data_root, gas_cost_df=gas_df + base_fingerprint, + base_params, + root=data_root, + dynamic_input_frames=dynamic_input_frames, ) if "final_value" in result_no_gas and "final_value" in result_with_gas: @@ -139,8 +156,12 @@ def test_fees_reduce_final_value(self, base_fingerprint, base_params, fees_df, d result_no_fees = do_run_on_historic_data(base_fingerprint, base_params, root=data_root) # Run with fees + dynamic_input_frames = DynamicInputFrames(fees=fees_df) result_with_fees = do_run_on_historic_data( - base_fingerprint, base_params, root=data_root, fees_df=fees_df + base_fingerprint, + base_params, + root=data_root, + dynamic_input_frames=dynamic_input_frames, ) if "final_value" in result_no_fees and "final_value" in result_with_fees: diff --git a/tests/pools/reCLAMM/test_reclamm_fee_revenue.py b/tests/pools/reCLAMM/test_reclamm_fee_revenue.py index 9406a96..0e58726 100644 --- a/tests/pools/reCLAMM/test_reclamm_fee_revenue.py +++ b/tests/pools/reCLAMM/test_reclamm_fee_revenue.py @@ -267,6 +267,7 @@ class TestPoolMethodWithFees: """pool.calculate_reserves_and_fee_revenue_with_fees returns correct tuple.""" def test_pool_method_with_fees(self): + from quantammsim.core_simulator.dynamic_inputs import DynamicInputArrays from quantammsim.pools.creator import create_pool from quantammsim.runners.jax_runner_utils import Hashabledict @@ -347,13 +348,20 @@ def test_pool_method_with_dynamic_inputs(self): fees_array = jnp.array([0.003]) arb_thresh_array = jnp.array([0.0]) arb_fees_array = jnp.array([0.0]) + dynamic_inputs = DynamicInputArrays( + trades=jnp.zeros((1, 3)), + fees=fees_array, + gas_cost=arb_thresh_array, + arb_fees=arb_fees_array, + lp_supply=jnp.ones((1,)), + ) reserves, fee_revenue = pool.calculate_reserves_and_fee_revenue_with_dynamic_inputs( - params, run_fingerprint, prices, start_index, - fees_array=fees_array, - arb_thresh_array=arb_thresh_array, - arb_fees_array=arb_fees_array, - trade_array=None, + params, + run_fingerprint, + prices, + start_index, + dynamic_inputs=dynamic_inputs, ) assert reserves.shape == (n_steps, 2) diff --git a/tests/scripts/dynamic_gas_test.py b/tests/scripts/dynamic_gas_test.py index d656b6a..920a7e3 100644 --- a/tests/scripts/dynamic_gas_test.py +++ b/tests/scripts/dynamic_gas_test.py @@ -1,7 +1,9 @@ -from quantammsim.runners.jax_runners import do_run_on_historic_data import jax.numpy as jnp import pandas as pd +from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames +from quantammsim.runners.jax_runners import do_run_on_historic_data + # Print the results print("=" * 100) print("Simulation Results:") @@ -32,11 +34,21 @@ run_fingerprint["do_trades"] = False result_w_gas_and_fees = do_run_on_historic_data( - run_fingerprint, params, gas_cost_df=gas_df, fees_df=fees_df + run_fingerprint, + params, + dynamic_input_frames=DynamicInputFrames(gas_cost=gas_df, fees=fees_df), +) +result_w_gas_only = do_run_on_historic_data( + run_fingerprint, + params, + dynamic_input_frames=DynamicInputFrames(gas_cost=gas_df), ) -result_w_gas_only = do_run_on_historic_data(run_fingerprint, params, gas_cost_df=gas_df) -result_w_fees_only = do_run_on_historic_data(run_fingerprint, params, fees_df=fees_df) +result_w_fees_only = do_run_on_historic_data( + run_fingerprint, + params, + dynamic_input_frames=DynamicInputFrames(fees=fees_df), +) print(result_w_gas_and_fees["value"][-1440+1]) print(result_w_gas_only["value"][-1440+1]) diff --git a/tests/unit/test_jax_runner_utils.py b/tests/unit/test_jax_runner_utils.py index 9cd95f6..70dbf82 100644 --- a/tests/unit/test_jax_runner_utils.py +++ b/tests/unit/test_jax_runner_utils.py @@ -13,6 +13,7 @@ import pytest import numpy as np import jax.numpy as jnp +import pandas as pd class TestHashabledict: @@ -159,6 +160,21 @@ def test_static_dict_is_hashable_with_real_fingerprint(self): h = hash(hd) assert isinstance(h, int) + def test_static_dict_accepts_dynamic_input_flags(self): + """Nested dynamic-input flags must remain hashable for JIT cache keys.""" + from quantammsim.core_simulator.dynamic_inputs import default_dynamic_input_flags + from quantammsim.runners.jax_runner_utils import create_static_dict, Hashabledict + + fp = self._make_fingerprint() + static = create_static_dict( + fp, + bout_length=10080, + overrides={"dynamic_input_flags": default_dynamic_input_flags()}, + ) + + assert static["dynamic_input_flags"]["use_dynamic_inputs"] is False + assert isinstance(hash(Hashabledict(static)), int) + def test_unknown_array_fields_dropped_with_warning(self): """Arrays not in _TRAINING_ONLY_FIELDS are dropped with a warning. @@ -237,6 +253,226 @@ def test_equality_with_non_dict_returns_false(self): assert d != [1, 2, 3] +class TestDynamicInputPreparation: + """Tests for dynamic input container construction and normalization.""" + + def test_empty_dynamic_input_arrays_have_stable_shapes(self): + """The empty hot-path bundle should have canonical placeholder arrays.""" + from quantammsim.core_simulator.dynamic_inputs import empty_dynamic_input_arrays + + dynamic_inputs = empty_dynamic_input_arrays() + + assert dynamic_inputs.trades.shape == (1, 3) + assert dynamic_inputs.fees.shape == (1,) + assert dynamic_inputs.gas_cost.shape == (1,) + assert dynamic_inputs.arb_fees.shape == (1,) + assert dynamic_inputs.lp_supply.shape == (1,) + + def test_dynamic_input_flags_reflect_present_frames(self): + """Frame-presence flags should drive static dynamic-input dispatch.""" + from quantammsim.core_simulator.dynamic_inputs import ( + DynamicInputFrames, + dynamic_input_flags_from_frames, + ) + + flags = dynamic_input_flags_from_frames( + DynamicInputFrames( + trades=pd.DataFrame({"unix": [1], "token_in": ["ETH"], "token_out": ["USDC"], "amount_in": [1.0]}), + fees=pd.DataFrame({"unix": [1], "fees": [0.003]}), + gas_cost=pd.DataFrame({"unix": [1], "trade_gas_cost_usd": [2.0]}), + ) + ) + + assert flags["use_dynamic_inputs"] is True + assert flags["has_trades"] is True + assert flags["has_dynamic_fees"] is True + assert flags["has_dynamic_gas_cost"] is True + assert flags["has_dynamic_arb_fees"] is False + assert flags["has_lp_supply"] is False + + def test_prepare_dynamic_inputs_preserves_fixed_hot_path_structure(self): + """Normalization should return fixed bundles plus static dispatch flags.""" + from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames + from quantammsim.runners.jax_runner_utils import prepare_dynamic_inputs + + run_fingerprint = { + "tokens": ["ETH", "USDC"], + "startDateString": "2023-01-01 00:00:00", + "endDateString": "2023-01-01 00:02:00", + "endTestDateString": "2023-01-01 00:04:00", + } + + dynamic_input_frames = DynamicInputFrames( + trades=pd.DataFrame( + { + "unix": [1672531200000, 1672531320000], + "token_in": ["ETH", "USDC"], + "token_out": ["USDC", "ETH"], + "amount_in": [1.5, 2.0], + } + ), + fees=pd.DataFrame({"unix": [1672531200000], "fees": [0.003]}), + gas_cost=pd.DataFrame({"unix": [1672531200000], "trade_gas_cost_usd": [3.25]}), + arb_fees=pd.DataFrame({"unix": [1672531200000], "arb_fees": [0.0005]}), + lp_supply=pd.DataFrame({"unix": [1672531200000], "lp_supply": [1250.0]}), + ) + + prepared = prepare_dynamic_inputs( + run_fingerprint, + dynamic_input_frames=dynamic_input_frames, + do_test_period=True, + ) + + train_inputs = prepared["train_dynamic_inputs"] + test_inputs = prepared["test_dynamic_inputs"] + flags = prepared["dynamic_input_flags"] + + assert flags["use_dynamic_inputs"] is True + assert flags["has_trades"] is True + assert flags["has_dynamic_fees"] is True + assert flags["has_dynamic_gas_cost"] is True + assert flags["has_dynamic_arb_fees"] is True + assert flags["has_lp_supply"] is True + assert train_inputs.trades.shape == (2, 3) + assert train_inputs.fees.shape == (2,) + assert train_inputs.gas_cost.shape == (2,) + assert train_inputs.arb_fees.shape == (2,) + assert train_inputs.lp_supply.shape == (2,) + assert test_inputs.trades.shape == (2, 3) + assert test_inputs.fees.shape == (2,) + assert test_inputs.gas_cost.shape == (2,) + assert test_inputs.arb_fees.shape == (2,) + assert test_inputs.lp_supply.shape == (2,) + np.testing.assert_allclose(np.asarray(train_inputs.fees), np.array([0.003, 0.003])) + + def test_prepare_dynamic_inputs_uses_correct_test_period_values(self): + """Test-period arrays should use values effective from the test window onward.""" + from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames + from quantammsim.runners.jax_runner_utils import prepare_dynamic_inputs + + run_fingerprint = { + "tokens": ["ETH", "USDC"], + "startDateString": "2023-01-01 00:00:00", + "endDateString": "2023-01-01 00:02:00", + "endTestDateString": "2023-01-01 00:04:00", + } + end_unix = pd.Timestamp(run_fingerprint["endDateString"]).value // 10**6 + + prepared = prepare_dynamic_inputs( + run_fingerprint, + dynamic_input_frames=DynamicInputFrames( + fees=pd.DataFrame( + {"unix": [1672531200000, end_unix], "fees": [0.003, 0.004]} + ), + gas_cost=pd.DataFrame( + { + "unix": [1672531200000, end_unix], + "trade_gas_cost_usd": [1.5, 2.5], + } + ), + arb_fees=pd.DataFrame( + {"unix": [1672531200000, end_unix], "arb_fees": [0.0001, 0.0002]} + ), + lp_supply=pd.DataFrame( + {"unix": [1672531200000, end_unix], "lp_supply": [1000.0, 2000.0]} + ), + ), + do_test_period=True, + ) + + np.testing.assert_allclose( + np.asarray(prepared["test_dynamic_inputs"].fees), + np.array([0.004, 0.004]), + ) + np.testing.assert_allclose( + np.asarray(prepared["test_dynamic_inputs"].gas_cost), + np.array([2.5, 2.5]), + ) + np.testing.assert_allclose( + np.asarray(prepared["test_dynamic_inputs"].arb_fees), + np.array([0.0002, 0.0002]), + ) + np.testing.assert_allclose( + np.asarray(prepared["test_dynamic_inputs"].lp_supply), + np.array([2000.0, 2000.0]), + ) + + def test_resolve_dynamic_input_flags_promotes_explicit_bundle(self): + """Passing a bundle directly should force dynamic-path dispatch.""" + from quantammsim.core_simulator.dynamic_inputs import ( + empty_dynamic_input_arrays, + resolve_dynamic_input_flags, + ) + + flags = resolve_dynamic_input_flags( + empty_dynamic_input_arrays(), + { + "use_dynamic_inputs": False, + "has_trades": False, + "has_dynamic_fees": False, + "has_dynamic_gas_cost": False, + "has_dynamic_arb_fees": False, + "has_lp_supply": False, + }, + ) + + assert flags["use_dynamic_inputs"] is True + + def test_resolve_dynamic_input_components_falls_back_to_static_scalars(self): + """Static scalar config should materialize as singleton arrays when no frames are present.""" + from quantammsim.core_simulator.dynamic_inputs import ( + default_dynamic_input_flags, + resolve_dynamic_input_components, + ) + + resolved = resolve_dynamic_input_components( + dynamic_inputs=None, + dynamic_input_flags=default_dynamic_input_flags(), + static_dict={"fees": 0.003, "gas_cost": 2.5, "arb_fees": 0.0001}, + ) + + assert resolved["trades"] is None + np.testing.assert_allclose(np.asarray(resolved["fees"]), np.array([0.003])) + np.testing.assert_allclose(np.asarray(resolved["gas_cost"]), np.array([2.5])) + np.testing.assert_allclose(np.asarray(resolved["arb_fees"]), np.array([0.0001])) + np.testing.assert_allclose(np.asarray(resolved["lp_supply"]), np.array([1.0])) + + def test_resolve_dynamic_input_components_prefers_dynamic_values(self): + """Dynamic arrays should override static scalar defaults for enabled fields.""" + from quantammsim.core_simulator.dynamic_inputs import ( + DynamicInputArrays, + resolve_dynamic_input_components, + ) + + dynamic_inputs = DynamicInputArrays( + trades=jnp.array([[0.0, 1.0, 5.0]]), + fees=jnp.array([0.004]), + gas_cost=jnp.array([3.0]), + arb_fees=jnp.array([0.0003]), + lp_supply=jnp.array([1500.0]), + ) + flags = { + "use_dynamic_inputs": True, + "has_trades": True, + "has_dynamic_fees": True, + "has_dynamic_gas_cost": True, + "has_dynamic_arb_fees": True, + "has_lp_supply": True, + } + + resolved = resolve_dynamic_input_components( + dynamic_inputs=dynamic_inputs, + dynamic_input_flags=flags, + static_dict={"fees": 0.0, "gas_cost": 0.0, "arb_fees": 0.0}, + ) + + np.testing.assert_allclose(np.asarray(resolved["trades"]), np.array([[0.0, 1.0, 5.0]])) + np.testing.assert_allclose(np.asarray(resolved["fees"]), np.array([0.004])) + np.testing.assert_allclose(np.asarray(resolved["gas_cost"]), np.array([3.0])) + np.testing.assert_allclose(np.asarray(resolved["arb_fees"]), np.array([0.0003])) + np.testing.assert_allclose(np.asarray(resolved["lp_supply"]), np.array([1500.0])) + + class TestGetSigVariations: """Tests for get_sig_variations function.""" diff --git a/tests/unit/test_jax_runners_comprehensive.py b/tests/unit/test_jax_runners_comprehensive.py index dcdd9dd..995e6d8 100644 --- a/tests/unit/test_jax_runners_comprehensive.py +++ b/tests/unit/test_jax_runners_comprehensive.py @@ -10,6 +10,7 @@ """ import pytest import numpy as np +import pandas as pd import jax.numpy as jnp import jax from copy import deepcopy @@ -18,6 +19,7 @@ from quantammsim.runners.jax_runners import ( train_on_historic_data, do_run_on_historic_data, + do_run_on_historic_data_with_provided_coarse_weights, ) from quantammsim.runners.jax_runner_utils import ( NestedHashabledict, @@ -31,8 +33,10 @@ create_static_dict, ) from quantammsim.runners.default_run_fingerprint import run_fingerprint_defaults +from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames from quantammsim.core_simulator.param_utils import recursive_default_set, check_run_fingerprint from quantammsim.pools.creator import create_pool +from quantammsim.utils.data_processing.historic_data_utils import get_data_dict from tests.conftest import TEST_DATA_DIR @@ -389,6 +393,219 @@ def test_multiple_param_sets(self, defaulted_run_fingerprint, sample_params): assert isinstance(results, list) assert len(results) == 2 + def test_dynamic_trades_change_balancer_reserves(self, defaulted_run_fingerprint): + """Dynamic trade input should change the reserve path in the runner.""" + fp = deepcopy(defaulted_run_fingerprint) + fp["rule"] = "balancer" + fp["do_arb"] = False + fp["fees"] = 0.0 + fp["gas_cost"] = 0.0 + fp["arb_fees"] = 0.0 + params = {"initial_weights_logits": jnp.array([0.0, 0.0])} + + trade_unix = pd.Timestamp(fp["startDateString"]).value // 10**6 + trades_df = pd.DataFrame( + { + "unix": [trade_unix], + "token_in": ["ETH"], + "token_out": ["USDC"], + "amount_in": [100.0], + } + ) + + result_without_trades = do_run_on_historic_data( + fp, + params=params, + root=TEST_DATA_DIR, + verbose=False, + ) + result_with_trades = do_run_on_historic_data( + fp, + params=params, + root=TEST_DATA_DIR, + verbose=False, + dynamic_input_frames=DynamicInputFrames(trades=trades_df), + ) + + assert not np.allclose( + np.asarray(result_without_trades["reserves"]), + np.asarray(result_with_trades["reserves"]), + ) + assert result_with_trades["reserves"][0, 0] > result_without_trades["reserves"][0, 0] + assert result_with_trades["reserves"][0, 1] < result_without_trades["reserves"][0, 1] + + def test_dynamic_arb_fees_match_scalar_arb_fees(self, defaulted_run_fingerprint): + """Constant dynamic arb fees should match the scalar arb-fee path.""" + fp = deepcopy(defaulted_run_fingerprint) + fp["rule"] = "balancer" + fp["fees"] = 0.003 + fp["do_arb"] = True + params = {"initial_weights_logits": jnp.array([0.0, 0.0])} + + arb_fee = 0.002 + arb_fees_df = pd.DataFrame( + { + "unix": [pd.Timestamp(fp["startDateString"]).value // 10**6], + "arb_fees": [arb_fee], + } + ) + + result_scalar = do_run_on_historic_data( + fp, + params=params, + root=TEST_DATA_DIR, + verbose=False, + arb_fees=arb_fee, + ) + result_dynamic = do_run_on_historic_data( + fp, + params=params, + root=TEST_DATA_DIR, + verbose=False, + dynamic_input_frames=DynamicInputFrames(arb_fees=arb_fees_df), + ) + + np.testing.assert_allclose( + np.asarray(result_dynamic["value"]), + np.asarray(result_scalar["value"]), + rtol=1e-6, + atol=1e-6, + ) + + def test_dynamic_lp_supply_changes_momentum_runner_path(self, defaulted_run_fingerprint, sample_params): + """LP supply changes should affect the main momentum runner path.""" + fp = deepcopy(defaulted_run_fingerprint) + fp["protocol_fee_split"] = 0.25 + + start_unix = pd.Timestamp(fp["startDateString"]).value // 10**6 + midpoint_unix = start_unix + 3 * 1440 * 60 * 1000 + constant_lp_supply_df = pd.DataFrame( + { + "unix": [start_unix], + "lp_supply": [1.0], + } + ) + stepped_lp_supply_df = pd.DataFrame( + { + "unix": [start_unix, midpoint_unix], + "lp_supply": [1.0, 2.0], + } + ) + + result_without_lp_supply = do_run_on_historic_data( + fp, + params=sample_params, + root=TEST_DATA_DIR, + verbose=False, + ) + result_with_constant_lp_supply = do_run_on_historic_data( + fp, + params=sample_params, + root=TEST_DATA_DIR, + verbose=False, + dynamic_input_frames=DynamicInputFrames(lp_supply=constant_lp_supply_df), + ) + result_with_stepped_lp_supply = do_run_on_historic_data( + fp, + params=sample_params, + root=TEST_DATA_DIR, + verbose=False, + dynamic_input_frames=DynamicInputFrames(lp_supply=stepped_lp_supply_df), + ) + + np.testing.assert_allclose( + np.asarray(result_with_constant_lp_supply["value"]), + np.asarray(result_without_lp_supply["value"]), + rtol=1e-6, + atol=1e-6, + ) + assert not np.allclose( + np.asarray(result_with_stepped_lp_supply["reserves"][-1]), + np.asarray(result_without_lp_supply["reserves"][-1]), + ) + assert float(result_with_stepped_lp_supply["final_value"]) != pytest.approx( + float(result_without_lp_supply["final_value"]) + ) + + def test_provided_coarse_weights_respect_scalar_and_dynamic_gas(self, defaulted_run_fingerprint, sample_params): + """Provided-coarse-weight path should honor both scalar gas and dynamic gas arrays.""" + fp = deepcopy(defaulted_run_fingerprint) + fp["protocol_fee_split"] = 0.0 + + data_dict = get_data_dict( + list_of_tickers=fp["tokens"], + run_fingerprint=fp, + data_kind=fp["optimisation_settings"]["training_data_kind"], + root=TEST_DATA_DIR, + max_memory_days=fp["max_memory_days"], + start_date_string=fp["startDateString"], + end_time_string=fp["endDateString"], + start_time_test_string=fp["endDateString"], + end_time_test_string=fp["endTestDateString"], + max_mc_version=fp["optimisation_settings"]["max_mc_version"], + do_test_period=False, + ) + + coarse_unix_values = ( + pd.date_range( + start=pd.Timestamp(fp["startDateString"]), + end=pd.Timestamp(fp["endDateString"]), + freq=f"{fp['chunk_period']}min", + ) + .astype(np.int64) + // 10**6 + ) + coarse_weights = { + "weights": jnp.tile(jnp.array([[0.5, 0.5]]), (len(coarse_unix_values), 1)), + "unix_values": jnp.asarray(coarse_unix_values), + } + + params = deepcopy(sample_params) + initial_prices = jnp.asarray(data_dict["prices"][data_dict["start_idx"]], dtype=jnp.float64) + params["initial_reserves"] = (jnp.array([0.5, 0.5]) * fp["initial_pool_value"]) / initial_prices + + gas_cost = 50.0 + gas_cost_df = pd.DataFrame( + { + "unix": [pd.Timestamp(fp["startDateString"]).value // 10**6], + "trade_gas_cost_usd": [gas_cost], + } + ) + + result_no_gas = do_run_on_historic_data_with_provided_coarse_weights( + fp, + coarse_weights=coarse_weights, + params=params, + root=TEST_DATA_DIR, + verbose=False, + ) + result_scalar_gas = do_run_on_historic_data_with_provided_coarse_weights( + fp, + coarse_weights=coarse_weights, + params=params, + root=TEST_DATA_DIR, + verbose=False, + gas_cost=gas_cost, + ) + result_dynamic_gas = do_run_on_historic_data_with_provided_coarse_weights( + fp, + coarse_weights=coarse_weights, + params=params, + root=TEST_DATA_DIR, + verbose=False, + dynamic_input_frames=DynamicInputFrames(gas_cost=gas_cost_df), + ) + + assert float(result_scalar_gas["final_value"]) != pytest.approx( + float(result_no_gas["final_value"]) + ) + np.testing.assert_allclose( + np.asarray(result_dynamic_gas["value"]), + np.asarray(result_scalar_gas["value"]), + rtol=1e-6, + atol=1e-6, + ) + # ============================================================================ # Validation and Early Stopping Tests diff --git a/tests/unit/test_lint_bugs.py b/tests/unit/test_lint_bugs.py index 95f7ee6..bbb49e6 100644 --- a/tests/unit/test_lint_bugs.py +++ b/tests/unit/test_lint_bugs.py @@ -117,16 +117,13 @@ def calculate_reserves_zero_fees(self, params, static_dict, prices, start_index) prices = jnp.ones((20, 2)) start_index = jnp.array([0, 0]) - # __wrapped__ arg order: params, start_index, prices, trades, fees, - # gas_cost, arb_fees, pool, static_dict + # __wrapped__ arg order: params, start_index, prices, dynamic_inputs, + # pool, static_dict result = forward_pass.__wrapped__( {}, # params start_index, # start_index prices, # prices - None, # trades_array - None, # fees_array - None, # gas_cost_array - None, # arb_fees_array + None, # dynamic_inputs _MockPool(), # pool static_dict, # static_dict ) From 0dcb04d9eb949490a7d054e2e7a3cf18d39a548d Mon Sep 17 00:00:00 2001 From: christian harrington Date: Wed, 4 Mar 2026 15:57:53 +0000 Subject: [PATCH 2/6] add fix for test directory data instead of main data source --- tests/pools/reCLAMM/test_reclamm_reserves.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/pools/reCLAMM/test_reclamm_reserves.py b/tests/pools/reCLAMM/test_reclamm_reserves.py index 3887e47..cf40217 100644 --- a/tests/pools/reCLAMM/test_reclamm_reserves.py +++ b/tests/pools/reCLAMM/test_reclamm_reserves.py @@ -9,6 +9,7 @@ import numpy as np import numpy.testing as npt +from tests.conftest import TEST_DATA_DIR from quantammsim.pools.reCLAMM.reclamm_reserves import ( compute_invariant, compute_price_ratio, @@ -734,8 +735,8 @@ def test_shift_exponent_equivalent_to_base(self): fp_common = { "rule": "reclamm", "tokens": ["ETH", "USDC"], - "startDateString": "2024-06-01 00:00:00", - "endDateString": "2024-06-15 00:00:00", + "startDateString": "2023-01-01 00:00:00", + "endDateString": "2023-01-15 00:00:00", "initial_pool_value": 1_000_000.0, "do_arb": True, "fees": 0.0, @@ -748,6 +749,7 @@ def test_shift_exponent_equivalent_to_base(self): "centeredness_margin": jnp.array(0.2), "daily_price_shift_base": jnp.array(base), }, + root=str(TEST_DATA_DIR), ) result_exp = do_run_on_historic_data( run_fingerprint={**fp_common, "reclamm_use_shift_exponent": True}, @@ -756,6 +758,7 @@ def test_shift_exponent_equivalent_to_base(self): "centeredness_margin": jnp.array(0.2), "shift_exponent": jnp.array(shift_exp), }, + root=str(TEST_DATA_DIR), ) np.testing.assert_allclose( @@ -772,10 +775,9 @@ def test_train_on_historic_data_optuna(self): fp = { "rule": "reclamm", "tokens": ["ETH", "USDC"], - "startDateString": "2024-06-01 00:00:00", - "endDateString": "2024-06-15 00:00:00", - "endTestDateString": "2024-07-01 00:00:00", - "endTestDateString": "2024-08-01 00:00:00", + "startDateString": "2023-01-01 00:00:00", + "endDateString": "2023-01-15 00:00:00", + "endTestDateString": "2023-02-01 00:00:00", "initial_pool_value": 1_000_000.0, "do_arb": True, "fees": 0.0025, @@ -810,5 +812,5 @@ def test_train_on_historic_data_optuna(self): }, }, } - result = train_on_historic_data(fp, verbose=False) + result = train_on_historic_data(fp, root=str(TEST_DATA_DIR), verbose=False) assert result is not None From ffda35c26a8bebeed9bef9bd4f93606c849e4dc2 Mon Sep 17 00:00:00 2001 From: christian harrington Date: Wed, 4 Mar 2026 15:58:06 +0000 Subject: [PATCH 3/6] import fix --- tests/pools/reCLAMM/test_reclamm_fee_revenue.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pools/reCLAMM/test_reclamm_fee_revenue.py b/tests/pools/reCLAMM/test_reclamm_fee_revenue.py index 0e58726..4a31f1f 100644 --- a/tests/pools/reCLAMM/test_reclamm_fee_revenue.py +++ b/tests/pools/reCLAMM/test_reclamm_fee_revenue.py @@ -9,6 +9,7 @@ import numpy as np import numpy.testing as npt +from quantammsim.core_simulator.dynamic_inputs import DynamicInputArrays from quantammsim.pools.reCLAMM.reclamm_reserves import ( initialise_reclamm_reserves, _jax_calc_reclamm_reserves_with_fees, From 961b796c1c2d8b9ae46496102bbc5ad1b3e1c536 Mon Sep 17 00:00:00 2001 From: christian harrington Date: Wed, 4 Mar 2026 16:28:17 +0000 Subject: [PATCH 4/6] refactor the dynamic inputs given the runtime errors with lax scan. trades is the wrong shape and optional. Centralised in the new materialized function that is used by all dynamic input reserve calcs --- quantammsim/core_simulator/dynamic_inputs.py | 73 +++++++++++++++- quantammsim/core_simulator/forward_pass.py | 15 ++-- quantammsim/core_simulator/windowing_utils.py | 41 ++++----- quantammsim/hooks/dynamic_fee_base_hook.py | 2 +- quantammsim/pools/ECLP/gyroscope.py | 47 ++++------ quantammsim/pools/ECLP/gyroscope_reserves.py | 24 ++--- quantammsim/pools/FM_AMM/cow_pool.py | 39 +++------ quantammsim/pools/FM_AMM/cow_reserves.py | 11 ++- quantammsim/pools/G3M/balancer/balancer.py | 40 +++------ .../pools/G3M/balancer/balancer_reserves.py | 30 +++---- .../pools/G3M/quantamm/TFMM_base_pool.py | 54 ++++-------- .../pools/G3M/quantamm/quantamm_reserves.py | 50 ++++++----- quantammsim/pools/reCLAMM/reclamm.py | 53 +++++------ quantammsim/runners/jax_runner_utils.py | 22 ++++- quantammsim/runners/jax_runners.py | 87 +++++++------------ 15 files changed, 289 insertions(+), 299 deletions(-) diff --git a/quantammsim/core_simulator/dynamic_inputs.py b/quantammsim/core_simulator/dynamic_inputs.py index b1eedac..c249088 100644 --- a/quantammsim/core_simulator/dynamic_inputs.py +++ b/quantammsim/core_simulator/dynamic_inputs.py @@ -16,9 +16,9 @@ class DynamicInputFrames: class DynamicInputArrays(NamedTuple): - """Fixed-structure JAX pytree for dynamic simulation inputs.""" + """JAX pytree for dynamic simulation inputs with optional trade data.""" - trades: jnp.ndarray + trades: Optional[jnp.ndarray] fees: jnp.ndarray gas_cost: jnp.ndarray arb_fees: jnp.ndarray @@ -70,9 +70,9 @@ def resolve_dynamic_input_flags( def empty_dynamic_input_arrays() -> DynamicInputArrays: - """Create a canonical empty bundle with stable pytree structure.""" + """Create a canonical empty bundle.""" return DynamicInputArrays( - trades=jnp.zeros((1, 3), dtype=jnp.float64), + trades=None, fees=jnp.zeros((1,), dtype=jnp.float64), gas_cost=jnp.zeros((1,), dtype=jnp.float64), arb_fees=jnp.zeros((1,), dtype=jnp.float64), @@ -110,3 +110,68 @@ def resolve_dynamic_input_components( else jnp.ones((1,), dtype=jnp.float64) ), } + + +def _broadcast_dynamic_input_leaf( + input_name: str, + values: jnp.ndarray, + scan_len: int, + dtype, +) -> jnp.ndarray: + """Broadcast a singleton dynamic-input leaf to the scan length.""" + values = jnp.asarray(values, dtype=dtype) + if values.ndim == 0: + values = values.reshape((1,)) + if values.shape[0] == scan_len: + return values + if values.shape[0] == 1: + return jnp.broadcast_to(values, (scan_len,) + values.shape[1:]) + raise ValueError( + f"{input_name} has leading axis {values.shape[0]}, expected 1 or {scan_len}" + ) + + +def materialize_dynamic_inputs( + dynamic_inputs: Optional[DynamicInputArrays], + dynamic_input_flags: Optional[dict], + static_dict: dict, + scan_len: int, + do_trades: bool, + dtype=jnp.float64, +) -> DynamicInputArrays: + """Resolve and broadcast dynamic inputs for a specific scan length.""" + if dynamic_input_flags is None and dynamic_inputs is not None: + flags = { + "use_dynamic_inputs": True, + "has_trades": do_trades, + "has_dynamic_fees": True, + "has_dynamic_gas_cost": True, + "has_dynamic_arb_fees": True, + "has_lp_supply": True, + } + else: + flags = resolve_dynamic_input_flags(dynamic_inputs, dynamic_input_flags) + + resolved = resolve_dynamic_input_components(dynamic_inputs, flags, static_dict) + + trades = None + if do_trades: + if resolved["trades"] is None: + raise ValueError("Trades must be provided when do_trades=True.") + trades = _broadcast_dynamic_input_leaf( + "trades", resolved["trades"], scan_len, dtype + ) + + return DynamicInputArrays( + trades=trades, + fees=_broadcast_dynamic_input_leaf("fees", resolved["fees"], scan_len, dtype), + gas_cost=_broadcast_dynamic_input_leaf( + "gas_cost", resolved["gas_cost"], scan_len, dtype + ), + arb_fees=_broadcast_dynamic_input_leaf( + "arb_fees", resolved["arb_fees"], scan_len, dtype + ), + lp_supply=_broadcast_dynamic_input_leaf( + "lp_supply", resolved["lp_supply"], scan_len, dtype + ), + ) diff --git a/quantammsim/core_simulator/forward_pass.py b/quantammsim/core_simulator/forward_pass.py index 4dd74f6..d2a32cd 100644 --- a/quantammsim/core_simulator/forward_pass.py +++ b/quantammsim/core_simulator/forward_pass.py @@ -57,7 +57,6 @@ from quantammsim.core_simulator.dynamic_inputs import ( DynamicInputArrays, default_dynamic_input_flags, - empty_dynamic_input_arrays, resolve_dynamic_input_flags, ) @@ -66,13 +65,11 @@ def _resolve_dynamic_inputs(dynamic_inputs, static_dict): - """Return a stable hot-path bundle plus static dispatch flags.""" + """Return the incoming bundle plus static dispatch flags.""" dynamic_input_flags = resolve_dynamic_input_flags( dynamic_inputs, static_dict.get("dynamic_input_flags"), ) - if dynamic_inputs is None: - dynamic_inputs = empty_dynamic_input_arrays() return dynamic_inputs, dynamic_input_flags @@ -1107,7 +1104,15 @@ def forward_pass_nograd( prices = stop_gradient(prices) if dynamic_inputs is not None: dynamic_inputs = DynamicInputArrays( - *(stop_gradient(arr) for arr in dynamic_inputs) + trades=( + None + if dynamic_inputs.trades is None + else stop_gradient(dynamic_inputs.trades) + ), + fees=stop_gradient(dynamic_inputs.fees), + gas_cost=stop_gradient(dynamic_inputs.gas_cost), + arb_fees=stop_gradient(dynamic_inputs.arb_fees), + lp_supply=stop_gradient(dynamic_inputs.lp_supply), ) return forward_pass( params, diff --git a/quantammsim/core_simulator/windowing_utils.py b/quantammsim/core_simulator/windowing_utils.py index 28d7524..c6cde8c 100644 --- a/quantammsim/core_simulator/windowing_utils.py +++ b/quantammsim/core_simulator/windowing_utils.py @@ -206,11 +206,12 @@ def raw_fee_like_amounts_to_fee_like_array( ).astype(int) // 10**6 )[:-1] + fill_value = np.nan if fill_method == "ffill" else 0.0 full_index_df = pd.DataFrame( - index=full_index, - columns=names, - data=0, - dtype=np.float64 + index=full_index, + columns=names, + data=fill_value, + dtype=np.float64, ) # Map raw data to the full index DataFrame @@ -236,15 +237,16 @@ def raw_fee_like_amounts_to_fee_like_array( # Ensure unix values are valid valid_unix = pd.to_numeric(raw_inputs['unix'], errors='coerce') valid_mask = valid_unix.notna() + valid_inputs = raw_inputs.loc[valid_mask].copy() + valid_inputs["unix"] = valid_unix.loc[valid_mask].astype(np.int64) + valid_inputs = valid_inputs.sort_values("unix") for name in names: initial_value = None - if valid_mask.any(): + if not valid_inputs.empty: # Try to get the last value before our start date - previous_values = raw_inputs[ - valid_mask & (valid_unix < start_unix) - ] + previous_values = valid_inputs[valid_inputs["unix"] < start_unix] if not previous_values.empty: try: @@ -254,9 +256,7 @@ def raw_fee_like_amounts_to_fee_like_array( if initial_value is None or pd.isna(initial_value): # Try to get first value in our date range - in_range_values = raw_inputs[ - valid_mask & (valid_unix >= start_unix) - ] + in_range_values = valid_inputs[valid_inputs["unix"] >= start_unix] if not in_range_values.empty: try: initial_value = pd.to_numeric(in_range_values[name].iloc[0]) @@ -264,17 +264,12 @@ def raw_fee_like_amounts_to_fee_like_array( initial_value = None if initial_value is not None and pd.notna(initial_value): - # this more complex logic is because of how we have started with prior-to-start values - # filled in, and then we want to ffill the rest - # Fill initial values - full_index_df[name] = full_index_df[name].mask( - full_index_df[name] == 0, - initial_value - ) - # Use ffill() - full_index_df[name] = full_index_df[name].where( - full_index_df[name] != 0 - ).ffill() + # Seed only the leading gap; explicit in-range updates must remain intact. + first_row = full_index_df.index[0] + if pd.isna(full_index_df.at[first_row, name]): + full_index_df.at[first_row, name] = initial_value + + full_index_df[name] = full_index_df[name].ffill().fillna(0.0) except (ValueError, KeyError, TypeError) as e: print(f"Warning: Error during ffill processing: {str(e)}") # On any error, return the original zero-filled DataFrame @@ -387,4 +382,4 @@ def filter_reserves_by_given_timestamp(reserves, unix_values, timestamp): unix_values == timestamp )[0][0] - return reserves[reserves_index].copy() \ No newline at end of file + return reserves[reserves_index].copy() diff --git a/quantammsim/hooks/dynamic_fee_base_hook.py b/quantammsim/hooks/dynamic_fee_base_hook.py index 9e00af6..ad64a5f 100644 --- a/quantammsim/hooks/dynamic_fee_base_hook.py +++ b/quantammsim/hooks/dynamic_fee_base_hook.py @@ -119,7 +119,7 @@ def calculate_reserves_with_fees( dynamic_fees = raw_dynamic_fees.repeat(chunk_period, axis=0).squeeze() empty_inputs = empty_dynamic_input_arrays() dynamic_inputs = DynamicInputArrays( - trades=empty_inputs.trades, + trades=None, fees=dynamic_fees, gas_cost=jnp.asarray(run_fingerprint["gas_cost"], dtype=jnp.float64), arb_fees=jnp.asarray(run_fingerprint["arb_fees"], dtype=jnp.float64), diff --git a/quantammsim/pools/ECLP/gyroscope.py b/quantammsim/pools/ECLP/gyroscope.py index 5aebcef..83a7dc2 100644 --- a/quantammsim/pools/ECLP/gyroscope.py +++ b/quantammsim/pools/ECLP/gyroscope.py @@ -32,6 +32,7 @@ from typing import Dict, Any, Optional, Tuple import numpy as np +from quantammsim.core_simulator.dynamic_inputs import materialize_dynamic_inputs from quantammsim.pools.base_pool import AbstractPool from quantammsim.pools.ECLP.gyroscope_reserves import ( @@ -342,11 +343,6 @@ def calculate_reserves_with_dynamic_inputs( else: arb_acted_upon_local_prices = local_prices - fees_array = dynamic_inputs.fees - arb_thresh_array = dynamic_inputs.gas_cost - arb_fees_array = dynamic_inputs.arb_fees - trade_array = dynamic_inputs.trades - # calculate initial reserves initial_pool_value = run_fingerprint["initial_pool_value"] initial_reserves = initialise_gyroscope_reserves_given_value( @@ -358,34 +354,27 @@ def calculate_reserves_with_dynamic_inputs( sin=jnp.sin(phi), cos=jnp.cos(phi), ) - # any of fees_array, arb_thresh_array, arb_fees_array, trade_array - # can be singletons, in which case we repeat them for the length of the bout - - # Determine the maximum leading dimension max_len = bout_length - 1 if run_fingerprint["arb_frequency"] != 1: max_len = max_len // run_fingerprint["arb_frequency"] + materialized_inputs = materialize_dynamic_inputs( + dynamic_inputs, + run_fingerprint.get("dynamic_input_flags"), + run_fingerprint, + scan_len=max_len, + do_trades=run_fingerprint["do_trades"], + dtype=arb_acted_upon_local_prices.dtype, + ) # Handle trade array reordering if needed if run_fingerprint["do_trades"]: - # if we are doing trades, the trades array must be of the same length as the other arrays - assert trade_array.shape[0] == max_len if needs_swap: # Swap trade indices (0->1, 1->0) but keep amounts unchanged - trade_array = trade_array.at[:, :2].set(1 - trade_array[:, :2]) - - # Broadcast input arrays to match the maximum leading dimension. - # If they are singletons, this will just repeat them for the length of the bout. - # If they are arrays of length bout_length, this will cause no change. - fees_array_broadcast = jnp.broadcast_to( - fees_array, (max_len,) + fees_array.shape[1:] - ) - arb_thresh_array_broadcast = jnp.broadcast_to( - arb_thresh_array, (max_len,) + arb_thresh_array.shape[1:] - ) - arb_fees_array_broadcast = jnp.broadcast_to( - arb_fees_array, (max_len,) + arb_fees_array.shape[1:] - ) + materialized_inputs = materialized_inputs._replace( + trades=materialized_inputs.trades.at[:, :2].set( + 1 - materialized_inputs.trades[:, :2] + ) + ) # Calculate reserves reserves = _jax_calc_gyroscope_reserves_with_dynamic_inputs( @@ -396,10 +385,10 @@ def calculate_reserves_with_dynamic_inputs( sin=jnp.sin(phi), cos=jnp.cos(phi), lam=lam, - fees=fees_array_broadcast, - arb_thresh=arb_thresh_array_broadcast, - arb_fees=arb_fees_array_broadcast, - trades=trade_array, + fees=materialized_inputs.fees, + arb_thresh=materialized_inputs.gas_cost, + arb_fees=materialized_inputs.arb_fees, + trades=materialized_inputs.trades, do_trades=run_fingerprint["do_trades"], ) # Restore original order if we swapped diff --git a/quantammsim/pools/ECLP/gyroscope_reserves.py b/quantammsim/pools/ECLP/gyroscope_reserves.py index da09072..6807310 100644 --- a/quantammsim/pools/ECLP/gyroscope_reserves.py +++ b/quantammsim/pools/ECLP/gyroscope_reserves.py @@ -605,7 +605,7 @@ def _jax_calc_gyroscope_reserves_with_dynamic_fees_and_trades_scan_function_usin gamma = input_list[1] arb_thresh = input_list[2] arb_fees = input_list[3] - trade = input_list[4] + trade = input_list[4] if do_trades else None @@ -727,6 +727,8 @@ def _jax_calc_gyroscope_reserves_with_dynamic_inputs( arb_fees = jnp.where( arb_fees.size == 1, jnp.full(prices.shape[0], arb_fees), arb_fees ) + if do_trades and trades is None: + raise ValueError("Trades must be provided when do_trades=True.") scan_fn = Partial( _jax_calc_gyroscope_reserves_with_dynamic_fees_and_trades_scan_function_using_precalcs, @@ -745,17 +747,15 @@ def _jax_calc_gyroscope_reserves_with_dynamic_inputs( initial_reserves, 0 ] - carry_list_end, reserves = scan( - scan_fn, - carry_list_init, - [ - prices, - gamma, - arb_thresh, - arb_fees, - trades, - ], - ) + scan_inputs = [ + prices, + gamma, + arb_thresh, + arb_fees, + ] + if do_trades: + scan_inputs.append(trades) + carry_list_end, reserves = scan(scan_fn, carry_list_init, scan_inputs) return reserves diff --git a/quantammsim/pools/FM_AMM/cow_pool.py b/quantammsim/pools/FM_AMM/cow_pool.py index 96299ce..b581c61 100644 --- a/quantammsim/pools/FM_AMM/cow_pool.py +++ b/quantammsim/pools/FM_AMM/cow_pool.py @@ -32,6 +32,7 @@ from functools import partial import numpy as np +from quantammsim.core_simulator.dynamic_inputs import materialize_dynamic_inputs from quantammsim.pools.base_pool import AbstractPool from quantammsim.pools.FM_AMM.cow_reserves import ( _jax_calc_cowamm_reserves_with_fees, @@ -211,47 +212,31 @@ def calculate_reserves_with_dynamic_inputs( else: arb_acted_upon_local_prices = local_prices - fees_array = dynamic_inputs.fees - arb_thresh_array = dynamic_inputs.gas_cost - arb_fees_array = dynamic_inputs.arb_fees - trade_array = dynamic_inputs.trades - initial_pool_value = run_fingerprint["initial_pool_value"] initial_value_per_token = weights * initial_pool_value initial_reserves = initial_value_per_token / arb_acted_upon_local_prices[0] - # any of fees_array, arb_thresh_array, arb_fees_array, trade_array - # can be singletons, in which case we repeat them for the length of the bout - - # Determine the maximum leading dimension max_len = bout_length - 1 if run_fingerprint["arb_frequency"] != 1: max_len = max_len // run_fingerprint["arb_frequency"] - # Broadcast input arrays to match the maximum leading dimension. - # If they are singletons, this will just repeat them for the length of the bout. - # If they are arrays of length bout_length, this will cause no change. - fees_array_broadcast = jnp.broadcast_to( - fees_array, (max_len,) + fees_array.shape[1:] - ) - arb_thresh_array_broadcast = jnp.broadcast_to( - arb_thresh_array, (max_len,) + arb_thresh_array.shape[1:] - ) - arb_fees_array_broadcast = jnp.broadcast_to( - arb_fees_array, (max_len,) + arb_fees_array.shape[1:] + materialized_inputs = materialize_dynamic_inputs( + dynamic_inputs, + run_fingerprint.get("dynamic_input_flags"), + run_fingerprint, + scan_len=max_len, + do_trades=run_fingerprint["do_trades"], + dtype=arb_acted_upon_local_prices.dtype, ) - # if we are doing trades, the trades array must be of the same length as the other arrays - if run_fingerprint["do_trades"]: - assert trade_array.shape[0] == max_len reserves = _jax_calc_cowamm_reserves_with_dynamic_inputs( initial_reserves, arb_acted_upon_local_prices, - fees_array_broadcast, - arb_thresh_array_broadcast, - arb_fees_array_broadcast, + materialized_inputs.fees, + materialized_inputs.gas_cost, + materialized_inputs.arb_fees, weights, run_fingerprint["arb_quality"], - trade_array, + materialized_inputs.trades, run_fingerprint["do_trades"], run_fingerprint["do_arb"], noise_trader_ratio=run_fingerprint["noise_trader_ratio"], diff --git a/quantammsim/pools/FM_AMM/cow_reserves.py b/quantammsim/pools/FM_AMM/cow_reserves.py index f876d51..ab340a0 100644 --- a/quantammsim/pools/FM_AMM/cow_reserves.py +++ b/quantammsim/pools/FM_AMM/cow_reserves.py @@ -729,7 +729,7 @@ def _jax_calc_cowamm_reserves_with_dynamic_fees_and_trades_scan_function( gamma = input_list[1] arb_thresh = input_list[2] arb_fees = input_list[3] - trade = input_list[4] + trade = input_list[4] if do_trades else None if do_arb: reserves_with_perfect_arb = _jax_calc_cowamm_reserves_with_fees_scan_function( @@ -827,6 +827,8 @@ def _jax_calc_cowamm_reserves_with_dynamic_inputs( initial_prices = prices[0] gamma = 1.0 - fees + if do_trades and trades is None: + raise ValueError("Trades must be provided when do_trades=True.") scan_fn = Partial( _jax_calc_cowamm_reserves_with_dynamic_fees_and_trades_scan_function, @@ -838,8 +840,9 @@ def _jax_calc_cowamm_reserves_with_dynamic_inputs( ) carry_list_init = [initial_prices, initial_reserves] - _, reserves = scan( - scan_fn, carry_list_init, [prices, gamma, arb_thresh, arb_fees, trades] - ) + scan_inputs = [prices, gamma, arb_thresh, arb_fees] + if do_trades: + scan_inputs.append(trades) + _, reserves = scan(scan_fn, carry_list_init, scan_inputs) return reserves diff --git a/quantammsim/pools/G3M/balancer/balancer.py b/quantammsim/pools/G3M/balancer/balancer.py index 986d49c..49bd5ca 100644 --- a/quantammsim/pools/G3M/balancer/balancer.py +++ b/quantammsim/pools/G3M/balancer/balancer.py @@ -8,6 +8,7 @@ import jax.numpy as jnp from jax.lax import dynamic_slice +from quantammsim.core_simulator.dynamic_inputs import materialize_dynamic_inputs from quantammsim.pools.base_pool import AbstractPool from quantammsim.pools.G3M.balancer.balancer_reserves import ( _jax_calc_balancer_reserve_ratios, @@ -306,47 +307,30 @@ def calculate_reserves_with_dynamic_inputs( else: arb_acted_upon_local_prices = local_prices - fees_array = dynamic_inputs.fees - arb_thresh_array = dynamic_inputs.gas_cost - arb_fees_array = dynamic_inputs.arb_fees - trade_array = dynamic_inputs.trades - lp_supply_array = dynamic_inputs.lp_supply - initial_pool_value = run_fingerprint["initial_pool_value"] initial_value_per_token = weights * initial_pool_value initial_reserves = initial_value_per_token / arb_acted_upon_local_prices[0] - # any of fees_array, arb_thresh_array, arb_fees_array, trade_array - # can be singletons, in which case we repeat them for the length of the bout - - # Determine the maximum leading dimension max_len = bout_length - 1 if run_fingerprint["arb_frequency"] != 1: max_len = max_len // run_fingerprint["arb_frequency"] - # Broadcast input arrays to match the maximum leading dimension. - # If they are singletons, this will just repeat them for the length of the bout. - # If they are arrays of length bout_length, this will cause no change. - fees_array_broadcast = jnp.broadcast_to( - fees_array, (max_len,) + fees_array.shape[1:] - ) - arb_thresh_array_broadcast = jnp.broadcast_to( - arb_thresh_array, (max_len,) + arb_thresh_array.shape[1:] - ) - arb_fees_array_broadcast = jnp.broadcast_to( - arb_fees_array, (max_len,) + arb_fees_array.shape[1:] + materialized_inputs = materialize_dynamic_inputs( + dynamic_inputs, + run_fingerprint.get("dynamic_input_flags"), + run_fingerprint, + scan_len=max_len, + do_trades=run_fingerprint["do_trades"], + dtype=arb_acted_upon_local_prices.dtype, ) - # if we are doing trades, the trades array must be of the same length as the other arrays - if run_fingerprint["do_trades"]: - assert trade_array.shape[0] == max_len reserves = _jax_calc_balancer_reserves_with_dynamic_inputs( initial_reserves, weights, arb_acted_upon_local_prices, - fees_array_broadcast, - arb_thresh_array_broadcast, - arb_fees_array_broadcast, + materialized_inputs.fees, + materialized_inputs.gas_cost, + materialized_inputs.arb_fees, jnp.array(run_fingerprint["all_sig_variations"]), - trade_array, + materialized_inputs.trades, run_fingerprint["do_trades"], run_fingerprint["do_arb"], ) diff --git a/quantammsim/pools/G3M/balancer/balancer_reserves.py b/quantammsim/pools/G3M/balancer/balancer_reserves.py index 240ae6a..d6fd68c 100644 --- a/quantammsim/pools/G3M/balancer/balancer_reserves.py +++ b/quantammsim/pools/G3M/balancer/balancer_reserves.py @@ -364,7 +364,7 @@ def _jax_calc_balancer_reserves_with_dynamic_fees_and_trades_scan_function_using gamma = input_list[4] arb_thresh = input_list[5] arb_fees = input_list[6] - trade = input_list[7] + trade = input_list[7] if do_trades else None fees_are_being_charged = gamma != 1.0 @@ -499,6 +499,8 @@ def _jax_calc_balancer_reserves_with_dynamic_inputs( arb_fees = jnp.where( arb_fees.size == 1, jnp.full(prices.shape[0], arb_fees), arb_fees ) + if do_trades and trades is None: + raise ValueError("Trades must be provided when do_trades=True.") # pre-calculate some values that are repeatedly used in optimal arb calculations _, active_trade_directions, tokens_to_drop, leave_one_out_idxs = ( @@ -533,19 +535,17 @@ def _jax_calc_balancer_reserves_with_dynamic_inputs( initial_reserves, 0, ] - _, reserves = scan( - scan_fn, - carry_list_init, - [ - prices, - active_initial_weights, - per_asset_ratios, - all_other_assets_ratios, - gamma, - arb_thresh, - arb_fees, - trades, - ], - ) + scan_inputs = [ + prices, + active_initial_weights, + per_asset_ratios, + all_other_assets_ratios, + gamma, + arb_thresh, + arb_fees, + ] + if do_trades: + scan_inputs.append(trades) + _, reserves = scan(scan_fn, carry_list_init, scan_inputs) return reserves diff --git a/quantammsim/pools/G3M/quantamm/TFMM_base_pool.py b/quantammsim/pools/G3M/quantamm/TFMM_base_pool.py index b087b5a..d5a1654 100644 --- a/quantammsim/pools/G3M/quantamm/TFMM_base_pool.py +++ b/quantammsim/pools/G3M/quantamm/TFMM_base_pool.py @@ -20,6 +20,7 @@ from jax.lax import dynamic_slice, scan, fori_loop from jax.tree_util import Partial +from quantammsim.core_simulator.dynamic_inputs import materialize_dynamic_inputs from quantammsim.pools.base_pool import AbstractPool from quantammsim.pools.G3M.quantamm.quantamm_reserves import ( _jax_calc_quantAMM_reserve_ratios, @@ -274,59 +275,35 @@ def calculate_reserves_with_dynamic_inputs( arb_acted_upon_weights = weights arb_acted_upon_local_prices = local_prices - fees_array = dynamic_inputs.fees - arb_thresh_array = dynamic_inputs.gas_cost - arb_fees_array = dynamic_inputs.arb_fees - trade_array = dynamic_inputs.trades - lp_supply_array = dynamic_inputs.lp_supply - initial_pool_value = run_fingerprint["initial_pool_value"] initial_value_per_token = arb_acted_upon_weights[0] * initial_pool_value initial_reserves = initial_value_per_token / arb_acted_upon_local_prices[0] - # any of fees_array, arb_thresh_array, arb_fees_array, trade_array, and lp_supply_array - # can be singletons, in which case we repeat them for the length of the bout. - - # Determine the maximum leading dimension max_len = bout_length - 1 if run_fingerprint["arb_frequency"] != 1: max_len = max_len // run_fingerprint["arb_frequency"] - # Broadcast input arrays to match the maximum leading dimension. - # If they are singletons, this will just repeat them for the length of the bout. - # If they are arrays of length bout_length, this will cause no change. - fees_array_broadcast = jnp.broadcast_to( - fees_array, (max_len,) + fees_array.shape[1:] - ) - arb_thresh_array_broadcast = jnp.broadcast_to( - arb_thresh_array, (max_len,) + arb_thresh_array.shape[1:] - ) - arb_fees_array_broadcast = jnp.broadcast_to( - arb_fees_array, (max_len,) + arb_fees_array.shape[1:] - ) - # if lp_supply_array is not provided, we set it to a constant of 1.0 - if lp_supply_array is None: - lp_supply_array = jnp.array(1.0) - - lp_supply_array_broadcast = jnp.broadcast_to( - lp_supply_array, (max_len,) + lp_supply_array.shape[1:] + materialized_inputs = materialize_dynamic_inputs( + dynamic_inputs, + run_fingerprint.get("dynamic_input_flags"), + run_fingerprint, + scan_len=max_len, + do_trades=run_fingerprint["do_trades"], + dtype=arb_acted_upon_local_prices.dtype, ) - # if we are doing trades, the trades array must be of the same length as the other arrays - if run_fingerprint["do_trades"]: - assert trade_array.shape[0] == max_len protocol_fee_split = run_fingerprint.get("protocol_fee_split", 0.0) reserves = _jax_calc_quantAMM_reserves_with_dynamic_inputs( initial_reserves, arb_acted_upon_weights, arb_acted_upon_local_prices, - fees_array_broadcast, - arb_thresh_array_broadcast, - arb_fees_array_broadcast, + materialized_inputs.fees, + materialized_inputs.gas_cost, + materialized_inputs.arb_fees, jnp.array(run_fingerprint["all_sig_variations"]), - trade_array, + materialized_inputs.trades, run_fingerprint["do_trades"], run_fingerprint["do_arb"], run_fingerprint["noise_trader_ratio"], - lp_supply_array_broadcast, + materialized_inputs.lp_supply, protocol_fee_split=protocol_fee_split, ) return reserves @@ -1461,11 +1438,16 @@ def calculate_weights_direct( initial_weights, minimum_weight, params, + jnp.zeros_like(initial_weights), + jnp.ones_like(initial_weights), local_fingerprint["max_memory_days"], local_fingerprint["chunk_period"], local_fingerprint["weight_interpolation_period"], maximum_change, False, + False, + False, + False, ) return target_weights_cpu diff --git a/quantammsim/pools/G3M/quantamm/quantamm_reserves.py b/quantammsim/pools/G3M/quantamm/quantamm_reserves.py index 351a788..77bbe11 100644 --- a/quantammsim/pools/G3M/quantamm/quantamm_reserves.py +++ b/quantammsim/pools/G3M/quantamm/quantamm_reserves.py @@ -545,9 +545,14 @@ def _jax_calc_quantAMM_reserves_with_dynamic_fees_and_trades_scan_function_using gamma = input_list[8] arb_thresh = input_list[9] arb_fees = input_list[10] - trade = input_list[11] - do_arb = input_list[12] - lp_supply = input_list[13] + if do_trades: + trade = input_list[11] + do_arb = input_list[12] + lp_supply = input_list[13] + else: + trade = None + do_arb = input_list[11] + lp_supply = input_list[12] fees_are_being_charged = gamma != 1.0 protocol_fee_amount_step = jnp.zeros_like(prev_reserves) @@ -831,6 +836,8 @@ def _jax_calc_quantAMM_reserves_with_dynamic_inputs( arb_fees = jnp.where( arb_fees.size == 1, jnp.full(weights.shape[0], arb_fees), arb_fees ) + if do_trades and trades is None: + raise ValueError("Trades must be provided when do_trades=True.") if lp_supply_array is None: lp_supply_array = jnp.array(1.0) @@ -904,25 +911,22 @@ def _jax_calc_quantAMM_reserves_with_dynamic_inputs( ] # carry_list_init = [initial_weights, initial_i] # nojit_scan = jax.disable_jit()(jax.lax.scan) - carry_list_end, reserves = scan( - scan_fn, - carry_list_init, - [ - weights, - prices, - active_initial_weights, - per_asset_ratios, - all_other_assets_ratios, - lagged_active_initial_weights, - lagged_per_asset_ratios, - lagged_all_other_assets_ratios, - gamma, - arb_thresh, - arb_fees, - trades, - do_arb, - lp_supply_array, - ], - ) + scan_inputs = [ + weights, + prices, + active_initial_weights, + per_asset_ratios, + all_other_assets_ratios, + lagged_active_initial_weights, + lagged_per_asset_ratios, + lagged_all_other_assets_ratios, + gamma, + arb_thresh, + arb_fees, + ] + if do_trades: + scan_inputs.append(trades) + scan_inputs.extend([do_arb, lp_supply_array]) + carry_list_end, reserves = scan(scan_fn, carry_list_init, scan_inputs) return reserves diff --git a/quantammsim/pools/reCLAMM/reclamm.py b/quantammsim/pools/reCLAMM/reclamm.py index c73b6c4..da36710 100644 --- a/quantammsim/pools/reCLAMM/reclamm.py +++ b/quantammsim/pools/reCLAMM/reclamm.py @@ -16,6 +16,7 @@ from typing import Dict, Any, Optional, NamedTuple import numpy as np +from quantammsim.core_simulator.dynamic_inputs import materialize_dynamic_inputs from quantammsim.pools.base_pool import AbstractPool from quantammsim.pools.reCLAMM.reclamm_reserves import ( initialise_reclamm_reserves, @@ -287,23 +288,17 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( LP fee revenue per timestep in USD. """ s = self._init_pool_state(params, run_fingerprint, prices, start_index) - fees_array = dynamic_inputs.fees - arb_thresh_array = dynamic_inputs.gas_cost - arb_fees_array = dynamic_inputs.arb_fees - bout_length = run_fingerprint["bout_length"] max_len = bout_length - 1 if run_fingerprint["arb_frequency"] != 1: max_len = max_len // run_fingerprint["arb_frequency"] - - fees_array_broadcast = jnp.broadcast_to( - fees_array, (max_len,) + fees_array.shape[1:] - ) - arb_thresh_array_broadcast = jnp.broadcast_to( - arb_thresh_array, (max_len,) + arb_thresh_array.shape[1:] - ) - arb_fees_array_broadcast = jnp.broadcast_to( - arb_fees_array, (max_len,) + arb_fees_array.shape[1:] + materialized_inputs = materialize_dynamic_inputs( + dynamic_inputs, + run_fingerprint.get("dynamic_input_flags"), + run_fingerprint, + scan_len=max_len, + do_trades=False, + dtype=s.arb_prices.dtype, ) return _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( @@ -312,9 +307,9 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( s.centeredness_margin, s.daily_price_shift_base, s.seconds_per_step, - fees=fees_array_broadcast, - arb_thresh=arb_thresh_array_broadcast, - arb_fees=arb_fees_array_broadcast, + fees=materialized_inputs.fees, + arb_thresh=materialized_inputs.gas_cost, + arb_fees=materialized_inputs.arb_fees, all_sig_variations=jnp.array( run_fingerprint["all_sig_variations"] ), @@ -370,23 +365,17 @@ def calculate_reserves_with_dynamic_inputs( additional_oracle_input: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: s = self._init_pool_state(params, run_fingerprint, prices, start_index) - fees_array = dynamic_inputs.fees - arb_thresh_array = dynamic_inputs.gas_cost - arb_fees_array = dynamic_inputs.arb_fees - bout_length = run_fingerprint["bout_length"] max_len = bout_length - 1 if run_fingerprint["arb_frequency"] != 1: max_len = max_len // run_fingerprint["arb_frequency"] - - fees_array_broadcast = jnp.broadcast_to( - fees_array, (max_len,) + fees_array.shape[1:] - ) - arb_thresh_array_broadcast = jnp.broadcast_to( - arb_thresh_array, (max_len,) + arb_thresh_array.shape[1:] - ) - arb_fees_array_broadcast = jnp.broadcast_to( - arb_fees_array, (max_len,) + arb_fees_array.shape[1:] + materialized_inputs = materialize_dynamic_inputs( + dynamic_inputs, + run_fingerprint.get("dynamic_input_flags"), + run_fingerprint, + scan_len=max_len, + do_trades=False, + dtype=s.arb_prices.dtype, ) return _jax_calc_reclamm_reserves_with_dynamic_inputs( @@ -395,9 +384,9 @@ def calculate_reserves_with_dynamic_inputs( s.centeredness_margin, s.daily_price_shift_base, s.seconds_per_step, - fees=fees_array_broadcast, - arb_thresh=arb_thresh_array_broadcast, - arb_fees=arb_fees_array_broadcast, + fees=materialized_inputs.fees, + arb_thresh=materialized_inputs.gas_cost, + arb_fees=materialized_inputs.arb_fees, all_sig_variations=jnp.array( run_fingerprint["all_sig_variations"] ), diff --git a/quantammsim/runners/jax_runner_utils.py b/quantammsim/runners/jax_runner_utils.py index 64c916d..e9c74f9 100644 --- a/quantammsim/runners/jax_runner_utils.py +++ b/quantammsim/runners/jax_runner_utils.py @@ -1060,8 +1060,9 @@ def get_unique_tokens(run_fingerprint): >>> get_unique_tokens(fingerprint) ['BTC', 'DAI', 'ETH'] """ + subsidary_pools = run_fingerprint.get("subsidary_pools", []) all_tokens = [run_fingerprint["tokens"]] + [ - cprd["tokens"] for cprd in run_fingerprint["subsidary_pools"] + cprd["tokens"] for cprd in subsidary_pools ] all_tokens = [item for sublist in all_tokens for item in sublist] unique_tokens = list(set(all_tokens)) @@ -1228,10 +1229,10 @@ def _to_dynamic_input_arrays( arb_fees_array, lp_supply_array, ) -> DynamicInputArrays: - """Normalize optional numpy arrays into the fixed hot-path container.""" + """Normalize optional numpy arrays into the hot-path container.""" empty = empty_dynamic_input_arrays() return DynamicInputArrays( - trades=empty.trades if trades_array is None else jnp.asarray(trades_array, dtype=jnp.float64), + trades=None if trades_array is None else jnp.asarray(trades_array, dtype=jnp.float64), fees=empty.fees if fees_array is None else jnp.asarray(fees_array, dtype=jnp.float64), gas_cost=empty.gas_cost if gas_cost_array is None else jnp.asarray(gas_cost_array, dtype=jnp.float64), arb_fees=empty.arb_fees if arb_fees_array is None else jnp.asarray(arb_fees_array, dtype=jnp.float64), @@ -1244,7 +1245,7 @@ def prepare_dynamic_inputs( dynamic_input_frames: Optional[DynamicInputFrames] = None, do_test_period: bool = False, ): - """Convert optional pandas inputs into fixed-structure dynamic input bundles.""" + """Convert optional pandas inputs into dynamic input bundles.""" if dynamic_input_frames is None: dynamic_input_frames = DynamicInputFrames() @@ -1367,6 +1368,19 @@ def prepare_dynamic_inputs( if lp_supply_df is not None else None ) + + # Unit LP supply is the neutral case; keep it on the static hot path. + if lp_supply_array is not None and np.allclose(lp_supply_array, 1.0): + lp_supply_array = None + if not do_test_period or test_lp_supply_array is None or np.allclose(test_lp_supply_array, 1.0): + dynamic_input_flags["has_lp_supply"] = False + dynamic_input_flags["use_dynamic_inputs"] = any( + value for key, value in dynamic_input_flags.items() if key != "use_dynamic_inputs" + ) + + if do_test_period and test_lp_supply_array is not None and np.allclose(test_lp_supply_array, 1.0): + test_lp_supply_array = None + if do_test_period: return { "train_dynamic_inputs": _to_dynamic_input_arrays( train_period_trades, diff --git a/quantammsim/runners/jax_runners.py b/quantammsim/runners/jax_runners.py index 1a8fdf3..aad11bd 100644 --- a/quantammsim/runners/jax_runners.py +++ b/quantammsim/runners/jax_runners.py @@ -54,7 +54,7 @@ ) from quantammsim.core_simulator.dynamic_inputs import ( DynamicInputFrames, - resolve_dynamic_input_components, + materialize_dynamic_inputs, ) from quantammsim.core_simulator.windowing_utils import get_indices, filter_coarse_weights_by_data_indices @@ -160,7 +160,10 @@ def _build_scan_infrastructure( run_scan_chunk : callable ``@jit`` wrapped ``lax.scan(scan_body, carry, None, length=chunk_size)``. scan_body : callable - The raw scan body (for partial-chunk Python fallback). + The raw scan body. + run_scan_step : callable + ``@jit`` wrapped single-step execution used for remainder iterations so + partial chunks follow the same numerics as the full scan path. """ # Local aliases for closed-over constants _start_idx = start_idx @@ -310,7 +313,11 @@ def scan_body(carry, _): def _run_scan_chunk(carry): return lax.scan(scan_body, carry, None, length=chunk_size) - return _run_scan_chunk, scan_body + @jit + def _run_scan_step(carry): + return scan_body(carry, None) + + return _run_scan_chunk, scan_body, _run_scan_step def train_on_historic_data( @@ -806,7 +813,7 @@ def init_optimizer(params): ) if config_key in _scan_infra_cache: - _run_scan_chunk, scan_body = _scan_infra_cache[config_key] + _run_scan_chunk, scan_body, _run_scan_step = _scan_infra_cache[config_key] else: # Build scan-compatible update (prices as explicit arg, not closure) partial_step_no_prices = Partial( @@ -825,7 +832,7 @@ def init_optimizer(params): partial_step_no_prices, params_in_axes_dict, ) - _run_scan_chunk, scan_body = _build_scan_infrastructure( + _run_scan_chunk, scan_body, _run_scan_step = _build_scan_infrastructure( chunk_size, partial_step_no_prices=partial_step_no_prices, forward_nograd_continuous=partial_forward_pass_nograd_continuous, @@ -851,7 +858,7 @@ def init_optimizer(params): swa_freq=swa_freq, n_parameter_sets=n_parameter_sets, ) - _scan_infra_cache[config_key] = (_run_scan_chunk, scan_body) + _scan_infra_cache[config_key] = (_run_scan_chunk, scan_body, _run_scan_step) # ── Initialize carry (prices & nan_bank in carry, not closures) ── carry = { @@ -906,7 +913,7 @@ def init_optimizer(params): "params": {k: [] for k in carry["params"]}, } for _ in range(actual): - carry, step_out = scan_body(carry, None) + carry, step_out = _run_scan_step(carry) all_per_steps["objective"].append(step_out["objective"]) all_per_steps["train_metrics"].append(step_out["train_metrics"]) all_per_steps["test_metrics"].append(step_out["test_metrics"]) @@ -2219,11 +2226,16 @@ def do_run_on_historic_data_with_provided_coarse_weights( initial_weights, minimum_weight, params, + jnp.zeros_like(initial_weights), + jnp.ones_like(initial_weights), run_fingerprint["max_memory_days"], chunk_period, chunk_period, 1.0, False, + False, + False, + False, ) weights = _jax_fine_weights_from_actual_starts_and_diffs( @@ -2257,70 +2269,33 @@ def do_run_on_historic_data_with_provided_coarse_weights( # ) dynamic_input_flags = dynamic_inputs_dict["dynamic_input_flags"] dynamic_inputs = dynamic_inputs_dict["train_dynamic_inputs"] - resolved_dynamic_inputs = resolve_dynamic_input_components( - dynamic_inputs, - dynamic_input_flags, - static_dict, - ) - fees_array = resolved_dynamic_inputs["fees"] - arb_thresh_array = resolved_dynamic_inputs["gas_cost"] - arb_fees_array = resolved_dynamic_inputs["arb_fees"] - trade_array = resolved_dynamic_inputs["trades"] - lp_supply_array = resolved_dynamic_inputs["lp_supply"] - - # initial_pool_value = run_fingerprint["initial_pool_value"] - # initial_value_per_token = arb_acted_upon_weights[0] * initial_pool_value - # initial_reserves = initial_value_per_token / arb_acted_upon_local_prices[0] - initial_reserves = params["initial_reserves"] - - # any of fees_array, arb_thresh_array, arb_fees_array, trade_array, and lp_supply_array - # can be singletons, in which case we repeat them for the length of the bout. - - # Determine the maximum leading dimension max_len = bout_length - 1 if run_fingerprint["arb_frequency"] != 1: max_len = max_len // run_fingerprint["arb_frequency"] - - fees_array = fees_array[:max_len] - arb_thresh_array = arb_thresh_array[:max_len] - arb_fees_array = arb_fees_array[:max_len] - lp_supply_array = lp_supply_array[:max_len] - if trade_array is not None: - trade_array = trade_array[:max_len] - # Broadcast input arrays to match the maximum leading dimension. - # If they are singletons, this will just repeat them for the length of the bout. - # If they are arrays of length bout_length, this will cause no change. - fees_array_broadcast = jnp.broadcast_to( - fees_array, (max_len,) + fees_array.shape[1:] - ) - arb_thresh_array_broadcast = jnp.broadcast_to( - arb_thresh_array, (max_len,) + arb_thresh_array.shape[1:] - ) - arb_fees_array_broadcast = jnp.broadcast_to( - arb_fees_array, (max_len,) + arb_fees_array.shape[1:] - ) - lp_supply_array_broadcast = jnp.broadcast_to( - lp_supply_array, (max_len,) + lp_supply_array.shape[1:] + materialized_inputs = materialize_dynamic_inputs( + dynamic_inputs, + dynamic_input_flags, + static_dict, + scan_len=max_len, + do_trades=run_fingerprint["do_trades"], + dtype=local_prices.dtype, ) - # if we are doing trades, the trades array must be of the same length as the other arrays - if run_fingerprint["do_trades"]: - assert trade_array.shape[0] == max_len protocol_fee_split = run_fingerprint.get("protocol_fee_split", 0.0) reserves = _jax_calc_quantAMM_reserves_with_dynamic_inputs( initial_reserves, weights, local_prices, - fees_array_broadcast, - arb_thresh_array_broadcast, - arb_fees_array_broadcast, + materialized_inputs.fees, + materialized_inputs.gas_cost, + materialized_inputs.arb_fees, jnp.array(static_dict["all_sig_variations"]), - None, + materialized_inputs.trades, run_fingerprint["do_trades"], run_fingerprint["do_arb"], run_fingerprint["noise_trader_ratio"], - lp_supply_array_broadcast, + materialized_inputs.lp_supply, protocol_fee_split=protocol_fee_split, ) From 42d07679efd15ae106048cfcb5520802680841e5 Mon Sep 17 00:00:00 2001 From: christian harrington Date: Wed, 4 Mar 2026 16:28:31 +0000 Subject: [PATCH 5/6] missing file commit --- tests/unit/test_jax_runner_utils.py | 55 +++++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_jax_runner_utils.py b/tests/unit/test_jax_runner_utils.py index 70dbf82..620ecb8 100644 --- a/tests/unit/test_jax_runner_utils.py +++ b/tests/unit/test_jax_runner_utils.py @@ -257,12 +257,12 @@ class TestDynamicInputPreparation: """Tests for dynamic input container construction and normalization.""" def test_empty_dynamic_input_arrays_have_stable_shapes(self): - """The empty hot-path bundle should have canonical placeholder arrays.""" + """The empty hot-path bundle should use singleton fee-like placeholders only.""" from quantammsim.core_simulator.dynamic_inputs import empty_dynamic_input_arrays dynamic_inputs = empty_dynamic_input_arrays() - assert dynamic_inputs.trades.shape == (1, 3) + assert dynamic_inputs.trades is None assert dynamic_inputs.fees.shape == (1,) assert dynamic_inputs.gas_cost.shape == (1,) assert dynamic_inputs.arb_fees.shape == (1,) @@ -472,6 +472,57 @@ def test_resolve_dynamic_input_components_prefers_dynamic_values(self): np.testing.assert_allclose(np.asarray(resolved["arb_fees"]), np.array([0.0003])) np.testing.assert_allclose(np.asarray(resolved["lp_supply"]), np.array([1500.0])) + def test_materialize_dynamic_inputs_leaves_trades_optional(self): + """No-trade paths should not expand placeholder trades into the scan inputs.""" + from quantammsim.core_simulator.dynamic_inputs import ( + empty_dynamic_input_arrays, + materialize_dynamic_inputs, + ) + + materialized = materialize_dynamic_inputs( + empty_dynamic_input_arrays(), + { + "use_dynamic_inputs": True, + "has_trades": False, + "has_dynamic_fees": False, + "has_dynamic_gas_cost": False, + "has_dynamic_arb_fees": False, + "has_lp_supply": False, + }, + static_dict={"fees": 0.003, "gas_cost": 2.5, "arb_fees": 0.0001}, + scan_len=4, + do_trades=False, + ) + + assert materialized.trades is None + np.testing.assert_allclose(np.asarray(materialized.fees), np.full(4, 0.003)) + np.testing.assert_allclose(np.asarray(materialized.gas_cost), np.full(4, 2.5)) + np.testing.assert_allclose(np.asarray(materialized.arb_fees), np.full(4, 0.0001)) + np.testing.assert_allclose(np.asarray(materialized.lp_supply), np.ones(4)) + + def test_materialize_dynamic_inputs_requires_trades_when_enabled(self): + """Trade-enabled scans should fail fast if no trade path is available.""" + from quantammsim.core_simulator.dynamic_inputs import ( + empty_dynamic_input_arrays, + materialize_dynamic_inputs, + ) + + with pytest.raises(ValueError, match="Trades must be provided"): + materialize_dynamic_inputs( + empty_dynamic_input_arrays(), + { + "use_dynamic_inputs": True, + "has_trades": False, + "has_dynamic_fees": True, + "has_dynamic_gas_cost": False, + "has_dynamic_arb_fees": False, + "has_lp_supply": False, + }, + static_dict={"fees": 0.003, "gas_cost": 0.0, "arb_fees": 0.0}, + scan_len=2, + do_trades=True, + ) + class TestGetSigVariations: """Tests for get_sig_variations function.""" From 35719bbd168db179ad53f6b861ae5fc7e19f92a7 Mon Sep 17 00:00:00 2001 From: MatthewWilletts <16085750+MatthewWilletts@users.noreply.github.com> Date: Fri, 6 Mar 2026 00:46:13 +0000 Subject: [PATCH 6/6] fix: replace stale fees_array references with dynamic_inputs check The merge of dev into reclamm-phase-1 reintroduced a reference to the old fees_array/gas_cost_array/arb_fees_array/trades_array parameters in the fused reserves guard. These were replaced by the DynamicInputArrays container in the dynamic inputs refactor. Replace the stale check with `dynamic_inputs is None`, which is the correct guard under the new API. --- quantammsim/core_simulator/forward_pass.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/quantammsim/core_simulator/forward_pass.py b/quantammsim/core_simulator/forward_pass.py index 5d018fe..c1509e9 100644 --- a/quantammsim/core_simulator/forward_pass.py +++ b/quantammsim/core_simulator/forward_pass.py @@ -1011,10 +1011,7 @@ def forward_pass( and static_dict["arb_frequency"] == 1 and static_dict.get("turnover_penalty", 0.0) == 0.0 and static_dict.get("price_noise_sigma", 0.0) == 0.0 - and all( - ele is None - for ele in [fees_array, gas_cost_array, arb_fees_array, trades_array] - ) + and dynamic_inputs is None and 1440 % static_dict["chunk_period"] == 0 # chunk_period divides metric_period and not pool._rule_outputs_are_weights # only delta-based pools validated and static_dict["bout_length"] > 1440 * 2 # need ≥2 metric periods