Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions quantammsim/core_simulator/dynamic_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from dataclasses import dataclass
from typing import Any, NamedTuple, Optional

import jax.numpy as jnp


@dataclass(frozen=True)
class DynamicInputFrames:
"""Outer-layer container for optional pandas-backed dynamic inputs."""

trades: Optional[Any] = None
fees: Optional[Any] = None
gas_cost: Optional[Any] = None
arb_fees: Optional[Any] = None
lp_supply: Optional[Any] = None


class DynamicInputArrays(NamedTuple):
"""JAX pytree for dynamic simulation inputs with optional trade data."""

trades: Optional[jnp.ndarray]
fees: jnp.ndarray
gas_cost: jnp.ndarray
arb_fees: jnp.ndarray
lp_supply: jnp.ndarray


def default_dynamic_input_flags() -> dict:
"""Static dispatch flags for forward-pass path selection."""
return {
"use_dynamic_inputs": False,
"has_trades": False,
"has_dynamic_fees": False,
"has_dynamic_gas_cost": False,
"has_dynamic_arb_fees": False,
"has_lp_supply": False,
}


def dynamic_input_flags_from_frames(dynamic_input_frames: Optional[DynamicInputFrames]) -> dict:
"""Build stable dispatch flags from the outer-layer frame container."""
if dynamic_input_frames is None:
return default_dynamic_input_flags()

flags = {
"use_dynamic_inputs": False,
"has_trades": dynamic_input_frames.trades is not None,
"has_dynamic_fees": dynamic_input_frames.fees is not None,
"has_dynamic_gas_cost": dynamic_input_frames.gas_cost is not None,
"has_dynamic_arb_fees": dynamic_input_frames.arb_fees is not None,
"has_lp_supply": dynamic_input_frames.lp_supply is not None,
}
flags["use_dynamic_inputs"] = any(flags.values())
return flags


def resolve_dynamic_input_flags(
dynamic_inputs: Optional[DynamicInputArrays],
dynamic_input_flags: Optional[dict] = None,
) -> dict:
"""Return a safe dispatch flag set for the provided hot-path bundle."""
flags = (
default_dynamic_input_flags()
if dynamic_input_flags is None
else dict(dynamic_input_flags)
)
if dynamic_inputs is not None:
flags["use_dynamic_inputs"] = True
return flags


def empty_dynamic_input_arrays() -> DynamicInputArrays:
"""Create a canonical empty bundle."""
return DynamicInputArrays(
trades=None,
fees=jnp.zeros((1,), dtype=jnp.float64),
gas_cost=jnp.zeros((1,), dtype=jnp.float64),
arb_fees=jnp.zeros((1,), dtype=jnp.float64),
lp_supply=jnp.ones((1,), dtype=jnp.float64),
)


def resolve_dynamic_input_components(
dynamic_inputs: Optional[DynamicInputArrays],
dynamic_input_flags: dict,
static_dict: dict,
) -> dict:
"""Resolve dynamic-input leaves against static scalar defaults."""
arrays = empty_dynamic_input_arrays() if dynamic_inputs is None else dynamic_inputs
return {
"trades": arrays.trades if dynamic_input_flags["has_trades"] else None,
"fees": (
arrays.fees
if dynamic_input_flags["has_dynamic_fees"]
else jnp.asarray([static_dict["fees"]], dtype=jnp.float64)
),
"gas_cost": (
arrays.gas_cost
if dynamic_input_flags["has_dynamic_gas_cost"]
else jnp.asarray([static_dict["gas_cost"]], dtype=jnp.float64)
),
"arb_fees": (
arrays.arb_fees
if dynamic_input_flags["has_dynamic_arb_fees"]
else jnp.asarray([static_dict["arb_fees"]], dtype=jnp.float64)
),
"lp_supply": (
arrays.lp_supply
if dynamic_input_flags["has_lp_supply"]
else jnp.ones((1,), dtype=jnp.float64)
),
}


def _broadcast_dynamic_input_leaf(
input_name: str,
values: jnp.ndarray,
scan_len: int,
dtype,
) -> jnp.ndarray:
"""Broadcast a singleton dynamic-input leaf to the scan length."""
values = jnp.asarray(values, dtype=dtype)
if values.ndim == 0:
values = values.reshape((1,))
if values.shape[0] == scan_len:
return values
if values.shape[0] == 1:
return jnp.broadcast_to(values, (scan_len,) + values.shape[1:])
raise ValueError(
f"{input_name} has leading axis {values.shape[0]}, expected 1 or {scan_len}"
)


