From 1ec8ad3e9c9a9cac3d00bfe8197364a51724a1e7 Mon Sep 17 00:00:00 2001 From: r4bbit <445106+0x-r4bbit@users.noreply.github.com> Date: Thu, 11 Jan 2024 14:54:39 +0100 Subject: [PATCH] fix(StakeVault): make unstaking actually work Unstaking didn't actually work because it was using `transferFrom()` on the `StakeVault` with the `from` address being the vault itself. This would result in an approval error because the vault isn't creating any approvals to spend its own funds. The solution is to use `transfer` instead and ensuring the return value is checked. --- contracts/StakeVault.sol | 14 ++++++++++++-- test/StakeManager.t.sol | 18 ++++++++++++++++++ test/StakeVault.t.sol | 23 +++++++++++++++++++++++ test/mocks/BrokenERC20.s.sol | 20 ++++++++++++++++++++ test/script/DeployBroken.s.sol | 20 ++++++++++++++++++++ 5 files changed, 93 insertions(+), 2 deletions(-) create mode 100644 test/mocks/BrokenERC20.s.sol create mode 100644 test/script/DeployBroken.s.sol diff --git a/contracts/StakeVault.sol b/contracts/StakeVault.sol index 2c16df3..0b5ddbd 100644 --- a/contracts/StakeVault.sol +++ b/contracts/StakeVault.sol @@ -14,6 +14,10 @@ import { StakeManager } from "./StakeManager.sol"; contract StakeVault is Ownable { error StakeVault__MigrationNotAvailable(); + error StakeVault__StakingFailed(); + + error StakeVault__UnstakingFailed(); + StakeManager private stakeManager; ERC20 private immutable STAKED_TOKEN; @@ -27,7 +31,10 @@ contract StakeVault is Ownable { } function stake(uint256 _amount, uint256 _time) external onlyOwner { - STAKED_TOKEN.transferFrom(msg.sender, address(this), _amount); + bool success = STAKED_TOKEN.transferFrom(msg.sender, address(this), _amount); + if (!success) { + revert StakeVault__StakingFailed(); + } stakeManager.stake(_amount, _time); emit Staked(msg.sender, address(this), _amount, _time); @@ -39,7 +46,10 @@ contract StakeVault is Ownable { function unstake(uint256 _amount) external onlyOwner { stakeManager.unstake(_amount); - STAKED_TOKEN.transferFrom(address(this), msg.sender, _amount); + bool success = STAKED_TOKEN.transfer(msg.sender, _amount); + if (!success) { + revert StakeVault__UnstakingFailed(); + } } function leave() external onlyOwner { diff --git a/test/StakeManager.t.sol b/test/StakeManager.t.sol index d5745b9..97322b4 100644 --- a/test/StakeManager.t.sol +++ b/test/StakeManager.t.sol @@ -97,6 +97,24 @@ contract UnstakeTest is StakeManagerTest { vm.expectRevert(StakeManager.StakeManager__FundsLocked.selector); userVault.unstake(100); } + + function test_UnstakeShouldReturnFunds() public { + // ensure user has funds + deal(stakeToken, testUser, 1000); + StakeVault userVault = _createTestVault(testUser); + + vm.startPrank(testUser); + ERC20(stakeToken).approve(address(userVault), 100); + + userVault.stake(100, 0); + assertEq(ERC20(stakeToken).balanceOf(testUser), 900); + + userVault.unstake(100); + + assertEq(stakeManager.stakeSupply(), 0); + assertEq(ERC20(stakeToken).balanceOf(address(userVault)), 0); + assertEq(ERC20(stakeToken).balanceOf(testUser), 1000); + } } contract LockTest is StakeManagerTest { diff --git a/test/StakeVault.t.sol b/test/StakeVault.t.sol index 441e45a..bd7841b 100644 --- a/test/StakeVault.t.sol +++ b/test/StakeVault.t.sol @@ -1,8 +1,11 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.19; +import { ERC20 } from "@openzeppelin/contracts/token/ERC20/ERC20.sol"; + import { Test } from "forge-std/Test.sol"; import { Deploy } from "../script/Deploy.s.sol"; +import { DeployBroken } from "./script/DeployBroken.s.sol"; import { DeploymentConfig } from "../script/DeploymentConfig.s.sol"; import { StakeManager } from "../contracts/StakeManager.sol"; import { StakeVault } from "../contracts/StakeVault.sol"; @@ -42,3 +45,23 @@ contract StakedTokenTest is StakeVaultTest { assertEq(address(stakeVault.stakedToken()), stakeToken); } } + +contract StakeTest is StakeVaultTest { + function setUp() public override { + DeployBroken deployment = new DeployBroken(); + (vaultFactory, stakeManager, stakeToken) = deployment.run(); + + vm.prank(testUser); + stakeVault = vaultFactory.createVault(); + } + + function test_RevertWhen_StakeTokenTransferFails() public { + // ensure user has funds + deal(stakeToken, testUser, 1000); + + vm.startPrank(address(testUser)); + ERC20(stakeToken).approve(address(stakeVault), 100); + vm.expectRevert(StakeVault.StakeVault__StakingFailed.selector); + stakeVault.stake(100, 0); + } +} diff --git a/test/mocks/BrokenERC20.s.sol b/test/mocks/BrokenERC20.s.sol new file mode 100644 index 0000000..3837bfe --- /dev/null +++ b/test/mocks/BrokenERC20.s.sol @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import { ERC20 } from "@openzeppelin/contracts/token/ERC20/ERC20.sol"; + +contract BrokenERC20 is ERC20 { + constructor() ERC20("Mock SNT", "SNT") { + _mint(msg.sender, 1_000_000_000_000_000_000); + } + + // solhint-disable-next-line no-unused-vars + function transferFrom(address sender, address recipient, uint256 amount) public override returns (bool) { + return false; + } + + // solhint-disable-next-line no-unused-vars + function transfer(address recipient, uint256 amount) public override returns (bool) { + return false; + } +} diff --git a/test/script/DeployBroken.s.sol b/test/script/DeployBroken.s.sol new file mode 100644 index 0000000..e2105e1 --- /dev/null +++ b/test/script/DeployBroken.s.sol @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import { BaseScript } from "../../script/Base.s.sol"; +import { StakeManager } from "../../contracts/StakeManager.sol"; +import { VaultFactory } from "../../contracts/VaultFactory.sol"; +import { BrokenERC20 } from "../mocks/BrokenERC20.s.sol"; + +contract DeployBroken is BaseScript { + function run() public returns (VaultFactory, StakeManager, address) { + BrokenERC20 token = new BrokenERC20(); + + vm.startBroadcast(broadcaster); + StakeManager stakeManager = new StakeManager(address(token), address(0)); + VaultFactory vaultFactory = new VaultFactory(address(stakeManager)); + vm.stopBroadcast(); + + return (vaultFactory, stakeManager, address(token)); + } +}