diff --git a/src/driftpy/clearing_house_user.py b/src/driftpy/clearing_house_user.py index e545dbdf..5b16af25 100644 --- a/src/driftpy/clearing_house_user.py +++ b/src/driftpy/clearing_house_user.py @@ -54,18 +54,18 @@ async def get_user_account(self): async def get_user_positions_account(self) -> UserPositions: user_account = self.get_user_account() - positions = cast( + positions_account = cast( UserPositions, await self.clearing_house.program.account["UserPositions"].fetch( user_account.positions ), ) - return positions + return positions_account async def get_user_position(self, market_index) -> MarketPosition: - positions = await self.get_user_positions_account() - for position in positions: + positions_account = await self.get_user_positions_account() + for position in positions_account.positions: if position.market_index == market_index: return position return MarketPosition( @@ -73,10 +73,10 @@ async def get_user_position(self, market_index) -> MarketPosition: ) async def get_unrealised_pnl(self, market_index=None): - positions = await self.get_user_positions_account() + positions_account = await self.get_user_positions_account() pnl = 0 - for position in positions: + for position in positions_account.positions: if market_index is not None and position.market_index == market_index: market = self.clearing_house.get_market( position.market_index @@ -90,9 +90,9 @@ async def get_total_collateral(self): return collateral + self.get_unrealised_pnl() async def get_total_position_value(self): - positions = await self.get_user_positions_account() + positions_account = await self.get_user_positions_account() value = 0 - for position in positions: + for position in positions_account.positions: market = self.clearing_house.get_market( position.market_index ) # todo repeat querying @@ -101,9 +101,9 @@ async def get_total_position_value(self): return value async def get_position_value(self, market_index=None): - positions = await self.get_user_positions_account() + positions_account = await self.get_user_positions_account() value = 0 - for position in positions: + for position in positions_account.positions: if market_index is not None and position.market_index == market_index: market = self.clearing_house.get_market( position.market_index diff --git a/src/driftpy/math/amm.py b/src/driftpy/math/amm.py index 2d435c66..1397ea0d 100644 --- a/src/driftpy/math/amm.py +++ b/src/driftpy/math/amm.py @@ -1,5 +1,16 @@ from driftpy.constants.numeric_constants import MARK_PRICE_PRECISION, PEG_PRECISION -from driftpy.types import SwapDirection, AssetType, PositionDirection +from driftpy.types import PositionDirection +from sumtypes import constructor # type: ignore + + +class SwapDirection: + ADD = constructor() + REMOVE = constructor() + + +class AssetType: + QUOTE = constructor() + BASE = constructor() def calculate_price(base_asset_amount, quote_asset_amount, peg_multiplier): @@ -9,14 +20,19 @@ def calculate_price(base_asset_amount, quote_asset_amount, peg_multiplier): return (quote_asset_amount * peg_multiplier / PEG_PRECISION) / base_asset_amount -def calculate_swap_output(input_asset_reserve, swap_amount, invariant): - new_input_asset_reserve = input_asset_reserve + swap_amount +def calculate_swap_output( + input_asset_reserve, swap_amount, swap_direction: SwapDirection, invariant +): + if swap_direction == SwapDirection.ADD: + new_input_asset_reserve = input_asset_reserve + swap_amount + else: + new_input_asset_reserve = input_asset_reserve - swap_amount new_output_asset_reserve = invariant / new_input_asset_reserve return [new_input_asset_reserve, new_output_asset_reserve] def calculate_amm_reserves_after_swap( - amm, input_asset_type: AssetType, swap_amount, swap_direction: PositionDirection + amm, input_asset_type: AssetType, swap_amount, swap_direction: SwapDirection ): if input_asset_type == AssetType.QUOTE: @@ -31,6 +47,7 @@ def calculate_amm_reserves_after_swap( [new_quote_asset_reserve, new_base_asset_reserve] = calculate_swap_output( amm.quote_asset_reserve / MARK_PRICE_PRECISION, swap_amount, + swap_direction, (amm.sqrt_k / MARK_PRICE_PRECISION) ** 2, ) @@ -41,6 +58,7 @@ def calculate_amm_reserves_after_swap( [new_base_asset_reserve, new_quote_asset_reserve] = calculate_swap_output( amm.base_asset_reserve / MARK_PRICE_PRECISION, swap_amount, + swap_direction, (amm.sqrt_k / MARK_PRICE_PRECISION) ** 2, ) diff --git a/src/driftpy/math/positions.py b/src/driftpy/math/positions.py index f9f2625b..debf480e 100644 --- a/src/driftpy/math/positions.py +++ b/src/driftpy/math/positions.py @@ -27,6 +27,8 @@ AMM_TIMES_PEG_TO_QUOTE_PRECISION_RATIO, ) +from driftpy.math.amm import AssetType + def calculate_base_asset_value(market: Market, user_position: MarketPosition) -> int: @@ -41,9 +43,9 @@ def calculate_base_asset_value(market: Market, user_position: MarketPosition) -> new_quote_asset_reserve, _ = calculate_amm_reserves_after_swap( market.amm, - "base", + AssetType.BASE, abs(user_position.base_asset_amount), - get_swap_direction("base", direction_to_close), + get_swap_direction(AssetType.BASE, direction_to_close), ) result = None @@ -65,7 +67,7 @@ def calculate_base_asset_value(market: Market, user_position: MarketPosition) -> def calculate_position_pnl( market: Market, market_position: MarketPosition, with_funding=False ): - pnl = 0 + pnl = 0.0 if market_position.base_asset_amount == 0: return pnl @@ -79,13 +81,13 @@ def calculate_position_pnl( if with_funding: funding_rate_pnl = 0.0 - pnl += funding_rate_pnl / PRICE_TO_QUOTE_PRECISION + pnl += funding_rate_pnl / float(PRICE_TO_QUOTE_PRECISION) return pnl def calculate_position_funding_pnl(market: Market, market_position: MarketPosition): - funding_pnl = 0 + funding_pnl = 0.0 if market_position.base_asset_amount == 0: return funding_pnl @@ -100,7 +102,7 @@ def calculate_position_funding_pnl(market: Market, market_position: MarketPositi market_position.last_cumulative_funding_rate - amm_cum_funding_rate ) * market_position.base_asset_amount - funding_pnl /= int(AMM_RESERVE_PRECISION * FUNDING_PRECISION) + funding_pnl /= float(AMM_RESERVE_PRECISION * FUNDING_PRECISION) return funding_pnl diff --git a/src/driftpy/math/trade.py b/src/driftpy/math/trade.py index b9cd61a0..4e809ff7 100644 --- a/src/driftpy/math/trade.py +++ b/src/driftpy/math/trade.py @@ -2,7 +2,7 @@ from driftpy.math.amm import calculate_price, calculate_amm_reserves_after_swap from driftpy.math.market import calculate_mark_price -from driftpy.constants import ( +from driftpy.constants.numeric_constants import ( MARK_PRICE_PRECISION, PEG_PRECISION, FUNDING_PRECISION,