Skip to content

Commit

Permalink
flash loan tested, there is a niggle in FlashBorrower.sol
Browse files Browse the repository at this point in the history
  • Loading branch information
alcueca committed Jul 25, 2023
1 parent 2dc933a commit a7c6315
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 64 deletions.
71 changes: 71 additions & 0 deletions FlashLender.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity >=0.8.19 <0.9.0;

import { PRBTest } from "@prb/test/PRBTest.sol";
import { console2 } from "forge-std/console2.sol";
import { StdCheats } from "forge-std/StdCheats.sol";

import { FlashLender } from "../src/FlashLender.sol";
import { FlashBorrower } from "../src/test/FlashBorrower.sol";
import { ERC20Mock } from "../src/test/ERC20Mock.sol";
import { IERC20 } from "../src/interfaces/IERC20.sol";


/// @dev If this is your first time with Forge, read this tutorial in the Foundry Book:
/// https://book.getfoundry.sh/forge/writing-tests
contract FlashLenderTest is PRBTest, StdCheats {
FlashLender internal lender;
FlashBorrower internal borrower;
IERC20 internal asset;
IERC20 internal otherAsset;

/// @dev A function invoked before each test case is run.
function setUp() public virtual {
// Instantiate the contract-under-test.

asset = IERC20(address(new ERC20Mock("Asset", "AST")));
otherAsset = IERC20(address(new ERC20Mock("Other", "OTH")));
IERC20[] memory supportedAssets = new IERC20[](2);
supportedAssets[0] = asset;
supportedAssets[1] = otherAsset;
lender = new FlashLender(supportedAssets, 10);
borrower = new FlashBorrower(lender);

asset.transfer(address(lender), 999e18); // Keeping 1e18 for the flash fee.
otherAsset.transfer(address(lender), 999e18); // Keeping 1e18 for the flash fee.
}

/// @dev Simple flash loan test.
function test_flashLoan() external {
console2.log("test_flashLoan");
uint256 lenderBalance = asset.balanceOf(address(lender));
uint256 loan = 1e18;
uint256 fee = lender.flashFee(asset, loan);
asset.transfer(address(borrower), fee);
borrower.flashBorrow(asset, loan);

assertEq(borrower.flashInitiator(), address(borrower));
assertEq(address(borrower.flashAsset()), address(asset));
assertEq(borrower.flashAmount(), loan);
assertEq(borrower.flashBalance(), loan + fee); // The amount we transferred to pay for fees, plus the amount we borrowed
assertEq(borrower.flashFee(), fee);
assertEq(asset.balanceOf(address(lender)), lenderBalance + fee);
}

function test_flashLoanAndReenter() external {
console2.log("test_flashLoanAndReenter");
uint256 lenderBalance = asset.balanceOf(address(lender));
uint256 firstLoan = 1e18;
uint256 secondLoan = 2e18;
uint256 fees = lender.flashFee(asset, firstLoan) + lender.flashFee(asset, secondLoan);
asset.transfer(address(borrower), fees);
borrower.flashBorrowAndReenter(asset, firstLoan);

assertEq(borrower.flashInitiator(), address(borrower));
assertEq(address(borrower.flashAsset()), address(asset));
assertEq(borrower.flashAmount(), firstLoan + secondLoan);
assertEq(borrower.flashBalance(), firstLoan + secondLoan + fees); // The amount we transferred to pay for fees, plus the amount we borrowed
assertEq(borrower.flashFee(), fees);
assertEq(asset.balanceOf(address(lender)), lenderBalance + fees);
}
}
12 changes: 5 additions & 7 deletions src/erc3156/ERC3156Wrapper.sol
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ contract ERC3156Wrapper is IERC3156PPFlashLender, IERC3156FlashBorrower {
using TransferHelper for IERC20;
using RevertMsgExtractor for bytes;

bytes32 constant CALLBACK_SUCCESS = keccak256("ERC3156FlashBorrower.onFlashLoan");
bytes32 public constant CALLBACK_SUCCESS = keccak256("ERC3156FlashBorrower.onFlashLoan");

mapping(IERC20 => IERC3156FlashLender) public lenders;
bytes internal _callbackResult;
Expand All @@ -43,7 +43,7 @@ contract ERC3156Wrapper is IERC3156PPFlashLender, IERC3156FlashBorrower {
*/
constructor(
IERC20[] memory assets_,
IERC20[] memory lenders_
IERC3156FlashLender[] memory lenders_
) {
require (assets_.length == lenders_.length, "Arrays must be the same length");
for (uint256 i = 0; i < assets_.length; i++) {
Expand Down Expand Up @@ -122,11 +122,9 @@ contract ERC3156Wrapper is IERC3156PPFlashLender, IERC3156FlashBorrower {
require(msg.sender == address(lenders[IERC20(token)]), "Unknown lender");
IERC3156FlashLender lender = IERC3156FlashLender(msg.sender);

// We pass the loan to the loan receiver
bytes memory result = _callFromData(IERC20(token), amount, fee, data);

// We store the callback result in storage for the the ERC3156++ flashLoan function to recover it.
_callbackResult = result;
// We pass the loan to the loan receiver and we store the callback result in storage for the the ERC3156++ flashLoan function to recover it.
_callbackResult = _callFromData(IERC20(token), amount, fee, data);
_callbackResult = abi.encode(CALLBACK_SUCCESS); // TODO: Hijacking the callback result to see where the OOG error comes from

IERC20(token).approve(address(lender), amount + fee);

Expand Down
86 changes: 86 additions & 0 deletions src/test/FlashBorrower.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;

import { IERC3156PPFlashLender } from "lib/erc3156pp/src/interfaces/IERC3156PPFlashLender.sol";
import { IERC20 } from "lib/erc3156pp/src/interfaces/IERC20.sol";


contract LoanReceiver {
function retrieve(IERC20 asset) external {
asset.transfer(msg.sender, asset.balanceOf(address(this)));
}
}

contract FlashBorrower {
IERC3156PPFlashLender lender;
LoanReceiver loanReceiver;

uint256 public flashBalance;
address public flashInitiator;
IERC20 public flashAsset;
uint256 public flashAmount;
uint256 public flashFee;

constructor (IERC3156PPFlashLender lender_) {
lender = lender_;
loanReceiver = new LoanReceiver();
}

/// @dev ERC-3156++ Flash loan callback
function onFlashLoan(address initiator, address paymentReceiver, IERC20 asset, uint256 amount, uint256 fee, bytes calldata data) external returns(bytes memory) {
require(msg.sender == address(lender), "FlashBorrower: Untrusted lender");
require(initiator == address(this), "FlashBorrower: External loan initiator");
flashInitiator = initiator;
flashAsset = asset;
flashAmount = amount;
flashFee = fee;
loanReceiver.retrieve(asset);
flashBalance = IERC20(asset).balanceOf(address(this));
asset.transfer(paymentReceiver, amount + fee);

return ""; // abi.encode(data, paymentReceiver, fee); // TODO: Returning anything here causes a revert
}

function onSteal(address initiator, address paymentReceiver, IERC20 asset, uint256 amount, uint256 fee, bytes calldata data) external returns(bytes memory) {
require(msg.sender == address(lender), "FlashBorrower: Untrusted lender");
require(initiator == address(this), "FlashBorrower: External loan initiator");
flashInitiator = initiator;
flashAsset = asset;
flashAmount = amount;
flashFee = fee;

// do nothing

return abi.encode(data, paymentReceiver, fee);
}

function onReenter(address initiator, address paymentReceiver, IERC20 asset, uint256 amount, uint256 fee, bytes calldata data) external returns(bytes memory) {
require(msg.sender == address(lender), "FlashBorrower: Untrusted lender");
require(initiator == address(this), "FlashBorrower: External loan initiator");
flashInitiator = initiator;
flashAsset = asset;
loanReceiver.retrieve(asset);

flashBorrow(asset, amount * 2);

asset.transfer(paymentReceiver, amount + fee);

// flashBorrow will have initialized these
flashAmount += amount;
flashFee += fee;

return abi.encode(data, paymentReceiver, fee);
}

function flashBorrow(IERC20 asset, uint256 amount) public returns(bytes memory) {
return lender.flashLoan(address(loanReceiver), asset, amount, "", this.onFlashLoan);
}

function flashBorrowAndSteal(IERC20 asset, uint256 amount) public returns(bytes memory) {
return lender.flashLoan(address(loanReceiver), asset, amount, "", this.onSteal);
}

function flashBorrowAndReenter(IERC20 asset, uint256 amount) public returns(bytes memory) {
return lender.flashLoan(address(loanReceiver), asset, amount, "", this.onReenter);
}
}
66 changes: 66 additions & 0 deletions test/ERC3156Wrapper.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity >=0.8.19 <0.9.0;

import { PRBTest } from "@prb/test/PRBTest.sol";
import { console2 } from "forge-std/console2.sol";
import { StdCheats } from "forge-std/StdCheats.sol";
import { IERC3156FlashLender } from "lib/erc3156/contracts/interfaces/IERC3156FlashLender.sol";

import { FlashBorrower } from "../src/test/FlashBorrower.sol";
import { IERC20, ERC3156Wrapper } from "../src/erc3156/ERC3156Wrapper.sol";


/// @dev If this is your first time with Forge, read this tutorial in the Foundry Book:
/// https://book.getfoundry.sh/forge/writing-tests
contract ERC3156WrapperTest is PRBTest, StdCheats {
ERC3156Wrapper internal wrapper;
FlashBorrower internal borrower;
IERC20 internal dai;
IERC3156FlashLender internal makerFlash;

/// @dev A function invoked before each test case is run.
function setUp() public virtual {
// Revert if there is no API key.
string memory alchemyApiKey = vm.envOr("API_KEY_ALCHEMY", string(""));
if (bytes(alchemyApiKey).length == 0) {
revert("API_KEY_ALCHEMY variable missing");
}

vm.createSelectFork({ urlOrAlias: "mainnet", blockNumber: 16_428_000 });
makerFlash = IERC3156FlashLender(0x60744434d6339a6B27d73d9Eda62b6F66a0a04FA);
dai = IERC20(0x6B175474E89094C44Da98b954EedeAC495271d0F);

IERC20[] memory assets = new IERC20[](1);
assets[0] = dai;
IERC3156FlashLender[] memory lenders = new IERC3156FlashLender[](1);
lenders[0] = makerFlash;
wrapper = new ERC3156Wrapper(assets, lenders);
borrower = new FlashBorrower(wrapper);
}

/// @dev Basic test. Run it with `forge test -vvv` to see the console log.
function test_flashFee() external {
console2.log("test_flashFee");
assertEq(wrapper.flashFee(dai, 1e18), 0, "Fee not zero");
assertEq(wrapper.flashFee(dai, type(uint256).max), type(uint256).max, "Fee not max");
}

function test_flashLoan() external {
console2.log("test_flashLoan");
uint256 lenderBalance = dai.balanceOf(address(wrapper));
uint256 loan = 1e18;
uint256 fee = wrapper.flashFee(dai, loan);
dai.transfer(address(borrower), fee);
bytes memory result = borrower.flashBorrow(dai, loan);

// TODO: Temporary test to ensure we receive the correct callback result
assertEq(uint256(wrapper.CALLBACK_SUCCESS()), uint256(abi.decode(result, (bytes32))));

assertEq(borrower.flashInitiator(), address(borrower));
assertEq(address(borrower.flashAsset()), address(dai));
assertEq(borrower.flashAmount(), loan);
assertEq(borrower.flashBalance(), loan + fee); // The amount we transferred to pay for fees, plus the amount we borrowed
assertEq(borrower.flashFee(), fee);
assertEq(dai.balanceOf(address(wrapper)), lenderBalance + fee);
}
}
57 changes: 0 additions & 57 deletions test/Foo.t.sol

This file was deleted.

0 comments on commit a7c6315

Please sign in to comment.