def materialize_dynamic_inputs(
dynamic_inputs: Optional[DynamicInputArrays],
dynamic_input_flags: Optional[dict],
static_dict: dict,
scan_len: int,
do_trades: bool,
dtype=jnp.float64,
) -> DynamicInputArrays:
"""Resolve and broadcast dynamic inputs for a specific scan length."""
if dynamic_input_flags is None and dynamic_inputs is not None:
flags = {
"use_dynamic_inputs": True,
"has_trades": do_trades,
"has_dynamic_fees": True,
"has_dynamic_gas_cost": True,
"has_dynamic_arb_fees": True,
"has_lp_supply": True,
}
else:
flags = resolve_dynamic_input_flags(dynamic_inputs, dynamic_input_flags)

resolved = resolve_dynamic_input_components(dynamic_inputs, flags, static_dict)

trades = None
if do_trades:
if resolved["trades"] is None:
raise ValueError("Trades must be provided when do_trades=True.")
trades = _broadcast_dynamic_input_leaf(
"trades", resolved["trades"], scan_len, dtype
)

return DynamicInputArrays(
trades=trades,
fees=_broadcast_dynamic_input_leaf("fees", resolved["fees"], scan_len, dtype),
gas_cost=_broadcast_dynamic_input_leaf(
"gas_cost", resolved["gas_cost"], scan_len, dtype
),
arb_fees=_broadcast_dynamic_input_leaf(
"arb_fees", resolved["arb_fees"], scan_len, dtype
),
lp_supply=_broadcast_dynamic_input_leaf(
"lp_supply", resolved["lp_supply"], scan_len, dtype
),
)
112 changes: 49 additions & 63 deletions quantammsim/core_simulator/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,25 @@
import numpy as np

from functools import partial
from quantammsim.core_simulator.dynamic_inputs import (
DynamicInputArrays,
default_dynamic_input_flags,
resolve_dynamic_input_flags,
)

np.seterr(all="raise")
np.seterr(under="print")


def _resolve_dynamic_inputs(dynamic_inputs, static_dict):
"""Return the incoming bundle plus static dispatch flags."""
dynamic_input_flags = resolve_dynamic_input_flags(
dynamic_inputs,
static_dict.get("dynamic_input_flags"),
)
return dynamic_inputs, dynamic_input_flags


