Skip to content

Commit

Permalink
add tests for erc4626 tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
bout3fiddy committed Oct 12, 2023
1 parent fa9731a commit d23f6e0
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 32 deletions.
12 changes: 4 additions & 8 deletions contracts/main/CurveStableSwapMetaNG.vy
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,14 @@ N_COINS: public(constant(uint256)) = 2
N_COINS_128: constant(int128) = 2
PRECISION: constant(uint256) = 10 ** 18

POOL_IS_REBASING_IMPLEMENTATION: public(immutable(bool))

BASE_POOL: public(immutable(address))
BASE_N_COINS: public(immutable(uint256))
BASE_COINS: public(immutable(DynArray[address, MAX_COINS]))

math: immutable(Math)
factory: immutable(Factory)
coins: public(immutable(DynArray[address, MAX_COINS]))
asset_types: public(immutable(DynArray[uint8, MAX_COINS]))
asset_types: immutable(DynArray[uint8, MAX_COINS])
stored_balances: DynArray[uint256, MAX_COINS]

# Fee specific vars
Expand Down Expand Up @@ -316,8 +314,6 @@ def __init__(
asset_types = _asset_types
rate_multipliers = _rate_multipliers

POOL_IS_REBASING_IMPLEMENTATION = 2 in _asset_types

for i in range(MAX_COINS):
if i < BASE_N_COINS:
# Approval needed for add_liquidity operation on base pool in
Expand Down Expand Up @@ -539,7 +535,7 @@ def _balances() -> DynArray[uint256, MAX_COINS]:

for i in range(N_COINS_128):

if POOL_IS_REBASING_IMPLEMENTATION:
if 2 in asset_types:
balances_i = ERC20(coins[i]).balanceOf(self) - self.admin_balances[i]
else:
balances_i = self.stored_balances[i] - self.admin_balances[i]
Expand Down Expand Up @@ -604,7 +600,7 @@ def exchange_received(
@param _min_dy Minimum amount of `j` to receive
@return Actual amount of `j` received
"""
assert not POOL_IS_REBASING_IMPLEMENTATION # dev: exchange_received not supported if pool contains rebasing tokens
assert not 2 in asset_types # dev: exchange_received not supported if pool contains rebasing tokens
return self._exchange(
msg.sender,
i,
Expand Down Expand Up @@ -664,7 +660,7 @@ def exchange_underlying_received(
@param _receiver Address that receives `j`
@return Actual amount of `j` received
"""
assert not POOL_IS_REBASING_IMPLEMENTATION # dev: exchange_received not supported if pool contains rebasing tokens
assert not 2 in asset_types # dev: exchange_received not supported if pool contains rebasing tokens
return self._exchange_underlying(
msg.sender,
i,
Expand Down
11 changes: 4 additions & 7 deletions contracts/main/CurveStableSwapNG.vy
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# @version 0.3.10
#pragma optimize gas
#pragma optimize codesize
"""
@title CurveStableSwapNG
@author Curve.Fi
Expand Down Expand Up @@ -139,11 +139,9 @@ N_COINS: public(immutable(uint256))
N_COINS_128: immutable(int128)
PRECISION: constant(uint256) = 10 ** 18

POOL_IS_REBASING_IMPLEMENTATION: public(immutable(bool))

factory: immutable(Factory)
coins: public(immutable(DynArray[address, MAX_COINS]))
asset_types: public(immutable(DynArray[uint8, MAX_COINS]))
asset_types: immutable(DynArray[uint8, MAX_COINS])
stored_balances: DynArray[uint256, MAX_COINS]

# Fee specific vars
Expand Down Expand Up @@ -264,7 +262,6 @@ def __init__(
N_COINS_128 = convert(__n_coins, int128)

rate_multipliers = _rate_multipliers
POOL_IS_REBASING_IMPLEMENTATION = 2 in _asset_types

factory = Factory(msg.sender)

Expand Down Expand Up @@ -469,7 +466,7 @@ def _balances() -> DynArray[uint256, MAX_COINS]:
if i == N_COINS_128:
break

if POOL_IS_REBASING_IMPLEMENTATION:
if 2 in asset_types:
balances_i = ERC20(coins[i]).balanceOf(self) - self.admin_balances[i]
else:
balances_i = self.stored_balances[i] - self.admin_balances[i]
Expand Down Expand Up @@ -534,7 +531,7 @@ def exchange_received(
@param _min_dy Minimum amount of `j` to receive
@return Actual amount of `j` received
"""
assert not POOL_IS_REBASING_IMPLEMENTATION # dev: exchange_received not supported if pool contains rebasing tokens
assert not 2 in asset_types # dev: exchange_received not supported if pool contains rebasing tokens
return self._exchange(
msg.sender,
i,
Expand Down
148 changes: 148 additions & 0 deletions contracts/mocks/ERC20RebasingConditional.vy
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# @version ^0.3.9

"""
@notice Rebasing ERC20 mock with rebase by 1% on every transfer
@dev This is for testing only, it is NOT safe for use
@dev Based on stEth implementation
"""


event Transfer:
_from: indexed(address)
_to: indexed(address)
_value: uint256


event Approval:
_owner: indexed(address)
_spender: indexed(address)
_value: uint256


name: public(String[64])
symbol: public(String[32])
decimals: public(uint256)
allowances: HashMap[address, HashMap[address, uint256]]

# <--- Rebase Parameters --->
totalCoin: public(uint256)
totalShares: public(uint256)
shares: public(HashMap[address, uint256])
IS_UP: immutable(bool)

# asset type
asset_type: public(constant(uint8)) = 2


@external
def __init__(_name: String[64], _symbol: String[32], _decimals: uint256, is_up: bool):
self.name = _name
self.symbol = _symbol
self.decimals = _decimals
IS_UP = is_up


@external
@view
def totalSupply() -> uint256:
# Rebase is pegged to total pooled coin
return self.totalCoin


@external
@view
def balanceOf(_user: address) -> uint256:
return self._get_coins_by_shares(self.shares[_user])


@external
@view
def allowance(_owner: address, _spender: address) -> uint256:
return self.allowances[_owner][_spender]


@external
def transfer(_to: address, _value: uint256) -> bool:
_shares: uint256 = self._get_shares_by_coins(_value)

self.shares[msg.sender] -= _shares
self.shares[_to] += _shares
log Transfer(msg.sender, _to, _value)
return True


@external
def transferFrom(_from: address, _to: address, _value: uint256) -> bool:
_shares: uint256 = self._get_shares_by_coins(_value)
_shares = min(self.shares[_from], _shares)
# Value can be less than expected even if self.shares[_from] > _shares
_new_value: uint256 = self._get_coins_by_shares(_shares)

self.shares[_from] -= _shares
self.shares[_to] += _shares
self.allowances[_from][msg.sender] -= _new_value

log Transfer(_from, _to, _new_value)
return True


@external
def approve(_spender: address, _value: uint256) -> bool:
self.allowances[msg.sender][_spender] = _value
log Approval(msg.sender, _spender, _value)
return True


@internal
@view
def _share_price() -> uint256:
if self.totalShares == 0:
return 10 ** self.decimals
return self.totalCoin * 10 ** self.decimals / self.totalShares


@internal
@view
def _get_coins_by_shares(_shares: uint256) -> uint256:
return _shares * self._share_price() / 10 ** self.decimals


@internal
@view
def _get_shares_by_coins(_coins: uint256) -> uint256:
return _coins * 10 ** self.decimals / self._share_price()


@external
@view
def share_price() -> uint256:
return self._share_price()


@external
def rebase():
if IS_UP:
self.totalCoin = self.totalCoin * 1000001 / 1000000
else:
self.totalCoin = self.totalCoin * 999999 / 1000000


@external
def set_total_coin(total_coin: uint256) -> bool:
assert self.totalShares != 0, "no shares"

self.totalCoin = total_coin
return True


@external
def _mint_for_testing(_target: address, _value: uint256) -> bool:
_shares: uint256 = self._get_shares_by_coins(_value)

self.totalCoin += _value
self.totalShares += _shares
self.shares[_target] += _shares

log Transfer(empty(address), _target, _value)

return True
10 changes: 6 additions & 4 deletions contracts/mocks/ERC4626.vy
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# @version 0.3.10
# From: https://github.com/fubuloubu/ERC4626/blob/main/contracts/VyperVault.vy
from vyper.interfaces import ERC20
from vyper.interfaces import ERC20Detailed

import ERC4626 as ERC4626

Expand All @@ -16,6 +17,7 @@ allowance: public(HashMap[address, HashMap[address, uint256]])
NAME: immutable(String[10])
SYMBOL: immutable(String[5])
DECIMALS: immutable(uint8)
_DECIMALS_OFFSET: immutable(uint8)

event Transfer:
sender: indexed(address)
Expand Down Expand Up @@ -57,6 +59,8 @@ def __init__(
DECIMALS = _decimals
asset = _asset

_DECIMALS_OFFSET = _decimals - ERC20Detailed(_asset.address).decimals()


@view
@external
Expand Down Expand Up @@ -113,8 +117,7 @@ def _convertToAssets(shareAmount: uint256) -> uint256:
if totalSupply == 0:
return 0

# NOTE: `shareAmount = 0` is extremely rare case, not optimizing for it
return shareAmount * asset.balanceOf(self) / totalSupply
return shareAmount * asset.balanceOf(self) / (totalSupply)


@view
Expand All @@ -129,9 +132,8 @@ def _convertToShares(assetAmount: uint256) -> uint256:
totalSupply: uint256 = self.totalSupply
totalAssets: uint256 = asset.balanceOf(self)
if totalAssets == 0 or totalSupply == 0:
return assetAmount # 1:1 price
return assetAmount * 10**convert(_DECIMALS_OFFSET, uint256) # 1:1 price

# NOTE: `assetAmount = 0` is extremely rare case, not optimizing for it
return assetAmount * totalSupply / totalAssets


Expand Down
Loading

0 comments on commit d23f6e0

Please sign in to comment.