Skip to content

Commit

Permalink
Merge pull request #3 from oraichain/fix/potential-overflow-math
Browse files Browse the repository at this point in the history
Fix/potential overflow math
  • Loading branch information
tubackkhoa authored Jul 16, 2024
2 parents 1d6b64b + 801a361 commit 549177d
Show file tree
Hide file tree
Showing 15 changed files with 122 additions and 196 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,7 @@ data/

# code analyzer
.scannerwork/
clippy-report.json
clippy-report.json

# code coverage
*.html
19 changes: 9 additions & 10 deletions contracts/oraiswap-v3/src/entrypoints/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,26 +103,25 @@ pub fn calculate_swap(

// make remaining amount smaller
if by_amount_in {
remaining_amount = remaining_amount
.checked_sub(result.amount_in + result.fee_amount)
.map_err(|_| ContractError::Sub)?;
remaining_amount =
remaining_amount.checked_sub(result.amount_in.checked_add(result.fee_amount)?)?;
} else {
remaining_amount = remaining_amount
.checked_sub(result.amount_out)
.map_err(|_| ContractError::Sub)?;
remaining_amount = remaining_amount.checked_sub(result.amount_out)?;
}

pool.add_fee(
result.fee_amount,
x_to_y,
state::CONFIG.load(store)?.protocol_fee,
)?;
event_fee_amount += result.fee_amount;
event_fee_amount = event_fee_amount.checked_add(result.fee_amount)?;

pool.sqrt_price = result.next_sqrt_price;

total_amount_in += result.amount_in + result.fee_amount;
total_amount_out += result.amount_out;
total_amount_in = total_amount_in
.checked_add(result.amount_in)?
.checked_add(result.fee_amount)?;
total_amount_out = total_amount_out.checked_add(result.amount_out)?;

// Fail if price would go over swap limit
if pool.sqrt_price == sqrt_price_limit && !remaining_amount.is_zero() {
Expand Down Expand Up @@ -155,7 +154,7 @@ pub fn calculate_swap(
)?;

remaining_amount = amount_after_tick_update;
total_amount_in += amount_to_add;
total_amount_in = total_amount_in.checked_add(amount_to_add)?;

if let UpdatePoolTick::TickInitialized(tick) = tick_update {
if has_crossed {
Expand Down
7 changes: 2 additions & 5 deletions contracts/oraiswap-v3/src/entrypoints/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,7 @@ pub fn create_position(
return Err(ContractError::InvalidTickIndex {});
}
let pool_key_db = pool_key.key();
let mut pool = POOLS
.load(deps.storage, &pool_key_db)
.map_err(|_| ContractError::PoolNotFound {})?;
let mut pool = POOLS.load(deps.storage, &pool_key_db)?;

let mut lower_tick = match state::get_tick(deps.storage, &pool_key, lower_tick) {
Ok(tick) => tick,
Expand Down Expand Up @@ -586,8 +584,7 @@ pub fn create_pool(
current_timestamp,
fee_tier.tick_spacing,
config.admin,
)
.map_err(|_| ContractError::CreatePoolError)?;
)?;

POOLS.save(deps.storage, &db_key, &pool)?;

Expand Down
16 changes: 13 additions & 3 deletions contracts/oraiswap-v3/src/entrypoints/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,13 +321,23 @@ pub fn get_liquidity_ticks_amount(
let max_chunk = state::get_bitmap_item(deps.storage, max_chunk_index, &pool_key).unwrap_or(0);

let mut amount: u32 = 0;
amount += active_bits_in_range(min_chunk, min_bit, (CHUNK_SIZE - 1) as u8);
amount += active_bits_in_range(max_chunk, 0, max_bit);
amount = amount
.checked_add(active_bits_in_range(
min_chunk,
min_bit,
(CHUNK_SIZE - 1) as u8,
))
.ok_or(ContractError::Add)?;
amount = amount
.checked_add(active_bits_in_range(max_chunk, 0, max_bit))
.ok_or(ContractError::Add)?;

for i in (min_chunk_index + 1)..max_chunk_index {
let chunk = state::get_bitmap_item(deps.storage, i, &pool_key).unwrap_or(0);

amount += chunk.count_ones();
amount = amount
.checked_add(chunk.count_ones())
.ok_or(ContractError::Add)?;
}

Ok(amount)
Expand Down
68 changes: 17 additions & 51 deletions contracts/oraiswap-v3/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ pub enum ContractError {
#[error("{0}")]
FromUtf8(#[from] FromUtf8Error),

#[error("{0}")]
CheckMathOverUnderFlowError(String),

#[error("invalid tick spacing")]
InvalidTickSpacing,

Expand Down Expand Up @@ -53,24 +56,9 @@ pub enum ContractError {
#[error("subtraction underflow")]
Sub,

#[error("update_liquidity: liquidity + liquidity_delta overflow")]
UpdateLiquidityPlusOverflow,

#[error("update_liquidity: liquidity - liquidity_delta underflow")]
UpdateLiquidityMinusOverflow,

#[error("empty position pokes")]
EmptyPositionPokes,

#[error("position not found")]
PositionNotFound,

#[error("position add liquidity overflow")]
PositionAddLiquidityOverflow,

#[error("position remove liquidity underflow")]
PositionRemoveLiquidityUnderflow,

#[error("price limit reached")]
PriceLimitReached,

Expand All @@ -80,15 +68,6 @@ pub enum ContractError {
#[error("current_timestamp - pool.start_timestamp underflow")]
TimestampSubOverflow,

#[error("pool not found")]
PoolNotFound,

#[error("pool.liquidity + tick.liquidity_change overflow")]
PoolAddTickLiquidityOverflow,

#[error("pool.liquidity - tick.liquidity_change underflow")]
PoolSubTickLiquidityUnderflow,

#[error("tick limit reached")]
TickLimitReached,

Expand All @@ -98,12 +77,6 @@ pub enum ContractError {
#[error("tick already exist")]
TickAlreadyExist,

#[error("tick add liquidity overflow")]
TickAddLiquidityOverflow,

#[error("tick remove liquidity underflow")]
TickRemoveLiquidityUnderflow,

#[error("invalid tick liquidity")]
InvalidTickLiquidity,

Expand All @@ -128,12 +101,6 @@ pub enum ContractError {
#[error("calcaule_sqrt_price::checked_div division failed")]
CheckedDiv,

#[error("calculate_sqrt_price: parsing scale failed")]
ParseScale,

#[error("extending liquidity overflow")]
ExtendLiquidityOverflow,

#[error("big_liquidity -/+ sqrt_price * x")]
BigLiquidityOverflow,

Expand All @@ -149,15 +116,9 @@ pub enum ContractError {
#[error("Upper Sqrt Price < Current Sqrt Price")]
UpperSqrtPriceLess,

#[error("overflow in calculating liquidity")]
OverflowInCalculatingLiquidity,

#[error("Current Sqrt Price < Lower Sqrt Price")]
CurrentSqrtPriceLess,

#[error("overflow while casting to TokenAmount")]
OverflowCastingTokenAmount,

#[error("unauthorized")]
Unauthorized {},

Expand All @@ -173,21 +134,12 @@ pub enum ContractError {
#[error("no gain swap")]
NoGainSwap,

#[error("swap failed")]
SwapFailed,

#[error("amount under minimum amount out")]
AmountUnderMinimumAmountOut,

#[error("invalid pool key")]
InvalidPoolKey,

#[error("pool already exist")]
PoolAlreadyExist,

#[error("pool not created")]
CreatePoolError,

#[error("FeeTierNotFound")]
FeeTierNotFound,

Expand All @@ -200,3 +152,17 @@ impl From<ContractError> for StdError {
Self::generic_err(source.to_string())
}
}

// Implementing From<String> for ContractError
impl From<String> for ContractError {
fn from(error: String) -> Self {
ContractError::CheckMathOverUnderFlowError(error)
}
}

// Implementing From<&str> for ContractError
impl From<&str> for ContractError {
fn from(error: &str) -> Self {
ContractError::CheckMathOverUnderFlowError(error.to_string())
}
}
25 changes: 8 additions & 17 deletions contracts/oraiswap-v3/src/logic/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ pub fn get_liquidity_by_x_sqrt_price(
* U256::from(nominator.get())
* U256::from(Liquidity::from_integer(1).get())
/ U256::from(denominator.get()))
.try_into()
.map_err(|_| ContractError::OverflowInCalculatingLiquidity)?,
.try_into()?,
);
return Ok(SingleTokenLiquidity {
l: liquidity,
Expand All @@ -154,8 +153,7 @@ pub fn get_liquidity_by_x_sqrt_price(
* U256::from(nominator.get())
* U256::from(Liquidity::from_integer(1).get())
/ U256::from(denominator.get()))
.try_into()
.map_err(|_| ContractError::OverflowInCalculatingLiquidity)?,
.try_into()?,
);

let sqrt_price_diff = current_sqrt_price - lower_sqrt_price;
Expand Down Expand Up @@ -209,8 +207,7 @@ pub fn get_liquidity_by_y_sqrt_price(
* U256::from(SqrtPrice::from_integer(1).get())
* U256::from(Liquidity::from_integer(1).get())
/ U256::from(sqrt_price_diff.get()))
.try_into()
.map_err(|_| ContractError::OverflowInCalculatingLiquidity)?,
.try_into()?,
);
return Ok(SingleTokenLiquidity {
l: liquidity,
Expand All @@ -224,8 +221,7 @@ pub fn get_liquidity_by_y_sqrt_price(
* U256::from(SqrtPrice::from_integer(1).get())
* U256::from(Liquidity::from_integer(1).get())
/ U256::from(sqrt_price_diff.get()))
.try_into()
.map_err(|_| ContractError::OverflowInCalculatingLiquidity)?,
.try_into()?,
);
let denominator =
(current_sqrt_price.big_mul(upper_sqrt_price)).big_div(SqrtPrice::from_integer(1));
Expand All @@ -252,14 +248,11 @@ pub fn calculate_x(
TokenAmount::new(
((U256::from(common) + U256::from(Liquidity::from_integer(1).get()) - U256::from(1))
/ U256::from(Liquidity::from_integer(1).get()))
.try_into()
.map_err(|_| ContractError::OverflowCastingTokenAmount)?,
.try_into()?,
)
} else {
TokenAmount::new(
(U256::from(common) / U256::from(Liquidity::from_integer(1).get()))
.try_into()
.map_err(|_| ContractError::OverflowCastingTokenAmount)?,
(U256::from(common) / U256::from(Liquidity::from_integer(1).get())).try_into()?,
)
})
}
Expand All @@ -275,15 +268,13 @@ pub fn calculate_y(
(((U256::from(sqrt_price_diff.get()) * U256::from(shifted_liquidity))
+ U256::from(SqrtPrice::from_integer(1).get() - 1))
/ U256::from(SqrtPrice::from_integer(1).get()))
.try_into()
.map_err(|_| ContractError::OverflowCastingTokenAmount)?,
.try_into()?,
)
} else {
TokenAmount::new(
(U256::from(sqrt_price_diff.get()) * U256::from(shifted_liquidity)
/ U256::from(SqrtPrice::from_integer(1).get()))
.try_into()
.map_err(|_| ContractError::OverflowCastingTokenAmount)?,
.try_into()?,
)
})
}
Expand Down
Loading

0 comments on commit 549177d

Please sign in to comment.