Skip to content

Commit

Permalink
refactored and corrected, working well
Browse files Browse the repository at this point in the history
  • Loading branch information
jepidoptera committed Jun 13, 2024
1 parent fad3154 commit 7702f41
Showing 1 changed file with 46 additions and 41 deletions.
87 changes: 46 additions & 41 deletions hydradx/model/amm/concentrated_liquidity_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def calculate_sell_from_buy(self, tkn_sell: str, tkn_buy: str, buy_quantity: flo
else:
sell_quantity = buy_quantity * x_virtual / (y_virtual - buy_quantity)

return sell_quantity / (1 - self.fee)
return sell_quantity * (1 + self.fee)

def get_virtual_reserves(self):
x_virtual = self.liquidity[self.asset_x] + self.x_offset
Expand All @@ -143,7 +143,7 @@ def price(self, tkn: str, denomination: str = '') -> float:
def buy_spot(self, tkn_buy: str, tkn_sell: str, fee: float = None):
if fee is None:
fee = self.fee
return self.price(tkn_buy) / (1 - fee)
return self.price(tkn_buy) * (1 + fee)

def sell_spot(self, tkn_sell: str, tkn_buy: str, fee: float = None):
if fee is None:
Expand Down Expand Up @@ -203,7 +203,6 @@ def __init__(
def initialize_tick(self, tick: int, liquidity_net: float):
if tick % self.tick_spacing != 0:
raise ValueError(f"Tick values must be multiples of the tick spacing ({self.tick_spacing}).")
max_tick = tick + self.tick_spacing
price = tick_to_price(tick)
new_tick = Tick(
liquidity_net=liquidity_net,
Expand All @@ -215,7 +214,7 @@ def initialize_tick(self, tick: int, liquidity_net: float):

@property
def current_tick(self):
return round(price_to_tick(self.sqrt_price ** 2))
return round(price_to_tick(self.sqrt_price ** 2, self.tick_spacing))

def next_initialized_tick(self, zero_for_one):
search_direction = -1 if zero_for_one else 1
Expand Down Expand Up @@ -256,7 +255,7 @@ def swap(
else:
sqrt_price_limit = sqrt(price_limit)

while amountSpecifiedRemaining != 0:
while abs(amountSpecifiedRemaining) > 1e-12:
next_tick: Tick = self.next_initialized_tick(zeroForOne)

# get the price for the next tick
Expand All @@ -265,14 +264,13 @@ def swap(
)

# compute values to swap to the target tick, price limit, or point where input/output amount is exhausted
sqrt_price_current, amountIn, amountOut, feeAmount = self.compute_swap_step(
self.sqrt_price, amountIn, amountOut, feeAmount = self.compute_swap_step(
self.sqrt_price,
sqrt_ratio_target=(
sqrt_price_limit if
(sqrt_price_next < sqrt_price_limit if zeroForOne else sqrt_price_next > sqrt_price_limit)
else sqrt_price_next
),
liquidity=self.liquidity,
amount_remaining=amountSpecifiedRemaining
)

Expand All @@ -294,27 +292,27 @@ def swap(
# state.feeGrowthGlobalX128 += feeAmount / self.liquidity

# shift tick if we reached the next price
if sqrt_price_current == sqrt_price_next:
if self.sqrt_price == sqrt_price_next:
# if the tick is initialized, run the tick transition
self.liquidity += next_tick.liquidityNet if zeroForOne else -next_tick.liquidityNet
self.sqrt_price = sqrt_price_next

# update the agent's holdings
if tkn_buy not in agent.holdings:
agent.holdings[tkn_buy] = 0
if exact_input:
agent.holdings[tkn_buy] += amountCalculated
agent.holdings[tkn_buy] -= amountCalculated
agent.holdings[tkn_sell] -= sell_quantity
else:
agent.holdings[tkn_buy] += buy_quantity
agent.holdings[tkn_sell] -= amountCalculated

return self


def compute_swap_step(
self,
sqrt_ratio_current: float,
sqrt_ratio_target: float,
liquidity: float,
amount_remaining: float
) -> tuple[float, float, float, float]: # sqrt_ratio_nex, amountIn, amountOut, feeAmount

Expand All @@ -325,7 +323,7 @@ def compute_swap_step(

if exactIn:
amountRemainingLessFee = amount_remaining * (1 - self.fee)
amountIn = (
amountIn = ( # calculate amount that it would take to reach our sqrt price target
self.getAmount0Delta(sqrt_ratio_target, sqrt_ratio_current)
if zeroForOne else
self.getAmount1Delta(sqrt_ratio_current, sqrt_ratio_target)
Expand All @@ -338,7 +336,7 @@ def compute_swap_step(
zero_for_one=zeroForOne
)
else:
amountOut = (
amountOut = ( # calculate amount that it would take to reach our sqrt price target
self.getAmount1Delta(sqrt_ratio_target, sqrt_ratio_current)
if zeroForOne else
self.getAmount0Delta(sqrt_ratio_current, sqrt_ratio_target)
Expand All @@ -356,20 +354,20 @@ def compute_swap_step(
# get the input/output amounts
if zeroForOne:
if not (is_max and exactIn):
amountIn = -amount_remaining # self.getAmount0Delta(sqrt_ratio_next, sqrt_ratio_current)
amountIn = self.getAmount0Delta(sqrt_ratio_next, sqrt_ratio_current)
if exactIn or not is_max:
amountOut = self.getAmount1Delta(sqrt_ratio_next, sqrt_ratio_current)
else:
if not(is_max and exactIn):
amountIn = self.getAmount1Delta(sqrt_ratio_current, sqrt_ratio_next)
if exactIn or not is_max:
amountOut = -amount_remaining # self.getAmount0Delta(sqrt_ratio_current, sqrt_ratio_next)
amountOut = self.getAmount0Delta(sqrt_ratio_current, sqrt_ratio_next)

# cap the output amount to not exceed the remaining output amount
if not exactIn and amountOut > -amount_remaining:
if not exactIn and amountOut > amount_remaining:
amountOut = -amount_remaining

if exactIn and sqrt_ratio_next != sqrt_ratio_target:
if exactIn and not is_max:
# we didn't reach the target, so take the remainder of the maximum input as fee
feeAmount = amount_remaining - amountIn
else:
Expand All @@ -396,9 +394,9 @@ def getNextSqrtPriceFromInput(
):
# // round to make sure that we don't pass the target price
return (
self.getNextSqrtPriceFromAmount0RoundingUp(amount=amount_in, add=True)
self.getNextSqrtPriceFromAmount0(amount=amount_in, add=True)
if zero_for_one else
self.getNextSqrtPriceFromAmount1RoundingDown(amount=amount_in, add=True)
self.getNextSqrtPriceFromAmount1(amount=amount_in, add=True)
)

# /// @notice Gets the next sqrt price given an output amount of token0 or token1
Expand All @@ -415,9 +413,9 @@ def getNextSqrtPriceFromOutput(
):
# round to make sure that we pass the target price
return (
self.getNextSqrtPriceFromAmount1RoundingDown(amount=amount_out, add=False)
self.getNextSqrtPriceFromAmount1(amount=amount_out, add=False)
if zero_for_one else
self.getNextSqrtPriceFromAmount0RoundingUp(amount=amount_out, add=False)
self.getNextSqrtPriceFromAmount0(amount=amount_out, add=False)
)

# Gets the next sqrt price given a delta of token0
Expand All @@ -426,23 +424,19 @@ def getNextSqrtPriceFromOutput(
# @param amount How much of token0 to add or remove from virtual reserves
# @param add Whether to add or remove the amount of token0
# @return The price after adding or removing amount, depending on add
def getNextSqrtPriceFromAmount0RoundingUp(
def getNextSqrtPriceFromAmount0(
self,
amount: float,
add: bool,
sqrt_ratio: float = None,
liquidity: float = None,
):
# we short circuit amount == 0 because the result is otherwise not guaranteed to equal the input price
sqrt_ratio = sqrt_ratio or self.sqrt_price
liquidity = liquidity or self.liquidity
if amount == 0:
return sqrt_ratio
# we short circuit amount == 0 because the result is otherwise not guaranteed to equal the input price
return self.sqrt_price
if add:
denominator = liquidity + amount * sqrt_ratio
denominator = self.liquidity + amount * self.sqrt_price
else:
denominator = liquidity - amount * sqrt_ratio
return liquidity * sqrt_ratio / denominator
denominator = self.liquidity - amount * self.sqrt_price
return self.liquidity * self.sqrt_price / denominator


# Gets the next sqrt price given a delta of token1
Expand All @@ -451,18 +445,29 @@ def getNextSqrtPriceFromAmount0RoundingUp(
# @param amount How much of token1 to add, or remove, from virtual reserves
# @param add Whether to add, or remove, the amount of token1
# @return The price after adding or removing `amount`
def getNextSqrtPriceFromAmount1RoundingDown(
def getNextSqrtPriceFromAmount1(
self,
amount: float,
add: bool,
sqrt_ratio: float = None,
liquidity: float = None,
add: bool
):
# if we're adding (subtracting), rounding down requires rounding the quotient down (up)
# in both cases, avoid a mulDiv for most inputs
sqrt_ratio = sqrt_ratio or self.sqrt_price
liquidity = liquidity or self.liquidity
if add:
return sqrt_ratio + amount / liquidity
return self.sqrt_price + amount / self.liquidity
else:
return self.sqrt_price - amount / self.liquidity


def buy_spot(self, tkn_buy: str, tkn_sell: str, fee: float = None):
if fee is None:
fee = self.fee
if tkn_buy == self.asset_list[0]:
return self.sqrt_price ** 2 * (1 + fee)
else:
return 1 / (self.sqrt_price ** 2) * (1 + fee)

def sell_spot(self, tkn_sell: str, tkn_buy: str, fee: float = None):
if fee is None:
fee = self.fee
if tkn_sell == self.asset_list[0]:
return self.sqrt_price ** 2 * (1 - fee)
else:
return sqrt_ratio - amount / liquidity
return 1 / (self.sqrt_price ** 2) * (1 - fee)

0 comments on commit 7702f41

Please sign in to comment.