def _apply_price_noise(prices, sigma, seed_int):
"""Apply multiplicative log-normal noise to prices.

Expand Down Expand Up @@ -839,15 +853,12 @@ def _calculate_return_value(
return return_metrics[return_val]()


@partial(jit, static_argnums=(7, 8))
@partial(jit, static_argnums=(4, 5))
def forward_pass(
params,
start_index,
prices,
trades_array=None,
fees_array=None,
gas_cost_array=None,
arb_fees_array=None,
dynamic_inputs=None,
pool=None,
static_dict=None,
):
Expand All @@ -870,17 +881,8 @@ def forward_pass(
prices : array-like
A 2D array of market prices for the assets involved in the simulation.

trades_array : array-like, optional
An array of trades to be considered in the simulation. Defaults to None.

fees_array : array-like, optional
An array of fees to be applied during the simulation. Defaults to None.

gas_cost_array : array-like, optional
An array of gas costs to be considered in the simulation. Defaults to None.

arb_fees_array : array-like, optional
An array of arbitrage fees to be applied during the simulation. Defaults to None.
dynamic_inputs : DynamicInputArrays, optional
Fixed-structure bundle of dynamic trades/fees/gas/arb/LP arrays.

pool : object
An instance of a pool object that provides methods
Expand Down Expand Up @@ -930,8 +932,8 @@ def forward_pass(
- The function handles different cases for fees and trades,
adjusting the calculation method accordingly:

1. If any of `fees_array`, `gas_cost_array`, `arb_fees_array`,
or `trades_array` is provided, it uses `pool.calculate_reserves_with_dynamic_inputs`.
1. If any dynamic-input flags are enabled, it uses
`pool.calculate_reserves_with_dynamic_inputs`.

2. If any of `fees`, `gas_cost`, or `arb_fees` in `static_dict` is a nonzero scalar value,
it uses `pool.calculate_reserves_with_fees`.
Expand Down Expand Up @@ -972,6 +974,7 @@ def forward_pass(
"training_data_kind": "historic",
"arb_frequency": 1,
"do_trades": False,
"dynamic_input_flags": default_dynamic_input_flags(),
}

# 'pool' has default of None only to handle how partial function
Expand Down Expand Up @@ -1008,10 +1011,7 @@ def forward_pass(
and static_dict["arb_frequency"] == 1
and static_dict.get("turnover_penalty", 0.0) == 0.0
and static_dict.get("price_noise_sigma", 0.0) == 0.0
and all(
ele is None
for ele in [fees_array, gas_cost_array, arb_fees_array, trades_array]
)
and dynamic_inputs is None
and 1440 % static_dict["chunk_period"] == 0 # chunk_period divides metric_period
and not pool._rule_outputs_are_weights # only delta-based pools validated
and static_dict["bout_length"] > 1440 * 2 # need ≥2 metric periods
Expand All @@ -1031,39 +1031,28 @@ def forward_pass(
# 1. Any of Fees, gas costs, and arb fees are provided as arrays, or trades are provided
# 2. Any of Fees, gas costs, and arb fees are nonzero scalar values, with no trades provided
# 3. Fees, gas costs, and arb fees are all zero, with no trades provided
dynamic_inputs, dynamic_input_flags = _resolve_dynamic_inputs(
dynamic_inputs, static_dict
)

fee_revenue = None
if any(
ele is not None
for ele in [fees_array, gas_cost_array, arb_fees_array, trades_array]
):
# Case 1, at least one of fees, gas costs, or arb fees is not None
if fees_array is None:
fees_array = jnp.array([static_dict["fees"]])
if gas_cost_array is None:
gas_cost_array = jnp.array([static_dict["gas_cost"]])
if arb_fees_array is None:
arb_fees_array = jnp.array([static_dict["arb_fees"]])
if dynamic_input_flags["use_dynamic_inputs"]:
# Case 1, at least one dynamic input is enabled
if hasattr(pool, "calculate_reserves_and_fee_revenue_with_dynamic_inputs"):
reserves, fee_revenue = pool.calculate_reserves_and_fee_revenue_with_dynamic_inputs(
params,
static_dict,
prices,
start_index,
fees_array=fees_array,
arb_thresh_array=gas_cost_array,
arb_fees_array=arb_fees_array,
trade_array=trades_array,
dynamic_inputs=dynamic_inputs,
)
else:
reserves = pool.calculate_reserves_with_dynamic_inputs(
params,
static_dict,
prices,
start_index,
fees_array=fees_array,
arb_thresh_array=gas_cost_array,
arb_fees_array=arb_fees_array,
trade_array=trades_array,
dynamic_inputs=dynamic_inputs,
)
elif True in (
ele > 0.0
Expand Down Expand Up @@ -1170,15 +1159,12 @@ def forward_pass(
return base_metric


@partial(jit, static_argnums=(7, 8))
@partial(jit, static_argnums=(4, 5))
def forward_pass_nograd(
params,
start_index,
prices,
trades_array=None,
fees_array=None,
gas_cost_array=None,
arb_fees_array=None,
dynamic_inputs=None,
pool=None,
static_dict=None,
):
Expand All @@ -1203,17 +1189,8 @@ def forward_pass_nograd(
prices : array-like
A 2D array of market prices for the assets involved in the simulation.

trades_array : array-like, optional
An array of trades to be considered in the simulation. Defaults to None.

fees_array : array-like, optional
An array of fees to be applied during the simulation. Defaults to None.

gas_cost_array : array-like, optional
An array of gas costs to be considered in the simulation. Defaults to None.

arb_fees_array : array-like, optional
An array of arbitrage fees to be applied during the simulation. Defaults to None.
dynamic_inputs : DynamicInputArrays, optional
Fixed-structure bundle of dynamic trades/fees/gas/arb/LP arrays.

pool : object
An instance of a pool object that provides methods
Expand Down Expand Up @@ -1263,8 +1240,8 @@ def forward_pass_nograd(
- The function handles different cases for fees and trades,
adjusting the calculation method accordingly:

1. If any of `fees_array`, `gas_cost_array`, `arb_fees_array`,
or `trades_array` is provided, it uses `pool.calculate_reserves_with_dynamic_inputs`.
1. If any dynamic-input flags are enabled, it uses
`pool.calculate_reserves_with_dynamic_inputs`.

2. If any of `fees`, `gas_cost`, or `arb_fees` in `static_dict` is a nonzero scalar value,
it uses `pool.calculate_reserves_with_fees`.
Expand All @@ -1289,14 +1266,23 @@ def forward_pass_nograd(
params = {k: stop_gradient(v) for k, v in params.items()}
start_index = stop_gradient(start_index)
prices = stop_gradient(prices)
if dynamic_inputs is not None:
dynamic_inputs = DynamicInputArrays(
trades=(
None
if dynamic_inputs.trades is None
else stop_gradient(dynamic_inputs.trades)
),
fees=stop_gradient(dynamic_inputs.fees),
gas_cost=stop_gradient(dynamic_inputs.gas_cost),
arb_fees=stop_gradient(dynamic_inputs.arb_fees),
lp_supply=stop_gradient(dynamic_inputs.lp_supply),
)
return forward_pass(
params,
start_index,
prices,
trades_array,
fees_array,
gas_cost_array,
arb_fees_array,
dynamic_inputs,
pool,
static_dict,
)
Loading