diff --git a/hydradx/model/amm/amm.py b/hydradx/model/amm/amm.py index aa43f7a7..60f49ad9 100644 --- a/hydradx/model/amm/amm.py +++ b/hydradx/model/amm/amm.py @@ -12,10 +12,9 @@ def __init__(self, fee_function: Callable, name: str): self.tkn = None def assign(self, exchange, tkn=''): - copy_self = copy.deepcopy(self) - copy_self.exchange = exchange - copy_self.tkn = tkn - return copy_self + self.exchange = exchange + self.tkn = tkn + return self def compute(self, tkn: str = '', delta_tkn: float = 0) -> float: return self.fee_function( diff --git a/hydradx/model/amm/global_state.py b/hydradx/model/amm/global_state.py index 540cdc7e..4de47448 100644 --- a/hydradx/model/amm/global_state.py +++ b/hydradx/model/amm/global_state.py @@ -90,15 +90,14 @@ def copy(self): def archive(self): if self.archive_all and not self.save_data: - return {'state': self.copy()} + return self.copy() elif self.save_data: return { datastream: self.save_data[datastream](self) for datastream in self.save_data } else: - record_state = ArchiveState(self) - return record_state + return ArchiveState(self) def evolve(self): self.time_step += 1 @@ -140,7 +139,7 @@ def market_prices(self, shares: dict) -> dict: if isinstance(share_id, tuple): pool_id = share_id[0] tkn_id = share_id[1] - prices[share_id] = self.pools[pool_id].usd_price(tkn_id) + prices[share_id] = self.pools[pool_id].usd_price(self.pools[pool_id], tkn_id) return prices @@ -184,7 +183,7 @@ def cash_out(self, agent: Agent) -> float: } prices = self.market_prices(withdraw_holdings) - return self.value_assets(prices, withdraw_holdings) + return value_assets(prices, withdraw_holdings) def pool_val(self, pool: AMM): """ get the total value of all liquidity in the pool. """ @@ -197,7 +196,7 @@ def impermanent_loss(self, agent: Agent) -> float: return self.cash_out(agent) / self.deposit_val(agent) - 1 def deposit_val(self, agent: Agent) -> float: - return self.value_assets( + return value_assets( self.market_prices(agent.holdings), agent.initial_holdings ) @@ -235,7 +234,7 @@ class ArchiveState: def __init__(self, state: GlobalState): self.time_step = state.time_step self.external_market = {k: v for k, v in state.external_market.items()} - self.pools = {k: OmnipoolArchiveState(v) for (k, v) in state.pools.items()} + self.pools = {k: v.archive() for (k, v) in state.pools.items()} self.agents = {k: AgentArchiveState(v) for (k, v) in state.agents.items()} @@ -322,9 +321,8 @@ def transform(state: GlobalState) -> GlobalState: def historical_prices(price_list: list[dict[str: float]]) -> Callable: def transform(state: GlobalState) -> GlobalState: - new_prices = price_list[state.time_step] - for tkn in new_prices: - state.external_market[tkn] = new_prices[tkn] + for tkn in price_list[state.time_step]: + state.external_market[tkn] = price_list[state.time_step][tkn] return state return transform diff --git a/hydradx/model/amm/oracle.py b/hydradx/model/amm/oracle.py index d9dd5ce7..1eae8684 100644 --- a/hydradx/model/amm/oracle.py +++ b/hydradx/model/amm/oracle.py @@ -22,13 +22,14 @@ def __init__(self, first_block: Block = None, decay_factor: float = 0, sma_equiv else: raise ValueError('Either decay_factor or sma_equivalent_length must be specified') self.length = sma_equivalent_length or 2 / self.decay_factor - 1 - self.asset_list = [] if last_values is not None: + self.asset_list = list((last_values['liquidity']).keys()) self.liquidity = {k: v for (k, v) in last_values['liquidity'].items()} self.price = {k: v for (k, v) in last_values['price'].items()} self.volume_in = {k: v for (k, v) in last_values['volume_in'].items()} self.volume_out = {k: v for (k, v) in last_values['volume_out'].items()} elif first_block is not None: + self.asset_list = first_block.asset_list self.liquidity = first_block.liquidity self.price = first_block.price self.volume_in = first_block.volume_in @@ -59,3 +60,12 @@ def update(self, block: Block): (1 - self.decay_factor) * self.volume_out[tkn] + self.decay_factor * block.volume_out[tkn] ) if tkn in self.volume_out else block.volume_out[tkn] return self + + +class OracleArchiveState: + def __init__(self, oracle: Oracle): + self.liquidity = {tkn: oracle.liquidity[tkn] for tkn in oracle.asset_list} + self.price = {tkn: oracle.price[tkn] for tkn in oracle.asset_list} + self.volume_in = {tkn: oracle.volume_in[tkn] for tkn in oracle.asset_list} + self.volume_out = {tkn: oracle.volume_out[tkn] for tkn in oracle.asset_list} + self.age = oracle.age diff --git a/hydradx/model/amm/trade_strategies.py b/hydradx/model/amm/trade_strategies.py index 8f758d79..b595c4f0 100644 --- a/hydradx/model/amm/trade_strategies.py +++ b/hydradx/model/amm/trade_strategies.py @@ -3,7 +3,7 @@ from .global_state import GlobalState, swap, add_liquidity, external_market_trade, withdraw_all_liquidity from .agents import Agent from .basilisk_amm import ConstantProductPoolState -from .omnipool_amm import OmnipoolState, usd_price +from .omnipool_amm import OmnipoolState from . import omnipool_amm as oamm from .stableswap_amm import StableSwapPoolState from typing import Callable @@ -34,7 +34,7 @@ def combo_function(state, agent_id) -> GlobalState: new_state = self.execute(state, agent_id) return other.execute(new_state, agent_id) - return TradeStrategy(combo_function, name='\n + '.join([self.name, other.name])) + return TradeStrategy(combo_function, name='\n'.join([self.name, other.name])) def random_swaps(