diff --git a/src/BaseWrapper.sol b/src/BaseWrapper.sol index 342a7b1..ff86328 100644 --- a/src/BaseWrapper.sol +++ b/src/BaseWrapper.sol @@ -68,7 +68,7 @@ abstract contract BaseWrapper is IERC7399 { /// @dev Handle the common parts of bridging the callback from legacy to ERC7399. Transfer the funds to the loan /// receiver. Call the callback supplied by the original borrower. Approve the repayment if necessary. If there is /// any result, it is kept in a storage variable to be retrieved on `flash` after the legacy flash loan is finished. - function bridgeToCallback(address asset, uint256 amount, uint256 fee, bytes memory params) internal { + function _bridgeToCallback(address asset, uint256 amount, uint256 fee, bytes memory params) internal { Data memory data = abi.decode(params, (Data)); _transferAssets(asset, amount, data.loanReceiver); diff --git a/src/aave/AaveWrapper.sol b/src/aave/AaveWrapper.sol index 2d6e071..be4c17c 100644 --- a/src/aave/AaveWrapper.sol +++ b/src/aave/AaveWrapper.sol @@ -12,7 +12,6 @@ import { FixedPointMathLib } from "lib/solmate/src/utils/FixedPointMathLib.sol"; import { BaseWrapper, IERC7399, ERC20 } from "../BaseWrapper.sol"; - /// @dev Aave Flash Lender that uses the Aave Pool as source of liquidity. /// Aave doesn't allow flow splitting or pushing repayments, so this wrapper is completely vanilla. contract AaveWrapper is BaseWrapper, IFlashLoanSimpleReceiver { @@ -28,28 +27,15 @@ contract AaveWrapper is BaseWrapper, IFlashLoanSimpleReceiver { } /// @inheritdoc IERC7399 - function maxFlashLoan(address asset) external view returns (uint256 max) { - DataTypes.ReserveData memory reserve = POOL.getReserveData(asset); - DataTypes.ReserveConfigurationMap memory configuration = reserve.configuration; - - max = !configuration.getPaused() && configuration.getActive() && configuration.getFlashLoanEnabled() - ? ERC20(asset).balanceOf(reserve.aTokenAddress) - : 0; + function maxFlashLoan(address asset) external view returns (uint256) { + return _maxFlashLoan(asset); } /// @inheritdoc IERC7399 - function flashFee(address, uint256 amount) external view returns (uint256) { - return amount.mulWadUp(POOL.FLASHLOAN_PREMIUM_TOTAL() * 0.0001e18); - } - - function _flashLoan(address asset, uint256 amount, bytes memory data) internal override { - POOL.flashLoanSimple({ - receiverAddress: address(this), - asset: asset, - amount: amount, - params: data, - referralCode: 0 - }); + function flashFee(address asset, uint256 amount) external view returns (uint256) { + uint256 max = _maxFlashLoan(asset); + require(max > 0, "Unsupported currency"); + return amount >= max ? type(uint256).max : _flashFee(amount); } /// @inheritdoc IFlashLoanSimpleReceiver @@ -67,8 +53,31 @@ contract AaveWrapper is BaseWrapper, IFlashLoanSimpleReceiver { require(msg.sender == address(POOL), "AaveFlashLoanProvider: not pool"); require(initiator == address(this), "AaveFlashLoanProvider: not initiator"); - bridgeToCallback(asset, amount, fee, params); + _bridgeToCallback(asset, amount, fee, params); return true; } + + function _flashLoan(address asset, uint256 amount, bytes memory data) internal override { + POOL.flashLoanSimple({ + receiverAddress: address(this), + asset: asset, + amount: amount, + params: data, + referralCode: 0 + }); + } + + function _maxFlashLoan(address asset) internal view returns (uint256 max) { + DataTypes.ReserveData memory reserve = POOL.getReserveData(asset); + DataTypes.ReserveConfigurationMap memory configuration = reserve.configuration; + + max = !configuration.getPaused() && configuration.getActive() && configuration.getFlashLoanEnabled() + ? ERC20(asset).balanceOf(reserve.aTokenAddress) + : 0; + } + + function _flashFee(uint256 amount) internal view returns (uint256) { + return amount.mulWadUp(POOL.FLASHLOAN_PREMIUM_TOTAL() * 0.0001e18); + } } diff --git a/src/balancer/BalancerWrapper.sol b/src/balancer/BalancerWrapper.sol index 3ea0e95..125a65e 100644 --- a/src/balancer/BalancerWrapper.sol +++ b/src/balancer/BalancerWrapper.sol @@ -11,7 +11,6 @@ import { FixedPointMathLib } from "lib/solmate/src/utils/FixedPointMathLib.sol"; import { BaseWrapper, IERC7399, ERC20 } from "../BaseWrapper.sol"; - /// @dev Balancer Flash Lender that uses Balancer Pools as source of liquidity. /// Balancer allows pushing repayments, so we override `_repayTo`. contract BalancerWrapper is BaseWrapper, IFlashLoanRecipient { @@ -29,17 +28,14 @@ contract BalancerWrapper is BaseWrapper, IFlashLoanRecipient { /// @inheritdoc IERC7399 function maxFlashLoan(address asset) external view returns (uint256) { - return ERC20(asset).balanceOf(address(balancer)); + return _maxFlashLoan(asset); } /// @inheritdoc IERC7399 - function flashFee(address, uint256 amount) external view returns (uint256) { - return amount.mulWadUp(balancer.getProtocolFeesCollector().getFlashLoanFeePercentage()); - } - - function _flashLoan(address asset, uint256 amount, bytes memory data) internal override { - flashLoanDataHash = keccak256(data); - balancer.flashLoan(this, asset.toArray(), amount.toArray(), data); + function flashFee(address asset, uint256 amount) external view returns (uint256) { + uint256 max = _maxFlashLoan(asset); + require(max > 0, "Unsupported currency"); + return amount >= max ? type(uint256).max : _flashFee(amount); } /// @inheritdoc IFlashLoanRecipient @@ -56,10 +52,23 @@ contract BalancerWrapper is BaseWrapper, IFlashLoanRecipient { require(keccak256(params) == flashLoanDataHash, "BalancerWrapper: params hash mismatch"); delete flashLoanDataHash; - bridgeToCallback(assets[0], amounts[0], fees[0], params); + _bridgeToCallback(assets[0], amounts[0], fees[0], params); + } + + function _flashLoan(address asset, uint256 amount, bytes memory data) internal override { + flashLoanDataHash = keccak256(data); + balancer.flashLoan(this, asset.toArray(), amount.toArray(), data); } function _repayTo() internal view override returns (address) { return address(balancer); } + + function _flashFee(uint256 amount) internal view returns (uint256) { + return amount.mulWadUp(balancer.getProtocolFeesCollector().getFlashLoanFeePercentage()); + } + + function _maxFlashLoan(address asset) internal view returns (uint256) { + return ERC20(asset).balanceOf(address(balancer)); + } } diff --git a/src/erc3156/ERC3156Wrapper.sol b/src/erc3156/ERC3156Wrapper.sol index 0704f0c..5d52796 100644 --- a/src/erc3156/ERC3156Wrapper.sol +++ b/src/erc3156/ERC3156Wrapper.sol @@ -30,23 +30,14 @@ contract ERC3156Wrapper is BaseWrapper, IERC3156FlashBorrower { /// @inheritdoc IERC7399 function maxFlashLoan(address asset) external view returns (uint256) { IERC3156FlashLender lender = lenders[asset]; - require(address(lender) != address(0), "Unsupported currency"); - return lender.maxFlashLoan(asset); + return address(lender) != address(0) ? _maxFlashLoan(lender, asset) : 0; } /// @inheritdoc IERC7399 function flashFee(address asset, uint256 amount) external view returns (uint256) { IERC3156FlashLender lender = lenders[asset]; require(address(lender) != address(0), "Unsupported currency"); - return lender.flashFee(asset, amount); - } - - function _flashLoan(address asset, uint256 amount, bytes memory data) internal override { - IERC3156FlashLender lender = lenders[asset]; - require(address(lender) != address(0), "Unsupported currency"); - - // We get funds from an ERC3156 lender to serve the ERC7399 flash loan in our ERC3156 callback - lender.flashLoan(this, address(asset), amount, data); + return amount >= _maxFlashLoan(lender, asset) ? type(uint256).max : _flashFee(lender, asset, amount); } /// @inheritdoc IERC3156FlashBorrower @@ -63,8 +54,24 @@ contract ERC3156Wrapper is BaseWrapper, IERC3156FlashBorrower { require(erc3156initiator == address(this), "External loan initiator"); require(msg.sender == address(lenders[asset]), "Unknown lender"); - bridgeToCallback(asset, amount, fee, params); + _bridgeToCallback(asset, amount, fee, params); return CALLBACK_SUCCESS; } + + function _flashLoan(address asset, uint256 amount, bytes memory data) internal override { + IERC3156FlashLender lender = lenders[asset]; + require(address(lender) != address(0), "Unsupported currency"); + + // We get funds from an ERC3156 lender to serve the ERC7399 flash loan in our ERC3156 callback + lender.flashLoan(this, address(asset), amount, data); + } + + function _maxFlashLoan(IERC3156FlashLender lender, address asset) internal view returns (uint256) { + return lender.maxFlashLoan(asset); + } + + function _flashFee(IERC3156FlashLender lender, address asset, uint256 amount) internal view returns (uint256) { + return lender.flashFee(asset, amount); + } } diff --git a/src/uniswapV3/UniswapV3Wrapper.sol b/src/uniswapV3/UniswapV3Wrapper.sol index c33a1ad..37a93e1 100644 --- a/src/uniswapV3/UniswapV3Wrapper.sol +++ b/src/uniswapV3/UniswapV3Wrapper.sol @@ -57,41 +57,15 @@ contract UniswapV3Wrapper is BaseWrapper, IUniswapV3FlashCallback { } /// @inheritdoc IERC7399 - function maxFlashLoan(address asset) external view returns (uint256 max) { - // Try a stable pair first - IUniswapV3Pool pool = _pool(asset, asset == usdc ? usdt : usdc, 0.0001e6); - if (address(pool) != address(0)) { - max = pool.balance(asset); - } - - uint16[3] memory fees = [0.0005e6, 0.003e6, 0.01e6]; - address assetOther = asset == weth ? usdc : weth; - for (uint256 i = 0; i < 3; i++) { - pool = _pool(asset, assetOther, fees[i]); - uint256 _balance = pool.balance(asset); - if (address(pool) != address(0) && _balance > max) { - max = _balance; - } - } + function maxFlashLoan(address asset) external view returns (uint256) { + return _maxFlashLoan(asset); } /// @inheritdoc IERC7399 function flashFee(address asset, uint256 amount) external view returns (uint256) { - IUniswapV3Pool pool = cheapestPool(asset, amount); - require(address(pool) != address(0), "Unsupported currency"); - return amount * uint256(pool.fee()) / 1e6; - } - - function _flashLoan(address asset, uint256 amount, bytes memory data) internal override { - IUniswapV3Pool pool = cheapestPool(asset, amount); - require(address(pool) != address(0), "Unsupported currency"); - - address asset0 = address(pool.token0()); - address asset1 = address(pool.token1()); - uint256 amount0 = asset == asset0 ? amount : 0; - uint256 amount1 = asset == asset1 ? amount : 0; - - pool.flash(address(this), amount0, amount1, abi.encode(asset0, asset1, pool.fee(), amount, data)); + uint256 max = _maxFlashLoan(asset); + require(max > 0, "Unsupported currency"); + return amount >= max ? type(uint256).max : _flashFee(asset, amount); } /// @inheritdoc IUniswapV3FlashCallback @@ -108,7 +82,19 @@ contract UniswapV3Wrapper is BaseWrapper, IUniswapV3FlashCallback { require(msg.sender == address(_pool(asset, other, feeTier)), "UniswapV3Wrapper: Unknown pool"); uint256 fee = fee0 > 0 ? fee0 : fee1; - bridgeToCallback(asset, amount, fee, data); + _bridgeToCallback(asset, amount, fee, data); + } + + function _flashLoan(address asset, uint256 amount, bytes memory data) internal override { + IUniswapV3Pool pool = cheapestPool(asset, amount); + require(address(pool) != address(0), "Unsupported currency"); + + address asset0 = address(pool.token0()); + address asset1 = address(pool.token1()); + uint256 amount0 = asset == asset0 ? amount : 0; + uint256 amount1 = asset == asset1 ? amount : 0; + + pool.flash(address(this), amount0, amount1, abi.encode(asset0, asset1, pool.fee(), amount, data)); } function _repayTo() internal view override returns (address) { @@ -119,6 +105,30 @@ contract UniswapV3Wrapper is BaseWrapper, IUniswapV3FlashCallback { PoolAddress.PoolKey memory poolKey = PoolAddress.getPoolKey(address(asset), address(other), fee); pool = IUniswapV3Pool(factory.computeAddress(poolKey)); } + + function _maxFlashLoan(address asset) internal view returns (uint256 max) { + // Try a stable pair first + IUniswapV3Pool pool = _pool(asset, asset == usdc ? usdt : usdc, 0.0001e6); + if (address(pool) != address(0)) { + max = pool.balance(asset); + } + + uint16[3] memory fees = [0.0005e6, 0.003e6, 0.01e6]; + address assetOther = asset == weth ? usdc : weth; + for (uint256 i = 0; i < 3; i++) { + pool = _pool(asset, assetOther, fees[i]); + uint256 _balance = pool.balance(asset); + if (address(pool) != address(0) && _balance > max) { + max = _balance; + } + } + } + + function _flashFee(address asset, uint256 amount) internal view returns (uint256) { + IUniswapV3Pool pool = cheapestPool(asset, amount); + require(address(pool) != address(0), "Unsupported currency"); + return amount * uint256(pool.fee()) / 1e6; + } } function canLoan(IUniswapV3Pool pool, address asset, uint256 amount) view returns (bool) { diff --git a/test/BaseWrapper.t.sol b/test/BaseWrapper.t.sol index 0281364..c681d7c 100644 --- a/test/BaseWrapper.t.sol +++ b/test/BaseWrapper.t.sol @@ -111,7 +111,7 @@ contract FooWrapper is BaseWrapper { } function flashLoanCallback(address asset, bytes memory params) external virtual { - bridgeToCallback(asset, ERC20(asset).balanceOf(address(this)), 0, params); + _bridgeToCallback(asset, ERC20(asset).balanceOf(address(this)), 0, params); ERC20(asset).transfer(msg.sender, ERC20(asset).balanceOf(address(this))); }