From 1431339c0676f2cd2b19f6efba7a0ed0147b345e Mon Sep 17 00:00:00 2001 From: poliwop Date: Tue, 10 Sep 2024 16:26:39 -0500 Subject: [PATCH] Fixed some tests --- hydradx/model/amm/omnipool_amm.py | 49 ++++++++++------------------ hydradx/tests/test_omnipool_amm.py | 9 ++--- hydradx/tests/test_omnipool_state.py | 2 +- 3 files changed, 23 insertions(+), 37 deletions(-) diff --git a/hydradx/model/amm/omnipool_amm.py b/hydradx/model/amm/omnipool_amm.py index 2ae68cdf..007416fa 100644 --- a/hydradx/model/amm/omnipool_amm.py +++ b/hydradx/model/amm/omnipool_amm.py @@ -1184,7 +1184,7 @@ def remove_liquidity(self, agent: Agent, quantity: float = None, tkn_remove: str elif quantity is not None and agent.nfts[nft_id].shares < quantity: return self.fail_transaction('Agent does not have enough shares in specified position.', agent) - if self.remove_liquidity_volatility_threshold: + if self.remove_liquidity_volatility_threshold and self.remove_liquidity_volatility_threshold < float('inf'): if self.oracles['price']: volatility = abs( self.oracles['price'].price[tkn_remove] / self.current_block.price[tkn_remove] - 1 @@ -1577,43 +1577,28 @@ def value_assets(prices: dict, assets: dict) -> float: ]) +def _turn_off_validations(omnipool: OmnipoolState) -> OmnipoolState: + new_state = omnipool.copy() + new_state.remove_liquidity_volatility_threshold = float('inf') + new_state.max_withdrawal_per_block = float('inf') + return new_state + + def cash_out_omnipool(omnipool: OmnipoolState, agent: Agent, prices) -> float: """ return the value of the agent's holdings if they withdraw all liquidity and then sell at current spot prices """ - if 'LRNA' not in agent.holdings: - agent.holdings['LRNA'] = 0 - agent_holdings = {tkn: agent.holdings[tkn] for tkn in list(agent.holdings.keys())} - liquidity_removed = {tkn: 0 for tkn in omnipool.asset_list} - lrna_removed = {tkn: 0 for tkn in omnipool.asset_list} - - for key in agent.holdings.keys(): - if isinstance(key, tuple): - tkn = key[1] - del agent_holdings[key] - - if agent.holdings[key] > 0: - val = omnipool.calculate_remove_liquidity(agent, tkn_remove=tkn) - delta_qa, delta_r, delta_q, delta_s, delta_b, delta_l = val[:6] - agent_holdings['LRNA'] += delta_qa - if tkn not in agent_holdings: - agent_holdings[tkn] = 0 - agent_holdings[tkn] -= delta_r - liquidity_removed[tkn] -= delta_r - lrna_removed[tkn] -= delta_q - for nft_id in agent.nfts: - nft = agent.nfts[nft_id] - if isinstance(nft, OmnipoolLiquidityPosition): - val = omnipool.calculate_remove_liquidity(agent, nft_id=nft_id) - delta_qa, delta_r, delta_q, delta_s, delta_b, delta_l = val[:6] - agent_holdings['LRNA'] += delta_qa - if nft.tkn not in agent_holdings: - agent_holdings[nft.tkn] = 0 - agent_holdings[nft.tkn] -= delta_r - liquidity_removed[nft.tkn] -= delta_r - lrna_removed[nft.tkn] -= delta_q + new_state, new_agent = _turn_off_validations(omnipool), agent.copy() + if 'LRNA' not in new_agent.holdings: + new_agent.holdings['LRNA'] = 0 + for tkn in omnipool.asset_list: + new_state, new_agent = simulate_remove_liquidity(new_state, new_agent, tkn_remove=tkn) + + agent_holdings = new_agent.holdings + lrna_removed = {tkn: omnipool.lrna[tkn] - new_state.lrna[tkn] for tkn in omnipool.lrna} + liquidity_removed = {tkn: omnipool.liquidity[tkn] - new_state.liquidity[tkn] for tkn in omnipool.liquidity} if 'LRNA' in prices: raise ValueError('LRNA price should not be given.') diff --git a/hydradx/tests/test_omnipool_amm.py b/hydradx/tests/test_omnipool_amm.py index 0eee5912..570d7d30 100644 --- a/hydradx/tests/test_omnipool_amm.py +++ b/hydradx/tests/test_omnipool_amm.py @@ -3121,14 +3121,15 @@ def test_cash_out_multiple_positions_works_with_lrna(price1: float, price2: floa tkn = 'DOT' amt1 = r * initial_state.shares[tkn] / 5 amt2 = initial_state.shares[tkn] / 5 - amt1 - holdings1 = {(initial_state.unique_id, tkn): amt1, (initial_state.unique_id + '_1', tkn): amt2} - prices1 = {(initial_state.unique_id, tkn): price1, (initial_state.unique_id + '_1', tkn): price2} - agent = Agent(holdings=holdings1, share_prices=prices1) + holdings1 = {(initial_state.unique_id, tkn): amt1} + prices1 = {(initial_state.unique_id, tkn): price1} + nft = OmnipoolLiquidityPosition(tkn, price2, amt2, 0, initial_state.unique_id) + agent = Agent(holdings=holdings1, share_prices=prices1, nfts={'pos1': nft}) spot_prices = {tkn: initial_state.price(initial_state, tkn, 'USD') for tkn in initial_state.asset_list} cash_out = cash_out_omnipool(initial_state, agent, spot_prices) state = initial_state.copy() - state.remove_all_liquidity(agent, tkn) + state.remove_liquidity(agent, tkn_remove=tkn) dot_value = agent.holdings['DOT'] * spot_prices['DOT'] lrna_value = agent.holdings['LRNA'] * initial_state.price(initial_state, 'LRNA', 'USD') assert dot_value < cash_out < dot_value + lrna_value # cash_out will be less than dot + lrna due to slippage diff --git a/hydradx/tests/test_omnipool_state.py b/hydradx/tests/test_omnipool_state.py index aed0aa52..40b1c8e2 100644 --- a/hydradx/tests/test_omnipool_state.py +++ b/hydradx/tests/test_omnipool_state.py @@ -311,7 +311,7 @@ def test_cash_out_accuracy(omnipool: oamm.OmnipoolState, share_price_ratio, lp_i ) lrna_profits[tkn] = withdraw_agent.holdings[tkn] - agent_holdings - del withdraw_agent.holdings['LRNA'] + del withdraw_agent.holdings['LRNA'] cash_count = sum([market_prices[tkn] * withdraw_agent.holdings[tkn] for tkn in withdraw_agent.holdings]) if cash_count != pytest.approx(cash_out, rel=1e-15): raise AssertionError('Cash out calculation is not accurate.')