diff --git a/contracts/mocks/ERC20Rebasing.vy b/contracts/mocks/ERC20Rebasing.vy index 14e9579c..0b24b47d 100644 --- a/contracts/mocks/ERC20Rebasing.vy +++ b/contracts/mocks/ERC20Rebasing.vy @@ -63,9 +63,12 @@ def allowance(_owner: address, _spender: address) -> uint256: @external def transfer(_to: address, _value: uint256) -> bool: - self._rebase() _shares: uint256 = self._get_shares_by_coins(_value) + if _shares > 0: + # only rebase on nonzero transfers + self._rebase() + self.shares[msg.sender] -= _shares self.shares[_to] += _shares log Transfer(msg.sender, _to, _value) @@ -74,9 +77,13 @@ def transfer(_to: address, _value: uint256) -> bool: @external def transferFrom(_from: address, _to: address, _value: uint256) -> bool: - self._rebase() _shares: uint256 = self._get_shares_by_coins(_value) _shares = min(self.shares[_from], _shares) + + if _shares > 0: + # only rebase on nonzero transfers + self._rebase() + # Value can be less than expected even if self.shares[_from] > _shares _new_value: uint256 = self._get_coins_by_shares(_shares) @@ -99,20 +106,20 @@ def approve(_spender: address, _value: uint256) -> bool: @view def _share_price() -> uint256: if self.totalShares == 0: - return 10 ** self.decimals - return self.totalCoin * 10 ** self.decimals / self.totalShares + return 10**self.decimals + return self.totalCoin * 10**self.decimals / self.totalShares @internal @view def _get_coins_by_shares(_shares: uint256) -> uint256: - return _shares * self._share_price() / 10 ** self.decimals + return _shares * self._share_price() / 10**self.decimals @internal @view def _get_shares_by_coins(_coins: uint256) -> uint256: - return _coins * 10 ** self.decimals / self._share_price() + return _coins * 10**self.decimals / self._share_price() @external