diff --git a/src/erc3156/ERC3156Wrapper.sol b/src/erc3156/ERC3156Wrapper.sol index d2c5c2a..0e1de0d 100644 --- a/src/erc3156/ERC3156Wrapper.sol +++ b/src/erc3156/ERC3156Wrapper.sol @@ -8,6 +8,7 @@ import { RevertMsgExtractor } from "../utils/RevertMsgExtractor.sol"; import { IERC3156PPFlashLender } from "lib/erc3156pp/src/interfaces/IERC3156PPFlashLender.sol"; import { IERC20 } from "lib/erc3156pp/src/interfaces/IERC20.sol"; + library TransferHelper { /// @notice Transfers tokens from msg.sender to a recipient /// @dev Errors with the underlying revert message if transfer fails @@ -124,7 +125,6 @@ contract ERC3156Wrapper is IERC3156PPFlashLender, IERC3156FlashBorrower { // We pass the loan to the loan receiver and we store the callback result in storage for the the ERC3156++ flashLoan function to recover it. _callbackResult = _callFromData(IERC20(token), amount, fee, data); - _callbackResult = abi.encode(CALLBACK_SUCCESS); // TODO: Hijacking the callback result to see where the OOG error comes from IERC20(token).approve(address(lender), amount + fee); @@ -147,7 +147,8 @@ contract ERC3156Wrapper is IERC3156PPFlashLender, IERC3156FlashBorrower { fee, // fee userData // data )); - require(success, RevertMsgExtractor.getRevertMsg(result)); - return result; + + if(!success) revert(RevertMsgExtractor.getRevertMsg(result)); + return abi.decode(result, (bytes)); } } \ No newline at end of file diff --git a/src/test/FlashBorrower.sol b/src/test/FlashBorrower.sol index e8cbac4..8d55755 100644 --- a/src/test/FlashBorrower.sol +++ b/src/test/FlashBorrower.sol @@ -12,6 +12,7 @@ contract LoanReceiver { } contract FlashBorrower { + bytes32 public constant ERC3156PP_CALLBACK_SUCCESS = keccak256("ERC3156PP_CALLBACK_SUCCESS"); IERC3156PPFlashLender lender; LoanReceiver loanReceiver; @@ -30,6 +31,7 @@ contract FlashBorrower { function onFlashLoan(address initiator, address paymentReceiver, IERC20 asset, uint256 amount, uint256 fee, bytes calldata data) external returns(bytes memory) { require(msg.sender == address(lender), "FlashBorrower: Untrusted lender"); require(initiator == address(this), "FlashBorrower: External loan initiator"); + flashInitiator = initiator; flashAsset = asset; flashAmount = amount; @@ -38,7 +40,7 @@ contract FlashBorrower { flashBalance = IERC20(asset).balanceOf(address(this)); asset.transfer(paymentReceiver, amount + fee); - return ""; // abi.encode(data, paymentReceiver, fee); // TODO: Returning anything here causes a revert + return abi.encode(ERC3156PP_CALLBACK_SUCCESS); } function onSteal(address initiator, address paymentReceiver, IERC20 asset, uint256 amount, uint256 fee, bytes calldata data) external returns(bytes memory) { @@ -51,7 +53,7 @@ contract FlashBorrower { // do nothing - return abi.encode(data, paymentReceiver, fee); + return abi.encode(ERC3156PP_CALLBACK_SUCCESS); } function onReenter(address initiator, address paymentReceiver, IERC20 asset, uint256 amount, uint256 fee, bytes calldata data) external returns(bytes memory) { @@ -69,7 +71,7 @@ contract FlashBorrower { flashAmount += amount; flashFee += fee; - return abi.encode(data, paymentReceiver, fee); + return abi.encode(ERC3156PP_CALLBACK_SUCCESS); } function flashBorrow(IERC20 asset, uint256 amount) public returns(bytes memory) { diff --git a/test/ERC3156Wrapper.t.sol b/test/ERC3156Wrapper.t.sol index 701c43c..6e47996 100644 --- a/test/ERC3156Wrapper.t.sol +++ b/test/ERC3156Wrapper.t.sol @@ -53,9 +53,11 @@ contract ERC3156WrapperTest is PRBTest, StdCheats { dai.transfer(address(borrower), fee); bytes memory result = borrower.flashBorrow(dai, loan); - // TODO: Temporary test to ensure we receive the correct callback result - assertEq(uint256(wrapper.CALLBACK_SUCCESS()), uint256(abi.decode(result, (bytes32)))); + // Test the return values + (bytes32 callbackReturn) = abi.decode(result, (bytes32)); + assertEq(uint256(callbackReturn), uint256(borrower.ERC3156PP_CALLBACK_SUCCESS()), "Callback failed"); + // Test the borrower state assertEq(borrower.flashInitiator(), address(borrower)); assertEq(address(borrower.flashAsset()), address(dai)); assertEq(borrower.flashAmount(), loan);