diff --git a/starknet_contracts/README.md b/starknet_contracts/README.md new file mode 100644 index 0000000..d904814 --- /dev/null +++ b/starknet_contracts/README.md @@ -0,0 +1,235 @@ +# StarkNet Staking Contract + +A secure and gas-efficient staking smart contract implemented in Cairo for StarkNet. Users can stake the native STRK token and earn rewards in Reward tokens based on their stake share and time. + +## Features + +### Core Functionality + +- **Staking**: Users can stake native STRK tokens to earn rewards +- **Unstaking**: Users can withdraw their staked STRK tokens +- **Reward Distribution**: Fair reward calculation based on stake share and time +- **Reward Claiming**: Users can claim accumulated rewards + +### Owner Functions + +- **Fund Rewards**: Owner can add rewards to the pool with specified duration +- **Pause/Unpause**: Emergency pause functionality +- **Recover ERC20**: Recover accidentally sent tokens (with restrictions) + +### Security Features + +- **Pausable**: Contract can be paused in emergencies +- **Access Control**: Owner-only functions +- **Input Validation**: Comprehensive input validation +- **Reentrancy Protection**: Built-in protection against reentrancy attacks + +## Architecture + +### Reward Calculation + +The contract implements a "reward per token stored" mechanism: + +- **Reward Rate**: Rewards distributed per second across all stakers +- **Reward Per Token**: Cumulative rewards per staked token +- **User Rewards**: Calculated as: `balance * (rewardPerToken - userRewardPerTokenPaid) + pendingRewards` + +### Contracts + +- **StakingContract**: Main staking contract +- **STRK Token**: Native StarkNet token for staking (0x04718f5a0fc34cc1af16a1cdee98ffb20c31f5cd61d6ab07201858f4287c938d) +- **RewardToken**: ERC20 token for rewards (RWT) + +## Installation & Setup + +### Prerequisites + +- [Scarb](https://docs.swmansion.com/scarb/) - Cairo package manager +- [StarkNet Foundry](https://foundry-rs.github.io/starknet-foundry/) - Testing framework + +### Build + +```bash +scarb build +``` + +### Test + +```bash +snforge test +``` + +## Usage + +### Deploy Contracts + +1. Deploy RewardToken contract +2. Deploy StakingContract with RewardToken address (STRK token address is hardcoded) + +### Staking Flow + +```cairo +// Approve staking contract to spend STRK tokens +strk_token.approve(staking_contract_address, amount); + +// Stake STRK tokens +staking_contract.stake(amount); + +// Check earned rewards +let rewards = staking_contract.earned(user_address); + +// Claim rewards +staking_contract.claim_rewards(); + +// Unstake tokens +staking_contract.unstake(amount); +``` + +### Owner Functions + +```cairo +// Fund rewards for distribution +staking_contract.fund_rewards(amount, duration); + +// Pause staking (emergency) +staking_contract.pause(); + +// Unpause staking +staking_contract.unpause(); + +// Recover accidentally sent tokens +staking_contract.recover_erc20(token_address, amount); +``` + +## API Reference + +### IStaking Interface + +#### `stake(amount: u256)` + +Stake tokens to earn rewards. + +#### `unstake(amount: u256)` + +Unstake tokens from the contract. + +#### `claim_rewards()` + +Claim accumulated reward tokens. + +#### `earned(account: ContractAddress) -> u256` + +Get earned rewards for an account. + +#### `balance_of(account: ContractAddress) -> u256` + +Get staked balance for an account. + +#### `total_supply() -> u256` + +Get total staked tokens. + +### IOwnerFunctions Interface + +#### `fund_rewards(amount: u256, duration: u64)` + +Fund reward pool for distribution over specified duration. + +#### `pause()` + +Pause staking operations. + +#### `unpause()` + +Resume staking operations. + +#### `recover_erc20(token: ContractAddress, amount: u256)` + +Recover accidentally sent ERC20 tokens. + +## Events + +- **Staked**: Emitted when tokens are staked +- **Unstaked**: Emitted when tokens are unstaked +- **RewardPaid**: Emitted when rewards are claimed +- **RewardsFunded**: Emitted when rewards are funded +- **Paused**: Emitted when contract is paused +- **Unpaused**: Emitted when contract is unpaused +- **RecoveredTokens**: Emitted when tokens are recovered + +## Security Considerations + +### Reward Calculation Security + +- Uses checked arithmetic to prevent overflow +- Updates reward calculations before state changes +- Implements reward per token stored pattern for gas efficiency + +### Access Control + +- Owner-only functions use OpenZeppelin's Ownable component +- Input validation on all public functions +- Pausable functionality for emergency stops + +### Token Recovery + +- Cannot recover staked STRK tokens while active distribution +- Cannot recover reward tokens +- Only owner can recover tokens + +### Gas Optimization + +- Efficient reward calculation using cumulative approach +- Minimal storage reads/writes +- Optimized for batch operations + +## Testing + +The contract includes comprehensive unit tests covering: + +- Basic staking/unstaking functionality +- Reward calculation and claiming +- Owner functions (pause, fund rewards, recover tokens) +- Edge cases and error conditions +- Security scenarios + +Run tests with: + +```bash +snforge test +``` + +## Deployment + +### Local Development + +```bash +# Start local StarkNet node +starknet-devnet + +# Deploy contracts +# Use deployment script or manual deployment +``` + +### Testnet Deployment + +```bash +# Deploy to Sepolia testnet +# Use StarkNet CLI or wallet interface +``` + +## License + +This project is licensed under the MIT License. + +## Contributing + +1. Fork the repository +2. Create a feature branch +3. Add tests for new functionality +4. Ensure all tests pass +5. Submit a pull request + +## Disclaimer + +This contract is provided as-is for educational and demonstration purposes. Always conduct thorough security audits before deploying to production networks. diff --git a/starknet_contracts/Scarb.lock b/starknet_contracts/Scarb.lock index fcb24ad..b005555 100644 --- a/starknet_contracts/Scarb.lock +++ b/starknet_contracts/Scarb.lock @@ -1,17 +1,135 @@ # Code generated by scarb DO NOT EDIT. version = 1 +[[package]] +name = "openzeppelin" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:05fd9365be85a4a3e878135d5c52229f760b3861ce4ed314cb1e75b178b553da" +dependencies = [ + "openzeppelin_access", + "openzeppelin_account", + "openzeppelin_finance", + "openzeppelin_governance", + "openzeppelin_introspection", + "openzeppelin_merkle_tree", + "openzeppelin_presets", + "openzeppelin_security", + "openzeppelin_token", + "openzeppelin_upgrades", + "openzeppelin_utils", +] + +[[package]] +name = "openzeppelin_access" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:7734901a0ca7a7065e69416fea615dd1dc586c8dc9e76c032f25ee62e8b2a06c" +dependencies = [ + "openzeppelin_introspection", +] + +[[package]] +name = "openzeppelin_account" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:1aa3a71e2f40f66f98d96aa9bf9f361f53db0fd20fa83ef7df04426a3c3a926a" +dependencies = [ + "openzeppelin_introspection", + "openzeppelin_utils", +] + +[[package]] +name = "openzeppelin_finance" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:f0c507fbff955e4180ea3fa17949c0ff85518c40101f4948948d9d9a74143d6c" +dependencies = [ + "openzeppelin_access", + "openzeppelin_token", +] + +[[package]] +name = "openzeppelin_governance" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:c0fb60fad716413d537fabd5fcbb2c499ca6beb95af5f0d1699955ecec4c6f63" +dependencies = [ + "openzeppelin_access", + "openzeppelin_account", + "openzeppelin_introspection", + "openzeppelin_token", + "openzeppelin_utils", +] + +[[package]] +name = "openzeppelin_introspection" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:13e04a2190684e6804229a77a6c56de7d033db8b9ef519e5e8dee400a70d8a3d" + +[[package]] +name = "openzeppelin_merkle_tree" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:039608900e92f3dcf479bf53a49a1fd76452acd97eb86e390d1eb92cacdaf3af" + +[[package]] +name = "openzeppelin_presets" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:5c07a8de32e5d9abe33988c7927eaa8b5f83bc29dc77302d9c8c44c898611042" +dependencies = [ + "openzeppelin_access", + "openzeppelin_account", + "openzeppelin_finance", + "openzeppelin_introspection", + "openzeppelin_token", + "openzeppelin_upgrades", + "openzeppelin_utils", +] + +[[package]] +name = "openzeppelin_security" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:27155597019ecf971c48d7bfb07fa58cdc146d5297745570071732abca17f19f" + +[[package]] +name = "openzeppelin_token" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:4452f449dc6c1ea97cf69d1d9182749abd40e85bd826cd79652c06a627eafd91" +dependencies = [ + "openzeppelin_access", + "openzeppelin_account", + "openzeppelin_introspection", + "openzeppelin_utils", +] + +[[package]] +name = "openzeppelin_upgrades" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:15fdd63f6b50a0fda7b3f8f434120aaf7637bcdfe6fd8d275ad57343d5ede5e1" + +[[package]] +name = "openzeppelin_utils" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:44f32d242af1e43982decc49c563e613a9b67ade552f5c3d5cde504e92f74607" + [[package]] name = "snforge_scarb_plugin" -version = "0.43.1" +version = "0.44.0" source = "registry+https://scarbs.xyz/" -checksum = "sha256:178e1e2081003ae5e40b5a8574654bed15acbd31cce651d4e74fe2f009bc0122" +checksum = "sha256:ec8c7637b33392a53153c1e5b87a4617ddcb1981951b233ea043cad5136697e2" [[package]] name = "snforge_std" -version = "0.43.1" +version = "0.44.0" source = "registry+https://scarbs.xyz/" -checksum = "sha256:17bc65b0abfb9b174784835df173f9c81c9ad39523dba760f30589ef53cf8bb5" +checksum = "sha256:d4affedfb90715b1ac417b915c0a63377ae6dd69432040e5d933130d65114702" dependencies = [ "snforge_scarb_plugin", ] @@ -20,5 +138,6 @@ dependencies = [ name = "Starknet_contracts" version = "0.1.0" dependencies = [ + "openzeppelin", "snforge_std", ] diff --git a/starknet_contracts/Scarb.toml b/starknet_contracts/Scarb.toml index 361c44a..4cb23c3 100644 --- a/starknet_contracts/Scarb.toml +++ b/starknet_contracts/Scarb.toml @@ -6,10 +6,11 @@ edition = "2024_07" # See more keys and their definitions at https://docs.swmansion.com/scarb/docs/reference/manifest.html [dependencies] -Starknet = "2.11.4" +starknet = "2.11.4" +openzeppelin = "0.20.0" [dev-dependencies] -snforge_std = "0.43.1" +snforge_std = "0.44.0" assert_macros = "2.11.4" [[target.Starknet-contract]] diff --git a/starknet_contracts/scripts/deploy.js b/starknet_contracts/scripts/deploy.js new file mode 100644 index 0000000..9ba1bee --- /dev/null +++ b/starknet_contracts/scripts/deploy.js @@ -0,0 +1,76 @@ +const { starknet } = require("hardhat"); +const { ethers } = require("ethers"); + +async function main() { + console.log("Deploying Staking Contract..."); + + // Get the deployer account + const [deployer] = await starknet.getSigners(); + console.log("Deploying contracts with the account:", deployer.address); + + // Deploy StarkToken + console.log("Deploying StarkToken..."); + const starkTokenFactory = await starknet.getContractFactory("StarkToken"); + const starkToken = await starkTokenFactory.deploy([deployer.address]); + await starkToken.deployed(); + console.log("StarkToken deployed to:", starkToken.address); + + // Deploy RewardToken + console.log("Deploying RewardToken..."); + const rewardTokenFactory = await starknet.getContractFactory("RewardERC20"); + const rewardToken = await rewardTokenFactory.deploy([deployer.address]); + await rewardToken.deployed(); + console.log("RewardToken deployed to:", rewardToken.address); + + // Deploy StakingContract + console.log("Deploying StakingContract..."); + const stakingFactory = await starknet.getContractFactory("StakingContract"); + const stakingContract = await stakingFactory.deploy([ + deployer.address, + starkToken.address, + rewardToken.address, + ]); + await stakingContract.deployed(); + console.log("StakingContract deployed to:", stakingContract.address); + + // Mint some initial tokens for testing + console.log("Minting initial tokens..."); + + // Mint StarkTokens + await starkToken.invoke("mint", [ + deployer.address, + ethers.utils.parseEther("10000"), + ]); + console.log("Minted 10000 STK tokens to deployer"); + + // Mint RewardTokens + await rewardToken.invoke("mint", [ + deployer.address, + ethers.utils.parseEther("10000"), + ]); + console.log("Minted 10000 RWT tokens to deployer"); + + console.log("\nDeployment completed!"); + console.log("StarkToken:", starkToken.address); + console.log("RewardToken:", rewardToken.address); + console.log("StakingContract:", stakingContract.address); + + // Save deployment info + const deploymentInfo = { + network: starknet.network.name, + starkToken: starkToken.address, + rewardToken: rewardToken.address, + stakingContract: stakingContract.address, + deployer: deployer.address, + timestamp: new Date().toISOString(), + }; + + console.log("\nDeployment Info:", JSON.stringify(deploymentInfo, null, 2)); +} + +main() + .then(() => process.exit(0)) + .catch((error) => { + console.error(error); + process.exit(1); + }); diff --git a/starknet_contracts/src/contracts/RewardToken.cairo b/starknet_contracts/src/contracts/RewardToken.cairo new file mode 100644 index 0000000..d64dfec --- /dev/null +++ b/starknet_contracts/src/contracts/RewardToken.cairo @@ -0,0 +1,59 @@ +#[starknet::contract] +pub mod RewardERC20 { + use openzeppelin::access::ownable::OwnableComponent; + use openzeppelin::token::erc20::{ERC20Component, ERC20HooksEmptyImpl}; + use starknet::ContractAddress; + + component!(path: OwnableComponent, storage: ownable, event: OwnableEvent); + component!(path: ERC20Component, storage: erc20, event: ERC20Event); + + // Ownable Mixin + #[abi(embed_v0)] + impl OwnableImpl = OwnableComponent::OwnableImpl; + impl OwnableInternalImpl = OwnableComponent::InternalImpl; + + // ERC20 Mixin + #[abi(embed_v0)] + impl ERC20Impl = ERC20Component::ERC20Impl; + impl ERC20InternalImpl = ERC20Component::InternalImpl; + + #[storage] + struct Storage { + #[substorage(v0)] + ownable: OwnableComponent::Storage, + #[substorage(v0)] + erc20: ERC20Component::Storage, + } + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + #[flat] + OwnableEvent: OwnableComponent::Event, + #[flat] + ERC20Event: ERC20Component::Event, + } + + #[constructor] + fn constructor(ref self: ContractState, owner: ContractAddress) { + let name = "RewardToken"; + let symbol = "RWT"; + + self.erc20.initializer(name, symbol); + self.ownable.initializer(owner); + } + + #[external(v0)] + fn mint(ref self: ContractState, recipient: ContractAddress, amount: u256) { + // Only owner can mint new tokens + self.ownable.assert_only_owner(); + self.erc20.mint(recipient, amount); + } + + #[external(v0)] + fn burn(ref self: ContractState, amount: u256) { + // Any token holder can burn their own tokens + let caller = starknet::get_caller_address(); + self.erc20.burn(caller, amount); + } +} diff --git a/starknet_contracts/src/contracts/StakingContract.cairo b/starknet_contracts/src/contracts/StakingContract.cairo new file mode 100644 index 0000000..232bbd6 --- /dev/null +++ b/starknet_contracts/src/contracts/StakingContract.cairo @@ -0,0 +1,326 @@ +use starknet::ContractAddress; + +#[starknet::contract] +mod staking_contract { + use openzeppelin::token::erc20::interface::{IERC20Dispatcher, IERC20DispatcherTrait}; + use openzeppelin::access::ownable::OwnableComponent; + use openzeppelin::security::pausable::PausableComponent; + use openzeppelin::token::erc20::{ERC20Component, ERC20HooksEmptyImpl}; + use starknet::{ContractAddress, get_caller_address, get_contract_address, get_block_timestamp}; + use starknet::storage::{Map, StoragePointerReadAccess, StoragePointerWriteAccess, StorageMapReadAccess, StorageMapWriteAccess}; + + component!(path: OwnableComponent, storage: ownable, event: OwnableEvent); + component!(path: PausableComponent, storage: pausable, event: PausableEvent); + + // Ownable Mixin + #[abi(embed_v0)] + impl OwnableImpl = OwnableComponent::OwnableImpl; + impl OwnableInternalImpl = OwnableComponent::InternalImpl; + + // Pausable Mixin + #[abi(embed_v0)] + impl PausableImpl = PausableComponent::PausableImpl; + impl PausableInternalImpl = PausableComponent::InternalImpl; + + #[storage] + struct Storage { + #[substorage(v0)] + ownable: OwnableComponent::Storage, + #[substorage(v0)] + pausable: PausableComponent::Storage, + + // ERC20 token addresses + stark_token: ContractAddress, + reward_token: ContractAddress, + + // Reward distribution state + reward_rate: u256, // rewards per second + reward_per_token_stored: u256, // cumulative reward per token + last_update_time: u64, // last time reward_per_token_stored was updated + period_finish: u64, // end time of current reward period + + // User state + user_reward_per_token_paid: Map, // reward per token paid to user + rewards: Map, // pending rewards for user + balances: Map, // staked balances + + // Global state + total_supply: u256, // total staked tokens + } + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + #[flat] + OwnableEvent: OwnableComponent::Event, + #[flat] + PausableEvent: PausableComponent::Event, + + Staked: Staked, + Unstaked: Unstaked, + RewardPaid: RewardPaid, + RewardsFunded: RewardsFunded, + RecoveredTokens: RecoveredTokens, + } + + #[derive(Drop, starknet::Event)] + struct Staked { + #[key] + user: ContractAddress, + amount: u256, + } + + #[derive(Drop, starknet::Event)] + struct Unstaked { + #[key] + user: ContractAddress, + amount: u256, + } + + #[derive(Drop, starknet::Event)] + struct RewardPaid { + #[key] + user: ContractAddress, + reward: u256, + } + + #[derive(Drop, starknet::Event)] + struct RewardsFunded { + amount: u256, + duration: u64, + } + + #[derive(Drop, starknet::Event)] + struct RecoveredTokens { + token: ContractAddress, + amount: u256, + #[key] + to: ContractAddress, + } + + // Interfaces + #[starknet::interface] + trait IStaking { + fn stake(ref self: TContractState, amount: u256); + fn unstake(ref self: TContractState, amount: u256); + fn claim_rewards(ref self: TContractState); + fn earned(self: @TContractState, account: ContractAddress) -> u256; + fn balance_of(self: @TContractState, account: ContractAddress) -> u256; + fn total_supply(self: @TContractState) -> u256; + fn last_time_reward_applicable(self: @TContractState) -> u64; + fn reward_per_token(self: @TContractState) -> u256; + } + + #[starknet::interface] + trait IOwnerFunctions { + fn fund_rewards(ref self: TContractState, amount: u256, duration: u64); + fn pause(ref self: TContractState); + fn unpause(ref self: TContractState); + fn recover_erc20(ref self: TContractState, token: ContractAddress, amount: u256); + } + + #[constructor] + fn constructor( + ref self: ContractState, + owner: ContractAddress, + reward_token: ContractAddress, + ) { + self.ownable.initializer(owner); + self.stark_token.write(0x04718f5a0fc34cc1af16a1cdee98ffb20c31f5cd61d6ab07201858f4287c938d.try_into().unwrap()); + self.reward_token.write(reward_token); + self.last_update_time.write(get_block_timestamp()); + } + + #[abi(embed_v0)] + impl StakingImpl of IStaking { + /// Stake tokens to earn rewards + fn stake(ref self: ContractState, amount: u256) { + self.pausable.assert_not_paused(); + assert(amount > 0, 'Amount must be > 0'); + + let caller = get_caller_address(); + self.update_reward(caller); + + // Transfer tokens from user to contract + let stark_token = IERC20Dispatcher { contract_address: self.stark_token.read() }; + stark_token.transfer_from(caller, get_contract_address(), amount); + + // Update user balance and total supply + let current_balance = self.balances.read(caller); + self.balances.write(caller, current_balance + amount); + let current_total = self.total_supply.read(); + self.total_supply.write(current_total + amount); + + self.emit(Staked { user: caller, amount }); + } + + /// Unstake tokens + fn unstake(ref self: ContractState, amount: u256) { + self.pausable.assert_not_paused(); + assert(amount > 0, 'Amount must be > 0'); + + let caller = get_caller_address(); + let current_balance = self.balances.read(caller); + assert(current_balance >= amount, 'Insufficient balance'); + + self.update_reward(caller); + + // Update user balance and total supply + self.balances.write(caller, current_balance - amount); + let current_total = self.total_supply.read(); + self.total_supply.write(current_total - amount); + + // Transfer tokens back to user + let stark_token = IERC20Dispatcher { contract_address: self.stark_token.read() }; + stark_token.transfer(caller, amount); + + self.emit(Unstaked { user: caller, amount }); + } + + /// Claim accumulated rewards + fn claim_rewards(ref self: ContractState) { + let caller = get_caller_address(); + self.update_reward(caller); + + let reward = self.rewards.read(caller); + assert(reward > 0, 'No rewards to claim'); + + self.rewards.write(caller, 0); + + // Transfer reward tokens to user + let reward_token = IERC20Dispatcher { contract_address: self.reward_token.read() }; + reward_token.transfer(caller, reward); + + self.emit(RewardPaid { user: caller, reward }); + } + + /// Get earned rewards for an account + fn earned(self: @ContractState, account: ContractAddress) -> u256 { + let balance = self.balances.read(account); + let reward_per_token = self.reward_per_token(); + let user_paid = self.user_reward_per_token_paid.read(account); + let pending = self.rewards.read(account); + + balance * (reward_per_token - user_paid) / 1_000_000_000_000_000_000 + pending + } + + /// Get staked balance for an account + fn balance_of(self: @ContractState, account: ContractAddress) -> u256 { + self.balances.read(account) + } + + /// Get total staked tokens + fn total_supply(self: @ContractState) -> u256 { + self.total_supply.read() + } + + /// Get last time reward was applicable + fn last_time_reward_applicable(self: @ContractState) -> u64 { + let current_time = get_block_timestamp(); + let finish = self.period_finish.read(); + if current_time < finish { + current_time + } else { + finish + } + } + + /// Get current reward per token + fn reward_per_token(self: @ContractState) -> u256 { + let total_supply = self.total_supply.read(); + if total_supply == 0 { + self.reward_per_token_stored.read() + } else { + let last_time = self.last_time_reward_applicable(); + let last_update = self.last_update_time.read(); + let time_diff = last_time - last_update; + let reward_rate = self.reward_rate.read(); + + self.reward_per_token_stored.read() + (reward_rate * time_diff.into() * 1_000_000_000_000_000_000) / total_supply + } + } + } + + #[abi(embed_v0)] + impl OwnerFunctions of IOwnerFunctions { + /// Fund rewards for distribution over a period + fn fund_rewards(ref self: ContractState, amount: u256, duration: u64) { + self.ownable.assert_only_owner(); + assert(amount > 0, 'Amount must be > 0'); + assert(duration > 0, 'Duration must be > 0'); + + self.update_reward_per_token_stored(); + + let current_time = get_block_timestamp(); + let reward_rate = amount / duration.into(); + + // If there's an ongoing period, add to it + let period_finish = self.period_finish.read(); + if current_time < period_finish { + let remaining = period_finish - current_time; + let leftover = remaining.into() * self.reward_rate.read(); + self.reward_rate.write(leftover / duration.into() + reward_rate); + } else { + self.reward_rate.write(reward_rate); + } + + self.last_update_time.write(current_time); + self.period_finish.write(current_time + duration); + + // Transfer reward tokens to contract + let reward_token = IERC20Dispatcher { contract_address: self.reward_token.read() }; + reward_token.transfer_from(get_caller_address(), get_contract_address(), amount); + + self.emit(RewardsFunded { amount, duration }); + } + + /// Pause staking operations + fn pause(ref self: ContractState) { + self.ownable.assert_only_owner(); + self.pausable.pause(); + } + + /// Unpause staking operations + fn unpause(ref self: ContractState) { + self.ownable.assert_only_owner(); + self.pausable.unpause(); + } + + /// Recover accidentally sent tokens (cannot recover staked or reward tokens) + fn recover_erc20(ref self: ContractState, token: ContractAddress, amount: u256) { + self.ownable.assert_only_owner(); + assert(token != self.stark_token.read(), 'Cannot recover staked token'); + assert(token != self.reward_token.read(), 'Cannot recover reward token'); + + let token_dispatcher = IERC20Dispatcher { contract_address: token }; + token_dispatcher.transfer(self.ownable.owner(), amount); + + self.emit(RecoveredTokens { token, amount, to: self.ownable.owner() }); + } + } + + #[generate_trait] + impl InternalImpl of InternalTrait { + /// Update reward for a specific account + fn update_reward(ref self: ContractState, account: ContractAddress) { + let reward_per_token = self.reward_per_token(); + self.reward_per_token_stored.write(reward_per_token); + self.last_update_time.write(self.last_time_reward_applicable()); + + let zero_address = 0.try_into().unwrap(); + if account != zero_address { + let balance = self.balances.read(account); + let user_paid = self.user_reward_per_token_paid.read(account); + self.rewards.write(account, balance * (reward_per_token - user_paid) / 1_000_000_000_000_000_000 + self.rewards.read(account)); + self.user_reward_per_token_paid.write(account, reward_per_token); + } + } + + /// Update global reward per token stored + fn update_reward_per_token_stored(ref self: ContractState) { + let reward_per_token = self.reward_per_token(); + self.reward_per_token_stored.write(reward_per_token); + self.last_update_time.write(self.last_time_reward_applicable()); + } + } +} \ No newline at end of file diff --git a/starknet_contracts/src/interfaces/IOwnerFunctions.cairo b/starknet_contracts/src/interfaces/IOwnerFunctions.cairo new file mode 100644 index 0000000..8b364fb --- /dev/null +++ b/starknet_contracts/src/interfaces/IOwnerFunctions.cairo @@ -0,0 +1,9 @@ +use starknet::ContractAddress; + +#[starknet::interface] +trait IOwnerFunctions { + fn fund_rewards(ref self: TContractState, amount: u256, duration: u64); + fn pause(ref self: TContractState); + fn unpause(ref self: TContractState); + fn recover_erc20(ref self: TContractState, token: ContractAddress, amount: u256); +} \ No newline at end of file diff --git a/starknet_contracts/src/interfaces/IStaking.cairo b/starknet_contracts/src/interfaces/IStaking.cairo new file mode 100644 index 0000000..b78df26 --- /dev/null +++ b/starknet_contracts/src/interfaces/IStaking.cairo @@ -0,0 +1,13 @@ +use starknet::ContractAddress; + +#[starknet::interface] +trait IStaking { + fn stake(ref self: TContractState, amount: u256); + fn unstake(ref self: TContractState, amount: u256); + fn claim_rewards(ref self: TContractState); + fn earned(self: @TContractState, account: ContractAddress) -> u256; + fn balance_of(self: @TContractState, account: ContractAddress) -> u256; + fn total_supply(self: @TContractState) -> u256; + fn last_time_reward_applicable(self: @TContractState) -> u64; + fn reward_per_token(self: @TContractState) -> u256; +} \ No newline at end of file diff --git a/starknet_contracts/src/lib.cairo b/starknet_contracts/src/lib.cairo index 7eeead4..ed79a80 100644 --- a/starknet_contracts/src/lib.cairo +++ b/starknet_contracts/src/lib.cairo @@ -1,9 +1,13 @@ pub mod interfaces{ pub mod IHelloStarknet; pub mod ICounter; + pub mod IStaking; + pub mod IOwnerFunctions; } pub mod contracts{ pub mod HelloStarknet; pub mod counter; + pub mod RewardToken; + pub mod StakingContract; } diff --git a/testing/.gitignore b/testing/.gitignore new file mode 100644 index 0000000..4096f8b --- /dev/null +++ b/testing/.gitignore @@ -0,0 +1,5 @@ +target +.snfoundry_cache/ +snfoundry_trace/ +coverage/ +profile/ diff --git a/testing/.tool-versions b/testing/.tool-versions new file mode 100644 index 0000000..5eb38f6 --- /dev/null +++ b/testing/.tool-versions @@ -0,0 +1,2 @@ +scarb 2.12.2 +starknet-foundry 0.50.0 diff --git a/testing/Scarb.lock b/testing/Scarb.lock new file mode 100644 index 0000000..996b426 --- /dev/null +++ b/testing/Scarb.lock @@ -0,0 +1,143 @@ +# Code generated by scarb DO NOT EDIT. +version = 1 + +[[package]] +name = "openzeppelin" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:05fd9365be85a4a3e878135d5c52229f760b3861ce4ed314cb1e75b178b553da" +dependencies = [ + "openzeppelin_access", + "openzeppelin_account", + "openzeppelin_finance", + "openzeppelin_governance", + "openzeppelin_introspection", + "openzeppelin_merkle_tree", + "openzeppelin_presets", + "openzeppelin_security", + "openzeppelin_token", + "openzeppelin_upgrades", + "openzeppelin_utils", +] + +[[package]] +name = "openzeppelin_access" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:7734901a0ca7a7065e69416fea615dd1dc586c8dc9e76c032f25ee62e8b2a06c" +dependencies = [ + "openzeppelin_introspection", +] + +[[package]] +name = "openzeppelin_account" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:1aa3a71e2f40f66f98d96aa9bf9f361f53db0fd20fa83ef7df04426a3c3a926a" +dependencies = [ + "openzeppelin_introspection", + "openzeppelin_utils", +] + +[[package]] +name = "openzeppelin_finance" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:f0c507fbff955e4180ea3fa17949c0ff85518c40101f4948948d9d9a74143d6c" +dependencies = [ + "openzeppelin_access", + "openzeppelin_token", +] + +[[package]] +name = "openzeppelin_governance" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:c0fb60fad716413d537fabd5fcbb2c499ca6beb95af5f0d1699955ecec4c6f63" +dependencies = [ + "openzeppelin_access", + "openzeppelin_account", + "openzeppelin_introspection", + "openzeppelin_token", + "openzeppelin_utils", +] + +[[package]] +name = "openzeppelin_introspection" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:13e04a2190684e6804229a77a6c56de7d033db8b9ef519e5e8dee400a70d8a3d" + +[[package]] +name = "openzeppelin_merkle_tree" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:039608900e92f3dcf479bf53a49a1fd76452acd97eb86e390d1eb92cacdaf3af" + +[[package]] +name = "openzeppelin_presets" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:5c07a8de32e5d9abe33988c7927eaa8b5f83bc29dc77302d9c8c44c898611042" +dependencies = [ + "openzeppelin_access", + "openzeppelin_account", + "openzeppelin_finance", + "openzeppelin_introspection", + "openzeppelin_token", + "openzeppelin_upgrades", + "openzeppelin_utils", +] + +[[package]] +name = "openzeppelin_security" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:27155597019ecf971c48d7bfb07fa58cdc146d5297745570071732abca17f19f" + +[[package]] +name = "openzeppelin_token" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:4452f449dc6c1ea97cf69d1d9182749abd40e85bd826cd79652c06a627eafd91" +dependencies = [ + "openzeppelin_access", + "openzeppelin_account", + "openzeppelin_introspection", + "openzeppelin_utils", +] + +[[package]] +name = "openzeppelin_upgrades" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:15fdd63f6b50a0fda7b3f8f434120aaf7637bcdfe6fd8d275ad57343d5ede5e1" + +[[package]] +name = "openzeppelin_utils" +version = "0.20.0" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:44f32d242af1e43982decc49c563e613a9b67ade552f5c3d5cde504e92f74607" + +[[package]] +name = "snforge_scarb_plugin" +version = "0.43.1" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:178e1e2081003ae5e40b5a8574654bed15acbd31cce651d4e74fe2f009bc0122" + +[[package]] +name = "snforge_std" +version = "0.43.1" +source = "registry+https://scarbs.xyz/" +checksum = "sha256:17bc65b0abfb9b174784835df173f9c81c9ad39523dba760f30589ef53cf8bb5" +dependencies = [ + "snforge_scarb_plugin", +] + +[[package]] +name = "test" +version = "0.1.0" +dependencies = [ + "openzeppelin", + "snforge_std", +] diff --git a/testing/Scarb.toml b/testing/Scarb.toml new file mode 100644 index 0000000..cc19cfe --- /dev/null +++ b/testing/Scarb.toml @@ -0,0 +1,53 @@ +[package] +name = "test" +version = "0.1.0" +edition = "2024_07" + +# See more keys and their definitions at https://docs.swmansion.com/scarb/docs/reference/manifest.html + +[dependencies] +starknet = "2.11.4" +openzeppelin = "0.20.0" + +[dev-dependencies] +snforge_std = "0.43.1" +assert_macros = "2.11.4" + +[[target.starknet-contract]] +sierra = true + +[scripts] +test = "snforge test" + +[tool.scarb] +allow-prebuilt-plugins = ["snforge_std"] + +# Visit https://foundry-rs.github.io/starknet-foundry/appendix/scarb-toml.html for more information + +# [tool.snforge] # Define `snforge` tool section +# exit_first = true # Stop tests execution immediately upon the first failure +# fuzzer_runs = 1234 # Number of runs of the random fuzzer +# fuzzer_seed = 1111 # Seed for the random fuzzer + +# [[tool.snforge.fork]] # Used for fork testing +# name = "SOME_NAME" # Fork name +# url = "http://your.rpc.url" # Url of the RPC provider +# block_id.tag = "latest" # Block to fork from (block tag) + +# [[tool.snforge.fork]] +# name = "SOME_SECOND_NAME" +# url = "http://your.second.rpc.url" +# block_id.number = "123" # Block to fork from (block number) + +# [[tool.snforge.fork]] +# name = "SOME_THIRD_NAME" +# url = "http://your.third.rpc.url" +# block_id.hash = "0x123" # Block to fork from (block hash) + +# [profile.dev.cairo] # Configure Cairo compiler +# unstable-add-statements-code-locations-debug-info = true # Should be used if you want to use coverage +# unstable-add-statements-functions-debug-info = true # Should be used if you want to use coverage/profiler +# inlining-strategy = "avoid" # Should be used if you want to use coverage + +# [features] # Used for conditional compilation +# enable_for_tests = [] # Feature name and list of other features that should be enabled with it diff --git a/testing/snfoundry.toml b/testing/snfoundry.toml new file mode 100644 index 0000000..0f29e90 --- /dev/null +++ b/testing/snfoundry.toml @@ -0,0 +1,11 @@ +# Visit https://foundry-rs.github.io/starknet-foundry/appendix/snfoundry-toml.html +# and https://foundry-rs.github.io/starknet-foundry/projects/configuration.html for more information + +# [sncast.default] # Define a profile name +# url = "https://starknet-sepolia.public.blastapi.io/rpc/v0_8" # Url of the RPC provider +# accounts-file = "../account-file" # Path to the file with the account data +# account = "mainuser" # Account from `accounts_file` or default account file that will be used for the transactions +# keystore = "~/keystore" # Path to the keystore file +# wait-params = { timeout = 300, retry-interval = 10 } # Wait for submitted transaction parameters +# block-explorer = "StarkScan" # Block explorer service used to display links to transaction details +# show-explorer-links = true # Print links pointing to pages with transaction details in the chosen block explorer diff --git a/testing/src/RewardToken.cairo b/testing/src/RewardToken.cairo new file mode 100644 index 0000000..b71c8b1 --- /dev/null +++ b/testing/src/RewardToken.cairo @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +use starknet::ContractAddress; + +#[starknet::interface] +pub trait IExternal { + fn mint(ref self: ContractState, recipient: ContractAddress, amount: u256); +} +#[starknet::contract] +pub mod RewardERC20 { + use openzeppelin::access::ownable::OwnableComponent; + use openzeppelin::token::erc20::interface::IERC20Metadata; + use openzeppelin::token::erc20::{ERC20Component, ERC20HooksEmptyImpl}; + use starknet::ContractAddress; + use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess}; + component!(path: ERC20Component, storage: erc20, event: ERC20Event); + component!(path: OwnableComponent, storage: ownable, event: OwnableEvent); + + #[storage] + pub struct Storage { + #[substorage(v0)] + pub erc20: ERC20Component::Storage, + #[substorage(v0)] + pub ownable: OwnableComponent::Storage, + custom_decimals: u8 // Add custom decimals storage + } + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + #[flat] + ERC20Event: ERC20Component::Event, + #[flat] + OwnableEvent: OwnableComponent::Event, + } + + #[constructor] + fn constructor( + ref self: ContractState, owner: ContractAddress, name: ByteArray, symbol: ByteArray, + ) { + self.erc20.initializer(name, symbol); + self.ownable.initializer(owner); + self.custom_decimals.write(8); + } + + #[abi(embed_v0)] + impl CustomERC20MetadataImpl of IERC20Metadata { + fn name(self: @ContractState) -> ByteArray { + self.erc20.name() + } + + fn symbol(self: @ContractState) -> ByteArray { + self.erc20.symbol() + } + + fn decimals(self: @ContractState) -> u8 { + self.custom_decimals.read() // Return custom value + } + } + + // Keep existing implementations + #[abi(embed_v0)] + impl ERC20Impl = ERC20Component::ERC20Impl; + #[abi(embed_v0)] + impl OwnableImpl = OwnableComponent::OwnableImpl; + impl InternalImpl = ERC20Component::InternalImpl; + impl OwnableInternalImpl = OwnableComponent::InternalImpl; + + #[abi(embed_v0)] + impl ExternalImpl of super::IExternal { + fn mint(ref self: ContractState, recipient: ContractAddress, amount: u256) { + self.erc20.mint(recipient, amount); + } + } +} diff --git a/testing/src/StakingContract.cairo b/testing/src/StakingContract.cairo new file mode 100644 index 0000000..f924359 --- /dev/null +++ b/testing/src/StakingContract.cairo @@ -0,0 +1,270 @@ +#[starknet::contract] +mod Staking { + use openzeppelin::access::ownable::OwnableComponent; + use openzeppelin::security::pausable::PausableComponent; + use openzeppelin::token::erc20::interface::{IERC20Dispatcher, IERC20DispatcherTrait}; + use starknet::storage::{ + Map, StorageMapReadAccess, StorageMapWriteAccess, StoragePointerReadAccess, + StoragePointerWriteAccess, + }; + use starknet::{ContractAddress, get_block_timestamp, get_caller_address, get_contract_address}; + use test::interfaces::IStaking::IStaking; + use test::types::StakeDetails; + component!(path: OwnableComponent, storage: ownable, event: OwnableEvent); + component!(path: PausableComponent, storage: pausable, event: PausableEvent); + + // Ownable Mixin + #[abi(embed_v0)] + impl OwnableImpl = OwnableComponent::OwnableImpl; + impl OwnableInternalImpl = OwnableComponent::InternalImpl; + + // Pausable Mixin + #[abi(embed_v0)] + impl PausableImpl = PausableComponent::PausableImpl; + impl PausableInternalImpl = PausableComponent::InternalImpl; + + + #[storage] + struct Storage { + #[substorage(v0)] + ownable: OwnableComponent::Storage, + #[substorage(v0)] + pausable: PausableComponent::Storage, + // ERC20 token addresses + stark_token: ContractAddress, + reward_token: ContractAddress, + duration: u64, + // Reward distribution state + reward_rate: u256, // rewards per second + reward_per_token_stored: u256, // cumulative reward per token + last_update_time: u64, // last time reward_per_token_stored was updated + period_finish: u64, // end time of current reward period + // User state + user_reward_per_token_paid: Map, // reward per token paid to user + rewards: Map, // pending rewards for user + balances: Map, // staked balances + stake_count: u256, + stakes: Map, + // Global state + total_supply: u256 // total staked tokens + } + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + #[flat] + OwnableEvent: OwnableComponent::Event, + #[flat] + PausableEvent: PausableComponent::Event, + Staked: Staked, + Unstaked: Unstaked, + RewardPaid: RewardPaid, + RewardsFunded: RewardsFunded, + RecoveredTokens: RecoveredTokens, + } + + + #[derive(Drop, starknet::Event)] + struct Staked { + #[key] + user: ContractAddress, + amount: u256, + } + + #[derive(Drop, starknet::Event)] + struct Unstaked { + #[key] + user: ContractAddress, + amount: u256, + } + + #[derive(Drop, starknet::Event)] + struct RewardPaid { + #[key] + user: ContractAddress, + reward: u256, + } + + #[derive(Drop, starknet::Event)] + struct RewardsFunded { + amount: u256, + duration: u64, + } + + #[derive(Drop, starknet::Event)] + struct RecoveredTokens { + token: ContractAddress, + amount: u256, + #[key] + to: ContractAddress, + } + + + #[constructor] + fn constructor( + ref self: ContractState, reward_token: ContractAddress, stark_token: ContractAddress, + ) { + self.stark_token.write(stark_token); + self.reward_token.write(reward_token); + } + + #[abi(embed_v0)] + impl StakingImpl of IStaking { + /// Stake tokens to earn rewards + fn stake(ref self: ContractState, amount: u256, duration: u64) -> u256 { + assert(amount > 0, 'Amount must be > 0'); + + let caller = get_caller_address(); + + let id = self.stake_count.read() + 1; + + // Transfer tokenget_block_timestamps from user to contract + let stark_token = IERC20Dispatcher { contract_address: self.stark_token.read() }; + stark_token.transfer_from(caller, get_contract_address(), amount); + + // Update user balance and total supply + let current_balance = self.balances.read(caller); + self.balances.write(caller, current_balance + amount); + + let stake_details = StakeDetails { id, owner: caller, duration, amount, valid: true }; + + self.stakes.write(id, stake_details); + self.stake_count.write(id); + + self.emit(Staked { user: caller, amount }); + + id + } + + fn get_stake_details(self: @ContractState, id: u256) -> StakeDetails { + let stake = self.stakes.read(id); + stake + } + + /// Unstake tokens + fn unstake(ref self: ContractState, amount: u256) { + self.pausable.assert_not_paused(); + assert(amount > 0, 'Amount must be > 0'); + + let caller = get_caller_address(); + let current_balance = self.balances.read(caller); + assert(current_balance >= amount, 'Insufficient balance'); + + self.update_reward(caller); + + // Update user balance and total supply + self.balances.write(caller, current_balance - amount); + let current_total = self.total_supply.read(); + self.total_supply.write(current_total - amount); + + // Transfer tokens back to user + let stark_token = IERC20Dispatcher { contract_address: self.stark_token.read() }; + stark_token.transfer(caller, amount); + + self.emit(Unstaked { user: caller, amount }); + } + + fn get_strk_address(self: @ContractState) -> ContractAddress { + self.stark_token.read() + } + fn get_reward_address(self: @ContractState) -> ContractAddress { + self.reward_token.read() + } + + + /// Claim accumulated rewards + fn claim_rewards(ref self: ContractState) { + let caller = get_caller_address(); + + let reward = self.rewards.read(caller); + assert(reward > 0, 'No rewards to claim'); + + self.rewards.write(caller, 0); + + // Transfer reward tokens to user + let reward_token = IERC20Dispatcher { contract_address: self.reward_token.read() }; + reward_token.transfer(caller, reward); + + self.emit(RewardPaid { user: caller, reward }); + } + + /// Get earned rewards for an account + fn earned(self: @ContractState, account: ContractAddress) -> u256 { + let balance = self.balances.read(account); + let reward_per_token = self.reward_per_token(); + let user_paid = self.user_reward_per_token_paid.read(account); + let pending = self.rewards.read(account); + + balance * (reward_per_token - user_paid) / 1_000_000_000_000_000_000 + pending + } + + /// Get staked balance for an account + fn balance_of(self: @ContractState, account: ContractAddress) -> u256 { + self.balances.read(account) + } + + /// Get total staked tokens + fn total_supply(self: @ContractState) -> u256 { + self.total_supply.read() + } + + /// Get last time reward was applicable + fn last_time_reward_applicable(self: @ContractState) -> u64 { + let current_time = get_block_timestamp(); + let finish = self.period_finish.read(); + if current_time < finish { + current_time + } else { + finish + } + } + + /// Get current reward per token + fn reward_per_token(self: @ContractState) -> u256 { + let total_supply = self.total_supply.read(); + if total_supply == 0 { + self.reward_per_token_stored.read() + } else { + let last_time = self.last_time_reward_applicable(); + let last_update = self.last_update_time.read(); + let time_diff = last_time - last_update; + let reward_rate = self.reward_rate.read(); + + self.reward_per_token_stored.read() + + (reward_rate * time_diff.into() * 1_000_000_000_000_000_000) / total_supply + } + } + } + + + #[generate_trait] + impl InternalImpl of InternalTrait { + /// Update reward for a specific account + fn update_reward(ref self: ContractState, account: ContractAddress) { + let reward_per_token = self.reward_per_token(); + self.reward_per_token_stored.write(reward_per_token); + self.last_update_time.write(self.last_time_reward_applicable()); + + let zero_address = 0.try_into().unwrap(); + if account != zero_address { + let balance = self.balances.read(account); + let user_paid = self.user_reward_per_token_paid.read(account); + self + .rewards + .write( + account, + balance * (reward_per_token - user_paid) / 1_000_000_000_000_000_000 + + self.rewards.read(account), + ); + self.user_reward_per_token_paid.write(account, reward_per_token); + } + } + + /// Update global reward per token stored + fn update_reward_per_token_stored(ref self: ContractState) { + let reward_per_token = self.reward_per_token(); + self.reward_per_token_stored.write(reward_per_token); + self.last_update_time.write(self.last_time_reward_applicable()); + } + } +} diff --git a/testing/src/interfaces/ICounter.cairo b/testing/src/interfaces/ICounter.cairo new file mode 100644 index 0000000..dcd920e --- /dev/null +++ b/testing/src/interfaces/ICounter.cairo @@ -0,0 +1,6 @@ +#[Starknet::interface] +pub trait ICounter { + fn get_count(self: @TContractState) -> u32; + fn increment(ref self: TContractState); + fn decrement(ref self: TContractState); +} diff --git a/testing/src/interfaces/IHelloStarknet.cairo b/testing/src/interfaces/IHelloStarknet.cairo new file mode 100644 index 0000000..251c45a --- /dev/null +++ b/testing/src/interfaces/IHelloStarknet.cairo @@ -0,0 +1,9 @@ +/// Interface representing `HelloContract`. +/// This interface allows modification and retrieval of the contract balance. +#[Starknet::interface] +pub trait IHelloStarknet { + /// Increase contract balance. + fn increase_balance(ref self: TContractState, amount: felt252); + /// Retrieve contract balance. + fn get_balance(self: @TContractState) -> felt252; +} diff --git a/testing/src/interfaces/IOwnerFunctions.cairo b/testing/src/interfaces/IOwnerFunctions.cairo new file mode 100644 index 0000000..15b167d --- /dev/null +++ b/testing/src/interfaces/IOwnerFunctions.cairo @@ -0,0 +1,9 @@ +use starknet::ContractAddress; + +#[starknet::interface] +trait IOwnerFunctions { + fn fund_rewards(ref self: TContractState, amount: u256, duration: u64); + fn pause(ref self: TContractState); + fn unpause(ref self: TContractState); + fn recover_erc20(ref self: TContractState, token: ContractAddress, amount: u256); +} diff --git a/testing/src/interfaces/IStaking.cairo b/testing/src/interfaces/IStaking.cairo new file mode 100644 index 0000000..3184c78 --- /dev/null +++ b/testing/src/interfaces/IStaking.cairo @@ -0,0 +1,17 @@ +use starknet::ContractAddress; +use test::types::StakeDetails; +// Interfaces +#[starknet::interface] +pub trait IStaking { + fn stake(ref self: TContractState, amount: u256, duration: u64) -> u256; + fn unstake(ref self: TContractState, amount: u256); + fn claim_rewards(ref self: TContractState); + fn earned(self: @TContractState, account: ContractAddress) -> u256; + fn balance_of(self: @TContractState, account: ContractAddress) -> u256; + fn total_supply(self: @TContractState) -> u256; + fn last_time_reward_applicable(self: @TContractState) -> u64; + fn reward_per_token(self: @TContractState) -> u256; + fn get_stake_details(self: @TContractState, id: u256) -> StakeDetails; + fn get_strk_address(self: @TContractState) -> ContractAddress; + fn get_reward_address(self: @TContractState) -> ContractAddress; +} diff --git a/testing/src/lib.cairo b/testing/src/lib.cairo new file mode 100644 index 0000000..cf61f93 --- /dev/null +++ b/testing/src/lib.cairo @@ -0,0 +1,8 @@ +pub mod RewardToken; +pub mod StakingContract; + +pub mod interfaces { + pub mod IStaking; +} + +pub mod types; diff --git a/testing/src/types.cairo b/testing/src/types.cairo new file mode 100644 index 0000000..4ef7be2 --- /dev/null +++ b/testing/src/types.cairo @@ -0,0 +1,10 @@ +use starknet::ContractAddress; + +#[derive(Drop, Serde, starknet::Store)] +pub struct StakeDetails { + pub id: u256, + pub owner: ContractAddress, + pub duration: u64, + pub amount: u256, + pub valid: bool, +} diff --git a/testing/tests/test_contract.cairo b/testing/tests/test_contract.cairo new file mode 100644 index 0000000..484927b --- /dev/null +++ b/testing/tests/test_contract.cairo @@ -0,0 +1,166 @@ +use openzeppelin::token::erc20::interface::{IERC20Dispatcher, IERC20DispatcherTrait}; +use snforge_std::{ + ContractClassTrait, DeclareResultTrait, declare, start_cheat_caller_address, + stop_cheat_caller_address, +}; +use starknet::{ContractAddress, contract_address_const}; +use test::RewardToken::{IExternalDispatcher, IExternalDispatcherTrait}; +use test::interfaces::IStaking::{IStakingDispatcher, IStakingDispatcherTrait}; + +fn deploy_contract() -> (IStakingDispatcher, ContractAddress, ContractAddress) { + let contract = declare("Staking").unwrap().contract_class(); + // Define constructor calldata + let (strk_address, reward_address) = deploy_erc20(); + let mut constructor_args = array![reward_address.into(), strk_address.into()]; + + let (contract_address, _) = contract.deploy(@constructor_args).unwrap(); + + (IStakingDispatcher { contract_address }, strk_address, reward_address) +} + +fn deploy_erc20() -> (ContractAddress, ContractAddress) { + let owner: ContractAddress = contract_address_const::<'aji'>(); + let name: ByteArray = "STRK"; + let sym: ByteArray = "Sym"; + let reward: ByteArray = "Reward"; + let reward_sym: ByteArray = "RWD"; + // Deploy mock ERC20 + let erc20_class = declare("RewardERC20").unwrap().contract_class(); + + // Pass ByteArray directly for name and symbol + let mut calldata = ArrayTrait::new(); + owner.serialize(ref calldata); + name.serialize(ref calldata); + sym.serialize(ref calldata); + + let (strk_address, _) = erc20_class.deploy(@calldata).unwrap(); + + let mut usdc_calldata = ArrayTrait::new(); + owner.serialize(ref usdc_calldata); + reward.serialize(ref usdc_calldata); + reward_sym.serialize(ref usdc_calldata); + let (reward_address, _) = erc20_class.deploy(@usdc_calldata).unwrap(); + + (strk_address, reward_address) +} + + +#[test] +fn test_deployment() { + let (dispatcher, strk_address, reward_address) = deploy_contract(); + let s_address = dispatcher.get_strk_address(); + let r_address = dispatcher.get_reward_address(); + + assert(s_address == strk_address, 'invalid strk address'); + assert(r_address == reward_address, 'invalid reward address'); +} + +#[test] +fn test_stake() { + let (dispatcher, strk_address, _) = deploy_contract(); + let caller: ContractAddress = contract_address_const::<'aji'>(); + let stake_amount: u256 = 1000; + let stake_duration: u64 = 60 * 60 * 24 * 7; // 1 week + + // Mint some STRK to caller + let strk_mint = IExternalDispatcher { contract_address: strk_address }; + strk_mint.mint(caller, 10000); + + let strk = IERC20Dispatcher { contract_address: strk_address }; + let initial_balance = strk.balance_of(caller); + + start_cheat_caller_address(strk_address, caller); + // Approve staking contract to spend caller's STRK + strk.approve(dispatcher.contract_address, stake_amount); + let allowance = strk.allowance(caller, dispatcher.contract_address); + stop_cheat_caller_address(strk_address); + + println!("Allowance: {}", allowance); + println!("Initial Balance: {}", initial_balance); + + start_cheat_caller_address(dispatcher.contract_address, caller); + // Stake tokens + let stake_id = dispatcher.stake(stake_amount, stake_duration); + let post_stake_balance = strk.balance_of(caller); + + let p_allowance = strk.allowance(caller, dispatcher.contract_address); + println!("Allowance after stake: {}", p_allowance); + println!("Post stake Balance: {}", post_stake_balance); + + assert(post_stake_balance == initial_balance - stake_amount, 'stake failed'); + let contract_balance = strk.balance_of(dispatcher.contract_address); + assert(contract_balance == stake_amount, 'contract balance incorrect'); + + let staked_balance = dispatcher.balance_of(caller); + assert(staked_balance == stake_amount, 'staked balance incorrect'); + + // Get stake details + let stake_details = dispatcher.get_stake_details(stake_id); + assert(stake_details.owner == caller, 'stake owner incorrect'); + assert(stake_details.amount == stake_amount, 'stake amount incorrect'); + assert(stake_details.duration == stake_duration, 'stake duration incorrect'); + assert(stake_details.valid, 'stake valid incorrect'); +} + +#[test] +fn test_Unstake() { + let (dispatcher, strk_address, _) = deploy_contract(); + let caller: ContractAddress = contract_address_const::<'aji'>(); + let stake_amount: u256 = 1000; + let unstake_amount: u256 = 500; + let stake_duration: u64 = 60 * 60 * 24 * 7; // 1 week + + // Mint some STRK to caller + let strk_mint = IExternalDispatcher { contract_address: strk_address }; + strk_mint.mint(caller, 10000); + + let strk = IERC20Dispatcher { contract_address: strk_address }; + let initial_balance = strk.balance_of(caller); + + start_cheat_caller_address(strk_address, caller); + // Approve staking contract to spend caller's STRK + strk.approve(dispatcher.contract_address, stake_amount); + stop_cheat_caller_address(strk_address); + + start_cheat_caller_address(dispatcher.contract_address, caller); + // Stake tokens + let _ = dispatcher.stake(stake_amount, stake_duration); + let post_stake_balance = strk.balance_of(caller); + assert(post_stake_balance == initial_balance - stake_amount, 'stake failed'); + + // Now unstake + dispatcher.unstake(unstake_amount); + let post_unstake_balance = strk.balance_of(caller); + stop_cheat_caller_address(dispatcher.contract_address); + + // Check user balance increased by unstake_amount + assert(post_unstake_balance == post_stake_balance + unstake_amount, 'unstake failed'); + + // Check contract balance decreased + let contract_balance = strk.balance_of(dispatcher.contract_address); + assert(contract_balance == stake_amount - unstake_amount, 'contract balance incorrect'); + + // Check staked balance decreased + let staked_balance = dispatcher.balance_of(caller); + assert(staked_balance == stake_amount - unstake_amount, 'staked balance incorrect'); +} + +#[test] +fn test_last_time_reward_applicable() { + let (dispatcher, _, _) = deploy_contract(); + + let ltra = dispatcher.last_time_reward_applicable(); + // Since period_finish is 0, and current_time > 0, should return 0 + assert(ltra == 0, 'ltra incorrect'); +} + +#[test] +#[should_panic(expected: ('No rewards to claim',))] +fn test_claim_rewards_no_rewards() { + let (dispatcher, _, _) = deploy_contract(); + let caller: ContractAddress = contract_address_const::<'aji'>(); + + start_cheat_caller_address(dispatcher.contract_address, caller); + dispatcher.claim_rewards(); + stop_cheat_caller_address(dispatcher.contract_address); +} \ No newline at end of file