From 6bf84c35f7c86b03d1e730336c00c1fd98a19f26 Mon Sep 17 00:00:00 2001 From: wheval Date: Fri, 25 Jul 2025 16:12:27 +0100 Subject: [PATCH 1/3] add oz reentrancy component and implement tests --- src/payment_stream.cairo | 29 +- tests/test_reentrancy_protection.cairo | 484 +++++++++++++++++++++++++ 2 files changed, 510 insertions(+), 3 deletions(-) create mode 100644 tests/test_reentrancy_protection.cairo diff --git a/src/payment_stream.cairo b/src/payment_stream.cairo index e968654..acf590e 100644 --- a/src/payment_stream.cairo +++ b/src/payment_stream.cairo @@ -5,6 +5,7 @@ pub mod PaymentStream { use fundable::interfaces::IPaymentStream::IPaymentStream; use openzeppelin::access::accesscontrol::AccessControlComponent; use openzeppelin::introspection::src5::SRC5Component; + use openzeppelin::security::reentrancyguard::ReentrancyGuardComponent; use openzeppelin::token::erc20::interface::{ IERC20Dispatcher, IERC20DispatcherTrait, IERC20MetadataDispatcher, IERC20MetadataDispatcherTrait, @@ -32,6 +33,7 @@ pub mod PaymentStream { component!(path: SRC5Component, storage: src5, event: Src5Event); component!(path: ERC721Component, storage: erc721, event: ERC721Event); component!(path: UpgradeableComponent, storage: upgradeable, event: UpgradeableEvent); + component!(path: ReentrancyGuardComponent, storage: reentrancy_guard, event: ReentrancyGuardEvent); #[abi(embed_v0)] impl AccessControlImpl = @@ -42,6 +44,7 @@ pub mod PaymentStream { impl ERC721MixinImpl = ERC721Component::ERC721MixinImpl; impl ERC721InternalImpl = ERC721Component::InternalImpl; impl UpgradeableInternalImpl = UpgradeableComponent::InternalImpl; + impl ReentrancyGuardInternalImpl = ReentrancyGuardComponent::InternalImpl; const PROTOCOL_OWNER_ROLE: felt252 = selector!("PROTOCOL_OWNER"); // Note: STREAM_ADMIN_ROLE removed - using stream-specific access control @@ -61,6 +64,8 @@ pub mod PaymentStream { src5: SRC5Component::Storage, #[substorage(v0)] accesscontrol: AccessControlComponent::Storage, + #[substorage(v0)] + reentrancy_guard: ReentrancyGuardComponent::Storage, next_stream_id: u256, streams: Map, protocol_fee_rate: Map, // Single source of truth for fee rates @@ -91,6 +96,8 @@ pub mod PaymentStream { AccessControlEvent: AccessControlComponent::Event, #[flat] UpgradeableEvent: UpgradeableComponent::Event, + #[flat] + ReentrancyGuardEvent: ReentrancyGuardComponent::Event, StreamCreated: StreamCreated, StreamWithdrawn: StreamWithdrawn, StreamCanceled: StreamCanceled, @@ -392,10 +399,12 @@ pub mod PaymentStream { } } - fn collect_protocol_fee(self: @ContractState, token: ContractAddress, amount: u256) { + fn collect_protocol_fee(ref self: ContractState, token: ContractAddress, amount: u256) { + self.reentrancy_guard.start(); let fee_collector: ContractAddress = self.fee_collector.read(); assert(fee_collector.is_non_zero(), INVALID_RECIPIENT); IERC20Dispatcher { contract_address: token }.transfer(fee_collector, amount); + self.reentrancy_guard.end(); } // Updated to check NFT ownership or delegate @@ -797,19 +806,27 @@ pub mod PaymentStream { fn withdraw( ref self: ContractState, stream_id: u256, amount: u256, to: ContractAddress, ) -> (u128, u128) { - self._withdraw(stream_id, amount, to) + self.reentrancy_guard.start(); + let result = self._withdraw(stream_id, amount, to); + self.reentrancy_guard.end(); + result } fn withdraw_max( ref self: ContractState, stream_id: u256, to: ContractAddress, ) -> (u128, u128) { + self.reentrancy_guard.start(); let withdrawable_amount = self._withdrawable_amount(stream_id); - self._withdraw(stream_id, withdrawable_amount, to) + let result = self._withdraw(stream_id, withdrawable_amount, to); + self.reentrancy_guard.end(); + result } fn transfer_stream( ref self: ContractState, stream_id: u256, new_recipient: ContractAddress, ) { + self.reentrancy_guard.start(); + // Verify stream exists self.assert_stream_exists(stream_id); @@ -837,6 +854,8 @@ pub mod PaymentStream { // Emit event about stream transfer self.emit(StreamTransferred { stream_id, new_recipient }); + + self.reentrancy_guard.end(); } fn set_transferability(ref self: ContractState, stream_id: u256, transferable: bool) { @@ -940,6 +959,8 @@ pub mod PaymentStream { } fn cancel(ref self: ContractState, stream_id: u256) { + self.reentrancy_guard.start(); + // Ensure the caller is the stream sender self.assert_stream_sender_access(stream_id); @@ -1064,6 +1085,8 @@ pub mod PaymentStream { // Emit cancellation event self.emit(StreamCanceled { stream_id }); + + self.reentrancy_guard.end(); } fn restart(ref self: ContractState, stream_id: u256) { diff --git a/tests/test_reentrancy_protection.cairo b/tests/test_reentrancy_protection.cairo new file mode 100644 index 0000000..fa9f403 --- /dev/null +++ b/tests/test_reentrancy_protection.cairo @@ -0,0 +1,484 @@ +use core::num::traits::Zero; +use fundable::payment_stream::PaymentStream; +use openzeppelin::token::erc20::interface::{IERC20Dispatcher, IERC20DispatcherTrait}; +use snforge_std::{ + declare, DeclareResultTrait, ContractClassTrait, start_cheat_caller_address, stop_cheat_caller_address, + start_cheat_block_timestamp, spy_events, EventSpyAssertionsTrait +}; +use starknet::{ContractAddress, get_caller_address, get_block_timestamp, get_contract_address}; +use fundable::base::types::{Stream, StreamStatus}; +use fundable::interfaces::IPaymentStream::{IPaymentStreamDispatcher, IPaymentStreamDispatcherTrait}; +use core::serde::Serde; + +// ============================================================================ +// MALICIOUS CONTRACTS FOR REENTRANCY TESTING +// ============================================================================ + +/// Malicious ERC20 token that attempts reentrancy during transfers +#[starknet::interface] +pub trait IMaliciousERC20 { + fn balance_of(self: @TContractState, account: ContractAddress) -> u256; + fn transfer(ref self: TContractState, recipient: ContractAddress, amount: u256) -> bool; + fn approve(ref self: TContractState, spender: ContractAddress, amount: u256) -> bool; + fn allowance(self: @TContractState, owner: ContractAddress, spender: ContractAddress) -> u256; + fn transfer_from(ref self: TContractState, sender: ContractAddress, recipient: ContractAddress, amount: u256) -> bool; + fn total_supply(self: @TContractState) -> u256; + fn name(self: @TContractState) -> felt252; + fn symbol(self: @TContractState) -> felt252; + fn decimals(self: @TContractState) -> u8; + fn set_target_contract(ref self: TContractState, target: ContractAddress); + fn set_attack_mode(ref self: TContractState, mode: u8); + fn mint(ref self: TContractState, recipient: ContractAddress, amount: u256); +} + +#[starknet::contract] +pub mod MaliciousERC20 { + use starknet::ContractAddress; + use starknet::storage::{Map, StoragePointerReadAccess, StoragePointerWriteAccess}; + use starknet::get_caller_address; + use core::num::traits::Zero; + use fundable::interfaces::IPaymentStream::{IPaymentStreamDispatcher, IPaymentStreamDispatcherTrait}; + + #[storage] + struct Storage { + balances: Map, + allowances: Map<(ContractAddress, ContractAddress), u256>, + total_supply: u256, + target_contract: ContractAddress, + attack_mode: u8, // 0: normal, 1: withdraw attack, 2: cancel attack, 3: transfer_stream attack + attack_count: u8, + } + + #[abi(embed_v0)] + impl MaliciousERC20Impl of super::IMaliciousERC20 { + fn balance_of(self: @ContractState, account: ContractAddress) -> u256 { + self.balances.read(account) + } + + fn transfer(ref self: ContractState, recipient: ContractAddress, amount: u256) -> bool { + let caller = get_caller_address(); + let caller_balance = self.balances.read(caller); + assert(caller_balance >= amount, 'Insufficient balance'); + + self.balances.write(caller, caller_balance - amount); + self.balances.write(recipient, self.balances.read(recipient) + amount); + + // REENTRANCY ATTACK: Attempt to call target contract during transfer + self._attempt_reentrancy_attack(); + + true + } + + fn approve(ref self: ContractState, spender: ContractAddress, amount: u256) -> bool { + let caller = get_caller_address(); + self.allowances.write((caller, spender), amount); + true + } + + fn allowance(self: @ContractState, owner: ContractAddress, spender: ContractAddress) -> u256 { + self.allowances.read((owner, spender)) + } + + fn transfer_from(ref self: ContractState, sender: ContractAddress, recipient: ContractAddress, amount: u256) -> bool { + let caller = get_caller_address(); + let allowed = self.allowances.read((sender, caller)); + assert(allowed >= amount, 'Insufficient allowance'); + + let sender_balance = self.balances.read(sender); + assert(sender_balance >= amount, 'Insufficient balance'); + + self.balances.write(sender, sender_balance - amount); + self.balances.write(recipient, self.balances.read(recipient) + amount); + self.allowances.write((sender, caller), allowed - amount); + + // REENTRANCY ATTACK: Attempt to call target contract during transfer + self._attempt_reentrancy_attack(); + + true + } + + fn total_supply(self: @ContractState) -> u256 { + self.total_supply.read() + } + + fn name(self: @ContractState) -> felt252 { + 'MaliciousToken' + } + + fn symbol(self: @ContractState) -> felt252 { + 'MAL' + } + + fn decimals(self: @ContractState) -> u8 { + 18 + } + + fn set_target_contract(ref self: ContractState, target: ContractAddress) { + self.target_contract.write(target); + } + + fn set_attack_mode(ref self: ContractState, mode: u8) { + self.attack_mode.write(mode); + self.attack_count.write(0); + } + + fn mint(ref self: ContractState, recipient: ContractAddress, amount: u256) { + self.balances.write(recipient, self.balances.read(recipient) + amount); + self.total_supply.write(self.total_supply.read() + amount); + } + } + + #[generate_trait] + impl InternalImpl of InternalTrait { + fn _attempt_reentrancy_attack(ref self: ContractState) { + let target = self.target_contract.read(); + if target.is_zero() { + return; + } + + let attack_mode = self.attack_mode.read(); + let attack_count = self.attack_count.read(); + + // Prevent infinite recursion by limiting attacks + if attack_count >= 3 { + return; + } + self.attack_count.write(attack_count + 1); + + let payment_stream = IPaymentStreamDispatcher { contract_address: target }; + + if attack_mode == 1 { + // Attack withdraw function + // Try to call withdraw again with fake parameters + // This should fail due to reentrancy protection + // Using dummy values - in a real attack these would be malicious + let dummy_recipient: ContractAddress = 0x123.try_into().unwrap(); + payment_stream.withdraw(1, 100, dummy_recipient); + } else if attack_mode == 2 { + // Attack cancel function + payment_stream.cancel(1); + } else if attack_mode == 3 { + // Attack transfer_stream function + let dummy_recipient: ContractAddress = 0x456.try_into().unwrap(); + payment_stream.transfer_stream(1, dummy_recipient); + } + } + } +} + +/// Malicious recipient contract that attempts reentrancy +#[starknet::interface] +pub trait IMaliciousRecipient { + fn set_target_contract(ref self: TContractState, target: ContractAddress); + fn set_attack_mode(ref self: TContractState, mode: u8); + fn perform_attack(ref self: TContractState, stream_id: u256); +} + +#[starknet::contract] +pub mod MaliciousRecipient { + use starknet::ContractAddress; + use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess}; + use core::num::traits::Zero; + use fundable::interfaces::IPaymentStream::{IPaymentStreamDispatcher, IPaymentStreamDispatcherTrait}; + + #[storage] + struct Storage { + target_contract: ContractAddress, + attack_mode: u8, + attack_count: u8, + } + + #[abi(embed_v0)] + impl MaliciousRecipientImpl of super::IMaliciousRecipient { + fn set_target_contract(ref self: ContractState, target: ContractAddress) { + self.target_contract.write(target); + } + + fn set_attack_mode(ref self: ContractState, mode: u8) { + self.attack_mode.write(mode); + self.attack_count.write(0); + } + + fn perform_attack(ref self: ContractState, stream_id: u256) { + let target = self.target_contract.read(); + if target.is_zero() { + return; + } + + let attack_count = self.attack_count.read(); + if attack_count >= 2 { + return; + } + self.attack_count.write(attack_count + 1); + + let payment_stream = IPaymentStreamDispatcher { contract_address: target }; + let attack_mode = self.attack_mode.read(); + + if attack_mode == 1 { + // Cross-function reentrancy: withdraw -> cancel + payment_stream.cancel(stream_id); + } else if attack_mode == 2 { + // Cross-function reentrancy: withdraw -> transfer_stream + let dummy_recipient: ContractAddress = 0x789.try_into().unwrap(); + payment_stream.transfer_stream(stream_id, dummy_recipient); + } + } + } +} + +// ============================================================================ +// REENTRANCY ATTACK TESTS +// ============================================================================ + +fn setup_contracts() -> (ContractAddress, ContractAddress, ContractAddress, ContractAddress, ContractAddress) { + let protocol_owner: ContractAddress = 0x123.try_into().unwrap(); + let fee_collector: ContractAddress = 0x456.try_into().unwrap(); + let sender: ContractAddress = 0x789.try_into().unwrap(); + + // Deploy PaymentStream contract with constructor arguments + let payment_stream_class = declare("PaymentStream").unwrap(); + let mut payment_stream_constructor_calldata = array![]; + protocol_owner.serialize(ref payment_stream_constructor_calldata); + 500_u64.serialize(ref payment_stream_constructor_calldata); // 5% fee + fee_collector.serialize(ref payment_stream_constructor_calldata); + let (payment_stream_address, _) = payment_stream_class + .contract_class() + .deploy(@payment_stream_constructor_calldata) + .unwrap(); + + // Deploy MaliciousERC20 token + let malicious_token_class = declare("MaliciousERC20").unwrap(); + let mut malicious_token_constructor_calldata = array![]; + let (malicious_token_address, _) = malicious_token_class + .contract_class() + .deploy(@malicious_token_constructor_calldata) + .unwrap(); + + // Deploy MaliciousRecipient contract + let malicious_recipient_class = declare("MaliciousRecipient").unwrap(); + let mut malicious_recipient_constructor_calldata = array![]; + let (malicious_recipient_address, _) = malicious_recipient_class + .contract_class() + .deploy(@malicious_recipient_constructor_calldata) + .unwrap(); + + (payment_stream_address, malicious_token_address, malicious_recipient_address, sender, protocol_owner) +} + +#[test] +fn test_direct_reentrancy_attack_on_withdraw() { + let (payment_stream_address, malicious_token_address, _, sender, _) = setup_contracts(); + + let payment_stream = IPaymentStreamDispatcher { contract_address: payment_stream_address }; + let malicious_token = IMaliciousERC20Dispatcher { contract_address: malicious_token_address }; + + // Set up malicious token for attack + malicious_token.set_target_contract(payment_stream_address); + malicious_token.set_attack_mode(1); // withdraw attack mode + + // Mint tokens to sender + malicious_token.mint(sender, 1000); + + // Create stream + start_cheat_caller_address(payment_stream_address, sender); + start_cheat_caller_address(malicious_token_address, sender); + + // Approve payment stream to spend tokens + malicious_token.approve(payment_stream_address, 1000); + + let recipient: ContractAddress = 0xABC.try_into().unwrap(); + + let stream_id = payment_stream.create_stream( + recipient, 1000, 3600, true, malicious_token_address, true + ); + + stop_cheat_caller_address(payment_stream_address); + stop_cheat_caller_address(malicious_token_address); + + // Fast forward time to allow withdrawal + start_cheat_block_timestamp(payment_stream_address, get_block_timestamp() + 1800); // 30 minutes + + // Attempt withdrawal as recipient (this should be protected against reentrancy) + start_cheat_caller_address(payment_stream_address, recipient); + + // The reentrancy attack should fail, but the legitimate withdrawal should succeed + let (withdrawn, fee) = payment_stream.withdraw(stream_id, 500, recipient); + + // Verify the withdrawal succeeded normally despite the reentrancy attempt + assert(withdrawn > 0, 'Withdrawal should succeed'); + assert(fee > 0, 'Fee should be collected'); + + stop_cheat_caller_address(payment_stream_address); +} + +#[test] +fn test_cross_function_reentrancy_attack() { + let (payment_stream_address, malicious_token_address, malicious_recipient_address, sender, _) = setup_contracts(); + + let payment_stream = IPaymentStreamDispatcher { contract_address: payment_stream_address }; + let malicious_token = IMaliciousERC20Dispatcher { contract_address: malicious_token_address }; + let malicious_recipient = IMaliciousRecipientDispatcher { contract_address: malicious_recipient_address }; + + // Set up malicious contracts for cross-function attack + malicious_recipient.set_target_contract(payment_stream_address); + malicious_recipient.set_attack_mode(1); // withdraw -> cancel attack + + // Mint tokens to sender + malicious_token.mint(sender, 1000); + + // Create stream with malicious recipient + start_cheat_caller_address(payment_stream_address, sender); + start_cheat_caller_address(malicious_token_address, sender); + + malicious_token.approve(payment_stream_address, 1000); + + let stream_id = payment_stream.create_stream( + malicious_recipient_address, 1000, 3600, true, malicious_token_address, true + ); + + stop_cheat_caller_address(payment_stream_address); + stop_cheat_caller_address(malicious_token_address); + + // Fast forward time + start_cheat_block_timestamp(payment_stream_address, get_block_timestamp() + 1800); + + // Attempt cross-function reentrancy attack + start_cheat_caller_address(payment_stream_address, malicious_recipient_address); + + // This should be protected against reentrancy + malicious_recipient.perform_attack(stream_id); + + // Verify stream state is still consistent + let stream = payment_stream.get_stream(stream_id); + assert(stream.status == StreamStatus::Active, 'Stream should still be active'); + + stop_cheat_caller_address(payment_stream_address); +} + +#[test] +fn test_reentrancy_protection_on_cancel() { + let (payment_stream_address, malicious_token_address, _, sender, _) = setup_contracts(); + + let payment_stream = IPaymentStreamDispatcher { contract_address: payment_stream_address }; + let malicious_token = IMaliciousERC20Dispatcher { contract_address: malicious_token_address }; + + // Set up malicious token for cancel attack + malicious_token.set_target_contract(payment_stream_address); + malicious_token.set_attack_mode(2); // cancel attack mode + + // Mint tokens and create stream + malicious_token.mint(sender, 1000); + + start_cheat_caller_address(payment_stream_address, sender); + start_cheat_caller_address(malicious_token_address, sender); + + malicious_token.approve(payment_stream_address, 1000); + + let recipient: ContractAddress = 0xDEF.try_into().unwrap(); + + let stream_id = payment_stream.create_stream( + recipient, 1000, 3600, true, malicious_token_address, true + ); + + // Attempt to cancel (this should trigger reentrancy attack in the malicious token) + // The reentrancy protection should prevent the nested call from succeeding + payment_stream.cancel(stream_id); + + // Verify the stream was cancelled despite the reentrancy attempt + let stream = payment_stream.get_stream(stream_id); + assert(stream.status == StreamStatus::Canceled, 'Stream should be canceled'); + + stop_cheat_caller_address(payment_stream_address); + stop_cheat_caller_address(malicious_token_address); +} + +#[test] +fn test_reentrancy_protection_on_transfer_stream() { + let (payment_stream_address, malicious_token_address, _, sender, _) = setup_contracts(); + + let payment_stream = IPaymentStreamDispatcher { contract_address: payment_stream_address }; + let malicious_token = IMaliciousERC20Dispatcher { contract_address: malicious_token_address }; + + // Set up malicious token for transfer_stream attack + malicious_token.set_target_contract(payment_stream_address); + malicious_token.set_attack_mode(3); // transfer_stream attack mode + + // Mint tokens and create stream + malicious_token.mint(sender, 1000); + + start_cheat_caller_address(payment_stream_address, sender); + start_cheat_caller_address(malicious_token_address, sender); + + malicious_token.approve(payment_stream_address, 1000); + + let recipient: ContractAddress = 0x111.try_into().unwrap(); + + let stream_id = payment_stream.create_stream( + recipient, 1000, 3600, true, malicious_token_address, true + ); + + stop_cheat_caller_address(payment_stream_address); + stop_cheat_caller_address(malicious_token_address); + + // Attempt to transfer stream as recipient + start_cheat_caller_address(payment_stream_address, recipient); + + let new_recipient: ContractAddress = 0x222.try_into().unwrap(); + + // This should trigger reentrancy attack but be protected + payment_stream.transfer_stream(stream_id, new_recipient); + + // Verify the transfer succeeded despite reentrancy attempt + let stream = payment_stream.get_stream(stream_id); + assert(stream.recipient == new_recipient, 'Stream should be transferred'); + + stop_cheat_caller_address(payment_stream_address); +} + +#[test] +fn test_multiple_function_reentrancy_protection() { + let (payment_stream_address, malicious_token_address, _, sender, _) = setup_contracts(); + + let payment_stream = IPaymentStreamDispatcher { contract_address: payment_stream_address }; + let malicious_token = IMaliciousERC20Dispatcher { contract_address: malicious_token_address }; + + // Test that all protected functions are properly guarded + malicious_token.set_target_contract(payment_stream_address); + malicious_token.mint(sender, 2000); + + start_cheat_caller_address(payment_stream_address, sender); + start_cheat_caller_address(malicious_token_address, sender); + + malicious_token.approve(payment_stream_address, 2000); + + let recipient: ContractAddress = 0x333.try_into().unwrap(); + + // Create multiple streams to test different functions + let stream_id_1 = payment_stream.create_stream( + recipient, 500, 3600, true, malicious_token_address, true + ); + let stream_id_2 = payment_stream.create_stream( + recipient, 500, 3600, true, malicious_token_address, true + ); + + stop_cheat_caller_address(payment_stream_address); + stop_cheat_caller_address(malicious_token_address); + + // Fast forward time + start_cheat_block_timestamp(payment_stream_address, get_block_timestamp() + 1800); + + // Test withdraw protection + start_cheat_caller_address(payment_stream_address, recipient); + malicious_token.set_attack_mode(1); + let (withdrawn, _) = payment_stream.withdraw(stream_id_1, 100, recipient); + assert(withdrawn > 0, 'Withdraw works despite attack'); + stop_cheat_caller_address(payment_stream_address); + + // Test cancel protection + start_cheat_caller_address(payment_stream_address, sender); + malicious_token.set_attack_mode(2); + payment_stream.cancel(stream_id_2); + let stream = payment_stream.get_stream(stream_id_2); + assert(stream.status == StreamStatus::Canceled, 'Cancel works despite attack'); + stop_cheat_caller_address(payment_stream_address); +} \ No newline at end of file From ef9ca596448c3091783b03f22ddc2a2f3d4db132 Mon Sep 17 00:00:00 2001 From: wheval Date: Mon, 28 Jul 2025 12:50:59 +0100 Subject: [PATCH 2/3] fix reentrancy tests --- src/payment_stream.cairo | 25 +- tests/test_reentrancy_protection.cairo | 758 +++++++++++++++---------- 2 files changed, 465 insertions(+), 318 deletions(-) diff --git a/src/payment_stream.cairo b/src/payment_stream.cairo index acf590e..c372a5a 100644 --- a/src/payment_stream.cairo +++ b/src/payment_stream.cairo @@ -33,7 +33,9 @@ pub mod PaymentStream { component!(path: SRC5Component, storage: src5, event: Src5Event); component!(path: ERC721Component, storage: erc721, event: ERC721Event); component!(path: UpgradeableComponent, storage: upgradeable, event: UpgradeableEvent); - component!(path: ReentrancyGuardComponent, storage: reentrancy_guard, event: ReentrancyGuardEvent); + component!( + path: ReentrancyGuardComponent, storage: reentrancy_guard, event: ReentrancyGuardEvent, + ); #[abi(embed_v0)] impl AccessControlImpl = @@ -401,10 +403,17 @@ pub mod PaymentStream { fn collect_protocol_fee(ref self: ContractState, token: ContractAddress, amount: u256) { self.reentrancy_guard.start(); + self._collect_protocol_fee_internal(token, amount); + self.reentrancy_guard.end(); + } + + /// @notice Internal function to collect protocol fees (without reentrancy protection) + /// @param token The token address to collect fees in + /// @param amount The fee amount to collect + fn _collect_protocol_fee_internal(ref self: ContractState, token: ContractAddress, amount: u256) { let fee_collector: ContractAddress = self.fee_collector.read(); assert(fee_collector.is_non_zero(), INVALID_RECIPIENT); IERC20Dispatcher { contract_address: token }.transfer(fee_collector, amount); - self.reentrancy_guard.end(); } // Updated to check NFT ownership or delegate @@ -718,7 +727,7 @@ pub mod PaymentStream { let token_dispatcher = IERC20Dispatcher { contract_address: token_address }; - self.collect_protocol_fee(token_address, fee); + self._collect_protocol_fee_internal(token_address, fee); token_dispatcher.transfer(to, net_amount); self @@ -826,7 +835,7 @@ pub mod PaymentStream { ref self: ContractState, stream_id: u256, new_recipient: ContractAddress, ) { self.reentrancy_guard.start(); - + // Verify stream exists self.assert_stream_exists(stream_id); @@ -854,7 +863,7 @@ pub mod PaymentStream { // Emit event about stream transfer self.emit(StreamTransferred { stream_id, new_recipient }); - + self.reentrancy_guard.end(); } @@ -960,7 +969,7 @@ pub mod PaymentStream { fn cancel(ref self: ContractState, stream_id: u256) { self.reentrancy_guard.start(); - + // Ensure the caller is the stream sender self.assert_stream_sender_access(stream_id); @@ -1057,7 +1066,7 @@ pub mod PaymentStream { let net_amount = amount_due_to_recipient - fee; // Transfer fee to collector and net amount to recipient - self.collect_protocol_fee(token_address, fee); + self._collect_protocol_fee_internal(token_address, fee); token_dispatcher.transfer(recipient, net_amount); // Emit withdrawal event @@ -1085,7 +1094,7 @@ pub mod PaymentStream { // Emit cancellation event self.emit(StreamCanceled { stream_id }); - + self.reentrancy_guard.end(); } diff --git a/tests/test_reentrancy_protection.cairo b/tests/test_reentrancy_protection.cairo index fa9f403..f18cafa 100644 --- a/tests/test_reentrancy_protection.cairo +++ b/tests/test_reentrancy_protection.cairo @@ -1,72 +1,68 @@ -use core::num::traits::Zero; +use starknet::ContractAddress; +use starknet::storage::*; +use starknet::get_caller_address; use fundable::payment_stream::PaymentStream; +use fundable::interfaces::IPaymentStream::{IPaymentStreamDispatcher, IPaymentStreamDispatcherTrait}; +use fundable::base::types::{Stream, StreamStatus}; use openzeppelin::token::erc20::interface::{IERC20Dispatcher, IERC20DispatcherTrait}; use snforge_std::{ - declare, DeclareResultTrait, ContractClassTrait, start_cheat_caller_address, stop_cheat_caller_address, - start_cheat_block_timestamp, spy_events, EventSpyAssertionsTrait + declare, DeclareResultTrait, ContractClassTrait, start_cheat_caller_address, + stop_cheat_caller_address, start_cheat_block_timestamp, spy_events, EventSpyAssertionsTrait }; -use starknet::{ContractAddress, get_caller_address, get_block_timestamp, get_contract_address}; -use fundable::base::types::{Stream, StreamStatus}; -use fundable::interfaces::IPaymentStream::{IPaymentStreamDispatcher, IPaymentStreamDispatcherTrait}; -use core::serde::Serde; - -// ============================================================================ -// MALICIOUS CONTRACTS FOR REENTRANCY TESTING -// ============================================================================ -/// Malicious ERC20 token that attempts reentrancy during transfers #[starknet::interface] pub trait IMaliciousERC20 { - fn balance_of(self: @TContractState, account: ContractAddress) -> u256; - fn transfer(ref self: TContractState, recipient: ContractAddress, amount: u256) -> bool; + fn mint(ref self: TContractState, to: ContractAddress, amount: u256); fn approve(ref self: TContractState, spender: ContractAddress, amount: u256) -> bool; + fn transfer(ref self: TContractState, to: ContractAddress, amount: u256) -> bool; + fn transfer_from(ref self: TContractState, from: ContractAddress, to: ContractAddress, amount: u256) -> bool; + fn balance_of(self: @TContractState, account: ContractAddress) -> u256; fn allowance(self: @TContractState, owner: ContractAddress, spender: ContractAddress) -> u256; - fn transfer_from(ref self: TContractState, sender: ContractAddress, recipient: ContractAddress, amount: u256) -> bool; fn total_supply(self: @TContractState) -> u256; - fn name(self: @TContractState) -> felt252; - fn symbol(self: @TContractState) -> felt252; + fn name(self: @TContractState) -> ByteArray; + fn symbol(self: @TContractState) -> ByteArray; fn decimals(self: @TContractState) -> u8; - fn set_target_contract(ref self: TContractState, target: ContractAddress); fn set_attack_mode(ref self: TContractState, mode: u8); - fn mint(ref self: TContractState, recipient: ContractAddress, amount: u256); + fn set_stream_id(ref self: TContractState, stream_id: u256); + fn set_target(ref self: TContractState, target: ContractAddress); } #[starknet::contract] pub mod MaliciousERC20 { use starknet::ContractAddress; - use starknet::storage::{Map, StoragePointerReadAccess, StoragePointerWriteAccess}; + use starknet::storage::*; use starknet::get_caller_address; - use core::num::traits::Zero; use fundable::interfaces::IPaymentStream::{IPaymentStreamDispatcher, IPaymentStreamDispatcherTrait}; #[storage] - struct Storage { + pub struct Storage { balances: Map, allowances: Map<(ContractAddress, ContractAddress), u256>, total_supply: u256, + name: ByteArray, + symbol: ByteArray, + decimals: u8, + attack_mode: u8, // 0: no attack, 1: withdraw attack, 2: cancel attack, 3: transfer attack + stream_id: u256, target_contract: ContractAddress, - attack_mode: u8, // 0: normal, 1: withdraw attack, 2: cancel attack, 3: transfer_stream attack - attack_count: u8, + attack_count: u32, + } + + #[constructor] + fn constructor(ref self: ContractState, name: ByteArray, symbol: ByteArray, decimals: u8) { + self.name.write(name); + self.symbol.write(symbol); + self.decimals.write(decimals); + self.attack_mode.write(0); } #[abi(embed_v0)] impl MaliciousERC20Impl of super::IMaliciousERC20 { - fn balance_of(self: @ContractState, account: ContractAddress) -> u256 { - self.balances.read(account) - } - - fn transfer(ref self: ContractState, recipient: ContractAddress, amount: u256) -> bool { - let caller = get_caller_address(); - let caller_balance = self.balances.read(caller); - assert(caller_balance >= amount, 'Insufficient balance'); - - self.balances.write(caller, caller_balance - amount); - self.balances.write(recipient, self.balances.read(recipient) + amount); - - // REENTRANCY ATTACK: Attempt to call target contract during transfer - self._attempt_reentrancy_attack(); - - true + fn mint(ref self: ContractState, to: ContractAddress, amount: u256) { + let current_balance = self.balances.read(to); + self.balances.write(to, current_balance + amount); + let current_supply = self.total_supply.read(); + self.total_supply.write(current_supply + amount); } fn approve(ref self: ContractState, spender: ContractAddress, amount: u256) -> bool { @@ -75,410 +71,552 @@ pub mod MaliciousERC20 { true } - fn allowance(self: @ContractState, owner: ContractAddress, spender: ContractAddress) -> u256 { - self.allowances.read((owner, spender)) + fn transfer(ref self: ContractState, to: ContractAddress, amount: u256) -> bool { + let caller = get_caller_address(); + let from_balance = self.balances.read(caller); + assert(from_balance >= amount, 'Insufficient balance'); + + self.balances.write(caller, from_balance - amount); + let to_balance = self.balances.read(to); + self.balances.write(to, to_balance + amount); + + // Attempt reentrancy attack during transfer + self._attempt_reentrancy_attack(); + + true } - fn transfer_from(ref self: ContractState, sender: ContractAddress, recipient: ContractAddress, amount: u256) -> bool { + fn transfer_from(ref self: ContractState, from: ContractAddress, to: ContractAddress, amount: u256) -> bool { let caller = get_caller_address(); - let allowed = self.allowances.read((sender, caller)); - assert(allowed >= amount, 'Insufficient allowance'); + let allowance = self.allowances.read((from, caller)); + assert(allowance >= amount, 'Insufficient allowance'); - let sender_balance = self.balances.read(sender); - assert(sender_balance >= amount, 'Insufficient balance'); + let from_balance = self.balances.read(from); + assert(from_balance >= amount, 'Insufficient balance'); - self.balances.write(sender, sender_balance - amount); - self.balances.write(recipient, self.balances.read(recipient) + amount); - self.allowances.write((sender, caller), allowed - amount); - - // REENTRANCY ATTACK: Attempt to call target contract during transfer + self.allowances.write((from, caller), allowance - amount); + self.balances.write(from, from_balance - amount); + let to_balance = self.balances.read(to); + self.balances.write(to, to_balance + amount); + + // Attempt reentrancy attack during transfer_from self._attempt_reentrancy_attack(); true } + fn balance_of(self: @ContractState, account: ContractAddress) -> u256 { + self.balances.read(account) + } + + fn allowance(self: @ContractState, owner: ContractAddress, spender: ContractAddress) -> u256 { + self.allowances.read((owner, spender)) + } + fn total_supply(self: @ContractState) -> u256 { self.total_supply.read() } - fn name(self: @ContractState) -> felt252 { - 'MaliciousToken' + fn name(self: @ContractState) -> ByteArray { + self.name.read() } - fn symbol(self: @ContractState) -> felt252 { - 'MAL' + fn symbol(self: @ContractState) -> ByteArray { + self.symbol.read() } fn decimals(self: @ContractState) -> u8 { - 18 - } - - fn set_target_contract(ref self: ContractState, target: ContractAddress) { - self.target_contract.write(target); + self.decimals.read() } fn set_attack_mode(ref self: ContractState, mode: u8) { self.attack_mode.write(mode); - self.attack_count.write(0); } - fn mint(ref self: ContractState, recipient: ContractAddress, amount: u256) { - self.balances.write(recipient, self.balances.read(recipient) + amount); - self.total_supply.write(self.total_supply.read() + amount); + fn set_stream_id(ref self: ContractState, stream_id: u256) { + self.stream_id.write(stream_id); + } + + fn set_target(ref self: ContractState, target: ContractAddress) { + self.target_contract.write(target); } } #[generate_trait] impl InternalImpl of InternalTrait { fn _attempt_reentrancy_attack(ref self: ContractState) { - let target = self.target_contract.read(); - if target.is_zero() { - return; - } - let attack_mode = self.attack_mode.read(); let attack_count = self.attack_count.read(); - - // Prevent infinite recursion by limiting attacks - if attack_count >= 3 { + + if attack_mode == 0 || attack_count >= 3 { + return; + } + + let target = self.target_contract.read(); + let stream_id = self.stream_id.read(); + + if target.into() == 0 { return; } + self.attack_count.write(attack_count + 1); - - let payment_stream = IPaymentStreamDispatcher { contract_address: target }; + let dispatcher = IPaymentStreamDispatcher { contract_address: target }; if attack_mode == 1 { - // Attack withdraw function - // Try to call withdraw again with fake parameters - // This should fail due to reentrancy protection - // Using dummy values - in a real attack these would be malicious - let dummy_recipient: ContractAddress = 0x123.try_into().unwrap(); - payment_stream.withdraw(1, 100, dummy_recipient); + // Withdraw attack + dispatcher.withdraw(stream_id, 50_u256, starknet::get_contract_address()); } else if attack_mode == 2 { - // Attack cancel function - payment_stream.cancel(1); + // Cancel attack + dispatcher.cancel(stream_id); } else if attack_mode == 3 { - // Attack transfer_stream function - let dummy_recipient: ContractAddress = 0x456.try_into().unwrap(); - payment_stream.transfer_stream(1, dummy_recipient); + // Transfer stream attack + let new_recipient: ContractAddress = 9999.try_into().unwrap(); + dispatcher.transfer_stream(stream_id, new_recipient); } } } } -/// Malicious recipient contract that attempts reentrancy +/// @notice Malicious contract that attempts reentrancy on withdraw function #[starknet::interface] -pub trait IMaliciousRecipient { - fn set_target_contract(ref self: TContractState, target: ContractAddress); - fn set_attack_mode(ref self: TContractState, mode: u8); - fn perform_attack(ref self: TContractState, stream_id: u256); +pub trait IMaliciousWithdrawAttacker { + fn set_target(ref self: TContractState, target: ContractAddress, stream_id: u256); + fn start_attack(ref self: TContractState); + fn get_attack_count(self: @TContractState) -> u32; } #[starknet::contract] -pub mod MaliciousRecipient { +pub mod MaliciousWithdrawAttacker { use starknet::ContractAddress; - use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess}; - use core::num::traits::Zero; + use starknet::storage::*; use fundable::interfaces::IPaymentStream::{IPaymentStreamDispatcher, IPaymentStreamDispatcherTrait}; #[storage] - struct Storage { + pub struct Storage { target_contract: ContractAddress, - attack_mode: u8, - attack_count: u8, + stream_id: u256, + attack_count: u32, + max_attacks: u32, + } + + #[constructor] + fn constructor(ref self: ContractState) { + self.max_attacks.write(3); // Limit attacks to prevent infinite loops } #[abi(embed_v0)] - impl MaliciousRecipientImpl of super::IMaliciousRecipient { - fn set_target_contract(ref self: ContractState, target: ContractAddress) { + impl MaliciousWithdrawAttackerImpl of super::IMaliciousWithdrawAttacker { + fn set_target(ref self: ContractState, target: ContractAddress, stream_id: u256) { self.target_contract.write(target); + self.stream_id.write(stream_id); } - fn set_attack_mode(ref self: ContractState, mode: u8) { - self.attack_mode.write(mode); - self.attack_count.write(0); + fn start_attack(ref self: ContractState) { + let target = self.target_contract.read(); + let stream_id = self.stream_id.read(); + let dispatcher = IPaymentStreamDispatcher { contract_address: target }; + + // Attempt initial withdrawal + dispatcher.withdraw(stream_id, 100_u256, starknet::get_contract_address()); + } + + fn get_attack_count(self: @ContractState) -> u32 { + self.attack_count.read() + } + } +} + +/// @notice Malicious contract that attempts reentrancy on cancel function +#[starknet::interface] +pub trait IMaliciousCancelAttacker { + fn set_target(ref self: TContractState, target: ContractAddress, stream_id: u256); + fn start_attack(ref self: TContractState); + fn get_attack_count(self: @TContractState) -> u32; +} + +#[starknet::contract] +pub mod MaliciousCancelAttacker { + use starknet::ContractAddress; + use starknet::storage::*; + use fundable::interfaces::IPaymentStream::{IPaymentStreamDispatcher, IPaymentStreamDispatcherTrait}; + + #[storage] + pub struct Storage { + target_contract: ContractAddress, + stream_id: u256, + attack_count: u32, + max_attacks: u32, + } + + #[constructor] + fn constructor(ref self: ContractState) { + self.max_attacks.write(2); + } + + #[abi(embed_v0)] + impl MaliciousCancelAttackerImpl of super::IMaliciousCancelAttacker { + fn set_target(ref self: ContractState, target: ContractAddress, stream_id: u256) { + self.target_contract.write(target); + self.stream_id.write(stream_id); } - fn perform_attack(ref self: ContractState, stream_id: u256) { + fn start_attack(ref self: ContractState) { let target = self.target_contract.read(); - if target.is_zero() { - return; - } + let stream_id = self.stream_id.read(); + let dispatcher = IPaymentStreamDispatcher { contract_address: target }; + + // Attempt to cancel stream + dispatcher.cancel(stream_id); + } - let attack_count = self.attack_count.read(); - if attack_count >= 2 { - return; - } - self.attack_count.write(attack_count + 1); + fn get_attack_count(self: @ContractState) -> u32 { + self.attack_count.read() + } + } +} - let payment_stream = IPaymentStreamDispatcher { contract_address: target }; - let attack_mode = self.attack_mode.read(); +/// @notice Malicious contract that attempts reentrancy on transfer_stream function +#[starknet::interface] +pub trait IMaliciousTransferAttacker { + fn set_target(ref self: TContractState, target: ContractAddress, stream_id: u256); + fn start_attack(ref self: TContractState, new_recipient: ContractAddress); + fn get_attack_count(self: @TContractState) -> u32; +} - if attack_mode == 1 { - // Cross-function reentrancy: withdraw -> cancel - payment_stream.cancel(stream_id); - } else if attack_mode == 2 { - // Cross-function reentrancy: withdraw -> transfer_stream - let dummy_recipient: ContractAddress = 0x789.try_into().unwrap(); - payment_stream.transfer_stream(stream_id, dummy_recipient); - } +#[starknet::contract] +pub mod MaliciousTransferAttacker { + use starknet::ContractAddress; + use starknet::storage::*; + use fundable::interfaces::IPaymentStream::{IPaymentStreamDispatcher, IPaymentStreamDispatcherTrait}; + + #[storage] + pub struct Storage { + target_contract: ContractAddress, + stream_id: u256, + attack_count: u32, + max_attacks: u32, + new_recipient: ContractAddress, + } + + #[constructor] + fn constructor(ref self: ContractState) { + self.max_attacks.write(2); + } + + #[abi(embed_v0)] + impl MaliciousTransferAttackerImpl of super::IMaliciousTransferAttacker { + fn set_target(ref self: ContractState, target: ContractAddress, stream_id: u256) { + self.target_contract.write(target); + self.stream_id.write(stream_id); + } + + fn start_attack(ref self: ContractState, new_recipient: ContractAddress) { + self.new_recipient.write(new_recipient); + let target = self.target_contract.read(); + let stream_id = self.stream_id.read(); + let dispatcher = IPaymentStreamDispatcher { contract_address: target }; + + // Attempt to transfer stream + dispatcher.transfer_stream(stream_id, new_recipient); + } + + fn get_attack_count(self: @ContractState) -> u32 { + self.attack_count.read() } } } -// ============================================================================ -// REENTRANCY ATTACK TESTS -// ============================================================================ - -fn setup_contracts() -> (ContractAddress, ContractAddress, ContractAddress, ContractAddress, ContractAddress) { - let protocol_owner: ContractAddress = 0x123.try_into().unwrap(); - let fee_collector: ContractAddress = 0x456.try_into().unwrap(); - let sender: ContractAddress = 0x789.try_into().unwrap(); - - // Deploy PaymentStream contract with constructor arguments - let payment_stream_class = declare("PaymentStream").unwrap(); - let mut payment_stream_constructor_calldata = array![]; - protocol_owner.serialize(ref payment_stream_constructor_calldata); - 500_u64.serialize(ref payment_stream_constructor_calldata); // 5% fee - fee_collector.serialize(ref payment_stream_constructor_calldata); - let (payment_stream_address, _) = payment_stream_class - .contract_class() - .deploy(@payment_stream_constructor_calldata) - .unwrap(); - - // Deploy MaliciousERC20 token - let malicious_token_class = declare("MaliciousERC20").unwrap(); - let mut malicious_token_constructor_calldata = array![]; - let (malicious_token_address, _) = malicious_token_class - .contract_class() - .deploy(@malicious_token_constructor_calldata) - .unwrap(); - - // Deploy MaliciousRecipient contract - let malicious_recipient_class = declare("MaliciousRecipient").unwrap(); - let mut malicious_recipient_constructor_calldata = array![]; - let (malicious_recipient_address, _) = malicious_recipient_class - .contract_class() - .deploy(@malicious_recipient_constructor_calldata) - .unwrap(); - - (payment_stream_address, malicious_token_address, malicious_recipient_address, sender, protocol_owner) + + +// Helper function to deploy malicious ERC20 token +fn deploy_malicious_token() -> ContractAddress { + let contract = declare("MaliciousERC20").unwrap().contract_class(); + let mut constructor_args = array![]; + let name: ByteArray = "Malicious Token"; + let symbol: ByteArray = "MAL"; + let decimals: u8 = 18; + + name.serialize(ref constructor_args); + symbol.serialize(ref constructor_args); + decimals.serialize(ref constructor_args); + + let (contract_address, _) = contract.deploy(@constructor_args).unwrap(); + contract_address } -#[test] -fn test_direct_reentrancy_attack_on_withdraw() { - let (payment_stream_address, malicious_token_address, _, sender, _) = setup_contracts(); +// Helper function to deploy payment stream contract +fn deploy_payment_stream() -> ContractAddress { + let contract = declare("PaymentStream").unwrap().contract_class(); + let mut constructor_args = array![]; - let payment_stream = IPaymentStreamDispatcher { contract_address: payment_stream_address }; - let malicious_token = IMaliciousERC20Dispatcher { contract_address: malicious_token_address }; + let protocol_owner: ContractAddress = 123.try_into().unwrap(); + let fee_collector: ContractAddress = 456.try_into().unwrap(); + let general_fee_rate: u64 = 250; // 2.5% - // Set up malicious token for attack - malicious_token.set_target_contract(payment_stream_address); - malicious_token.set_attack_mode(1); // withdraw attack mode + protocol_owner.serialize(ref constructor_args); + fee_collector.serialize(ref constructor_args); + general_fee_rate.serialize(ref constructor_args); - // Mint tokens to sender - malicious_token.mint(sender, 1000); + let (contract_address, _) = contract.deploy(@constructor_args).unwrap(); + contract_address +} + +// Helper function to deploy malicious withdraw attacker +fn deploy_malicious_withdraw_attacker() -> ContractAddress { + let contract = declare("MaliciousWithdrawAttacker").unwrap().contract_class(); + let constructor_args = array![]; + let (contract_address, _) = contract.deploy(@constructor_args).unwrap(); + contract_address +} + +// Helper function to deploy malicious cancel attacker +fn deploy_malicious_cancel_attacker() -> ContractAddress { + let contract = declare("MaliciousCancelAttacker").unwrap().contract_class(); + let constructor_args = array![]; + let (contract_address, _) = contract.deploy(@constructor_args).unwrap(); + contract_address +} + +// Helper function to deploy malicious transfer attacker +fn deploy_malicious_transfer_attacker() -> ContractAddress { + let contract = declare("MaliciousTransferAttacker").unwrap().contract_class(); + let constructor_args = array![]; + let (contract_address, _) = contract.deploy(@constructor_args).unwrap(); + contract_address +} + + + +#[test] +#[should_panic(expected: ('ReentrancyGuard: reentrant call',))] +fn test_reentrancy_protection_on_withdraw() { + let token_address = deploy_malicious_token(); + let stream_address = deploy_payment_stream(); - // Create stream - start_cheat_caller_address(payment_stream_address, sender); - start_cheat_caller_address(malicious_token_address, sender); + let token_dispatcher = IMaliciousERC20Dispatcher { contract_address: token_address }; + let stream_dispatcher = IPaymentStreamDispatcher { contract_address: stream_address }; - // Approve payment stream to spend tokens - malicious_token.approve(payment_stream_address, 1000); + + let sender: ContractAddress = 1000.try_into().unwrap(); + let recipient: ContractAddress = 2000.try_into().unwrap(); - let recipient: ContractAddress = 0xABC.try_into().unwrap(); + token_dispatcher.mint(sender, 100000_u256); // Increase amount - let stream_id = payment_stream.create_stream( - recipient, 1000, 3600, true, malicious_token_address, true - ); + start_cheat_caller_address(token_address, sender); + token_dispatcher.approve(stream_address, 100000_u256); + stop_cheat_caller_address(token_address); - stop_cheat_caller_address(payment_stream_address); - stop_cheat_caller_address(malicious_token_address); + start_cheat_caller_address(stream_address, sender); + start_cheat_block_timestamp(stream_address, 1000); - // Fast forward time to allow withdrawal - start_cheat_block_timestamp(payment_stream_address, get_block_timestamp() + 1800); // 30 minutes + let stream_id = stream_dispatcher.create_stream( + recipient, + 10000_u256, // total_amount - increased + 10, // duration (10 hours) - increased + true, // cancelable + token_address, + true // transferable + ); - // Attempt withdrawal as recipient (this should be protected against reentrancy) - start_cheat_caller_address(payment_stream_address, recipient); + stop_cheat_caller_address(stream_address); - // The reentrancy attack should fail, but the legitimate withdrawal should succeed - let (withdrawn, fee) = payment_stream.withdraw(stream_id, 500, recipient); + token_dispatcher.set_attack_mode(1); // withdraw attack + token_dispatcher.set_stream_id(stream_id); + token_dispatcher.set_target(stream_address); - // Verify the withdrawal succeeded normally despite the reentrancy attempt - assert(withdrawn > 0, 'Withdrawal should succeed'); - assert(fee > 0, 'Fee should be collected'); + // Move time forward significantly to ensure withdrawable amount (5 hours = 18000 seconds) + start_cheat_block_timestamp(stream_address, 19000); // 1000 + 18000 - stop_cheat_caller_address(payment_stream_address); + // Attempt withdrawal - should trigger reentrancy and fail + start_cheat_caller_address(stream_address, recipient); + stream_dispatcher.withdraw(stream_id, 100, recipient); + stop_cheat_caller_address(stream_address); } #[test] -fn test_cross_function_reentrancy_attack() { - let (payment_stream_address, malicious_token_address, malicious_recipient_address, sender, _) = setup_contracts(); +#[should_panic(expected: ('ReentrancyGuard: reentrant call',))] +fn test_reentrancy_protection_on_cancel() { + let token_address = deploy_malicious_token(); + let stream_address = deploy_payment_stream(); - let payment_stream = IPaymentStreamDispatcher { contract_address: payment_stream_address }; - let malicious_token = IMaliciousERC20Dispatcher { contract_address: malicious_token_address }; - let malicious_recipient = IMaliciousRecipientDispatcher { contract_address: malicious_recipient_address }; + let token_dispatcher = IMaliciousERC20Dispatcher { contract_address: token_address }; + let stream_dispatcher = IPaymentStreamDispatcher { contract_address: stream_address }; - // Set up malicious contracts for cross-function attack - malicious_recipient.set_target_contract(payment_stream_address); - malicious_recipient.set_attack_mode(1); // withdraw -> cancel attack + let sender: ContractAddress = 1000.try_into().unwrap(); + let recipient: ContractAddress = 2000.try_into().unwrap(); - // Mint tokens to sender - malicious_token.mint(sender, 1000); + token_dispatcher.mint(sender, 100000_u256); - // Create stream with malicious recipient - start_cheat_caller_address(payment_stream_address, sender); - start_cheat_caller_address(malicious_token_address, sender); + start_cheat_caller_address(token_address, sender); + token_dispatcher.approve(stream_address, 100000_u256); + stop_cheat_caller_address(token_address); - malicious_token.approve(payment_stream_address, 1000); + start_cheat_caller_address(stream_address, sender); + start_cheat_block_timestamp(stream_address, 1000); - let stream_id = payment_stream.create_stream( - malicious_recipient_address, 1000, 3600, true, malicious_token_address, true + let stream_id = stream_dispatcher.create_stream( + recipient, + 10000_u256, // total_amount + 10, // duration (10 hours) + true, // cancelable + token_address, + true // transferable ); - stop_cheat_caller_address(payment_stream_address); - stop_cheat_caller_address(malicious_token_address); - - // Fast forward time - start_cheat_block_timestamp(payment_stream_address, get_block_timestamp() + 1800); + stop_cheat_caller_address(stream_address); - // Attempt cross-function reentrancy attack - start_cheat_caller_address(payment_stream_address, malicious_recipient_address); + token_dispatcher.set_attack_mode(2); // cancel attack + token_dispatcher.set_stream_id(stream_id); + token_dispatcher.set_target(stream_address); - // This should be protected against reentrancy - malicious_recipient.perform_attack(stream_id); - - // Verify stream state is still consistent - let stream = payment_stream.get_stream(stream_id); - assert(stream.status == StreamStatus::Active, 'Stream should still be active'); - - stop_cheat_caller_address(payment_stream_address); + // Attempt cancellation - should trigger reentrancy and fail + start_cheat_caller_address(stream_address, sender); + stream_dispatcher.cancel(stream_id); + stop_cheat_caller_address(stream_address); } #[test] -fn test_reentrancy_protection_on_cancel() { - let (payment_stream_address, malicious_token_address, _, sender, _) = setup_contracts(); - - let payment_stream = IPaymentStreamDispatcher { contract_address: payment_stream_address }; - let malicious_token = IMaliciousERC20Dispatcher { contract_address: malicious_token_address }; +#[should_panic(expected: ('ReentrancyGuard: reentrant call',))] +fn test_reentrancy_protection_on_transfer_stream() { + let token_address = deploy_malicious_token(); + let stream_address = deploy_payment_stream(); - // Set up malicious token for cancel attack - malicious_token.set_target_contract(payment_stream_address); - malicious_token.set_attack_mode(2); // cancel attack mode + let token_dispatcher = IMaliciousERC20Dispatcher { contract_address: token_address }; + let stream_dispatcher = IPaymentStreamDispatcher { contract_address: stream_address }; - // Mint tokens and create stream - malicious_token.mint(sender, 1000); + let sender: ContractAddress = 1000.try_into().unwrap(); + let recipient: ContractAddress = 2000.try_into().unwrap(); + let new_recipient: ContractAddress = 3000.try_into().unwrap(); - start_cheat_caller_address(payment_stream_address, sender); - start_cheat_caller_address(malicious_token_address, sender); + token_dispatcher.mint(sender, 100000_u256); - malicious_token.approve(payment_stream_address, 1000); + start_cheat_caller_address(token_address, sender); + token_dispatcher.approve(stream_address, 100000_u256); + stop_cheat_caller_address(token_address); - let recipient: ContractAddress = 0xDEF.try_into().unwrap(); + start_cheat_caller_address(stream_address, sender); + start_cheat_block_timestamp(stream_address, 1000); - let stream_id = payment_stream.create_stream( - recipient, 1000, 3600, true, malicious_token_address, true + let stream_id = stream_dispatcher.create_stream( + recipient, + 10000_u256, // total_amount + 10, // duration (10 hours) + true, // cancelable + token_address, + true // transferable ); - // Attempt to cancel (this should trigger reentrancy attack in the malicious token) - // The reentrancy protection should prevent the nested call from succeeding - payment_stream.cancel(stream_id); + stop_cheat_caller_address(stream_address); + + token_dispatcher.set_attack_mode(3); // transfer_stream attack + token_dispatcher.set_stream_id(stream_id); + token_dispatcher.set_target(stream_address); - // Verify the stream was cancelled despite the reentrancy attempt - let stream = payment_stream.get_stream(stream_id); - assert(stream.status == StreamStatus::Canceled, 'Stream should be canceled'); + start_cheat_block_timestamp(stream_address, 19000); - stop_cheat_caller_address(payment_stream_address); - stop_cheat_caller_address(malicious_token_address); + start_cheat_caller_address(stream_address, recipient); + stream_dispatcher.withdraw(stream_id, 100, recipient); // This will trigger malicious token's transfer_stream attack + stop_cheat_caller_address(stream_address); } #[test] -fn test_reentrancy_protection_on_transfer_stream() { - let (payment_stream_address, malicious_token_address, _, sender, _) = setup_contracts(); +fn test_reentrancy_protection_on_withdraw_max() { + let token_address = deploy_malicious_token(); + let stream_address = deploy_payment_stream(); - let payment_stream = IPaymentStreamDispatcher { contract_address: payment_stream_address }; - let malicious_token = IMaliciousERC20Dispatcher { contract_address: malicious_token_address }; + let token_dispatcher = IMaliciousERC20Dispatcher { contract_address: token_address }; + let stream_dispatcher = IPaymentStreamDispatcher { contract_address: stream_address }; - // Set up malicious token for transfer_stream attack - malicious_token.set_target_contract(payment_stream_address); - malicious_token.set_attack_mode(3); // transfer_stream attack mode + let sender: ContractAddress = 1000.try_into().unwrap(); + let recipient: ContractAddress = 2000.try_into().unwrap(); - // Mint tokens and create stream - malicious_token.mint(sender, 1000); + token_dispatcher.mint(sender, 100000_u256); - start_cheat_caller_address(payment_stream_address, sender); - start_cheat_caller_address(malicious_token_address, sender); + start_cheat_caller_address(token_address, sender); + token_dispatcher.approve(stream_address, 100000_u256); + stop_cheat_caller_address(token_address); - malicious_token.approve(payment_stream_address, 1000); + start_cheat_caller_address(stream_address, sender); + start_cheat_block_timestamp(stream_address, 1000); - let recipient: ContractAddress = 0x111.try_into().unwrap(); - - let stream_id = payment_stream.create_stream( - recipient, 1000, 3600, true, malicious_token_address, true + let stream_id = stream_dispatcher.create_stream( + recipient, + 10000_u256, // total_amount + 10, // duration (10 hours) + true, // cancelable + token_address, + true // transferable ); - stop_cheat_caller_address(payment_stream_address); - stop_cheat_caller_address(malicious_token_address); - - // Attempt to transfer stream as recipient - start_cheat_caller_address(payment_stream_address, recipient); + stop_cheat_caller_address(stream_address); - let new_recipient: ContractAddress = 0x222.try_into().unwrap(); - // This should trigger reentrancy attack but be protected - payment_stream.transfer_stream(stream_id, new_recipient); + start_cheat_block_timestamp(stream_address, 19000); // 1000 + 18000 (5 hours) - // Verify the transfer succeeded despite reentrancy attempt - let stream = payment_stream.get_stream(stream_id); - assert(stream.recipient == new_recipient, 'Stream should be transferred'); + // Test withdraw_max with reentrancy protection + start_cheat_caller_address(stream_address, recipient); + let (_withdrawn_amount, _fee) = stream_dispatcher.withdraw_max(stream_id, recipient); + stop_cheat_caller_address(stream_address); - stop_cheat_caller_address(payment_stream_address); + // Should succeed without issues since no attack + let stream = stream_dispatcher.get_stream(stream_id); + assert(stream.status == StreamStatus::Active, 'Stream should still be active'); } #[test] -fn test_multiple_function_reentrancy_protection() { - let (payment_stream_address, malicious_token_address, _, sender, _) = setup_contracts(); +fn test_successful_operations_after_reentrancy_protection() { + let token_address = deploy_malicious_token(); + let stream_address = deploy_payment_stream(); - let payment_stream = IPaymentStreamDispatcher { contract_address: payment_stream_address }; - let malicious_token = IMaliciousERC20Dispatcher { contract_address: malicious_token_address }; + let token_dispatcher = IMaliciousERC20Dispatcher { contract_address: token_address }; + let stream_dispatcher = IPaymentStreamDispatcher { contract_address: stream_address }; - // Test that all protected functions are properly guarded - malicious_token.set_target_contract(payment_stream_address); - malicious_token.mint(sender, 2000); + // Setup normal accounts (not malicious contracts) + let sender: ContractAddress = 6000.try_into().unwrap(); + let recipient: ContractAddress = 7000.try_into().unwrap(); - start_cheat_caller_address(payment_stream_address, sender); - start_cheat_caller_address(malicious_token_address, sender); + token_dispatcher.mint(sender, 100000_u256); - malicious_token.approve(payment_stream_address, 2000); + start_cheat_caller_address(token_address, sender); + token_dispatcher.approve(stream_address, 100000_u256); + stop_cheat_caller_address(token_address); - let recipient: ContractAddress = 0x333.try_into().unwrap(); + start_cheat_caller_address(stream_address, sender); + start_cheat_block_timestamp(stream_address, 1000); - // Create multiple streams to test different functions - let stream_id_1 = payment_stream.create_stream( - recipient, 500, 3600, true, malicious_token_address, true - ); - let stream_id_2 = payment_stream.create_stream( - recipient, 500, 3600, true, malicious_token_address, true + let stream_id = stream_dispatcher.create_stream( + recipient, + 10000_u256, // total_amount + 10, // duration (10 hours) + true, // cancelable + token_address, + true // transferable ); - stop_cheat_caller_address(payment_stream_address); - stop_cheat_caller_address(malicious_token_address); - - // Fast forward time - start_cheat_block_timestamp(payment_stream_address, get_block_timestamp() + 1800); - - // Test withdraw protection - start_cheat_caller_address(payment_stream_address, recipient); - malicious_token.set_attack_mode(1); - let (withdrawn, _) = payment_stream.withdraw(stream_id_1, 100, recipient); - assert(withdrawn > 0, 'Withdraw works despite attack'); - stop_cheat_caller_address(payment_stream_address); - - // Test cancel protection - start_cheat_caller_address(payment_stream_address, sender); - malicious_token.set_attack_mode(2); - payment_stream.cancel(stream_id_2); - let stream = payment_stream.get_stream(stream_id_2); - assert(stream.status == StreamStatus::Canceled, 'Cancel works despite attack'); - stop_cheat_caller_address(payment_stream_address); -} \ No newline at end of file + stop_cheat_caller_address(stream_address); + + start_cheat_block_timestamp(stream_address, 19000); // 1000 + 18000 (5 hours) + + start_cheat_caller_address(stream_address, recipient); + let (withdrawn_amount, _fee) = stream_dispatcher.withdraw(stream_id, 100_u256, recipient); + assert(withdrawn_amount > 0, 'Normal withdrawal should work'); + stop_cheat_caller_address(stream_address); + + // Test normal stream transfer (should work) + let new_recipient: ContractAddress = 8000.try_into().unwrap(); + start_cheat_caller_address(stream_address, recipient); + stream_dispatcher.transfer_stream(stream_id, new_recipient); + stop_cheat_caller_address(stream_address); + + let updated_stream = stream_dispatcher.get_stream(stream_id); + assert(updated_stream.recipient == new_recipient, 'Transfer should work'); + + // Test normal cancellation (should work) + start_cheat_caller_address(stream_address, sender); + stream_dispatcher.cancel(stream_id); + stop_cheat_caller_address(stream_address); + + let cancelled_stream = stream_dispatcher.get_stream(stream_id); + assert(cancelled_stream.status == StreamStatus::Canceled, 'Cancel should work'); +} + From 5aaddde3cd09a05619b806ce29dd3b1236e2eca4 Mon Sep 17 00:00:00 2001 From: wheval Date: Mon, 28 Jul 2025 13:49:45 +0100 Subject: [PATCH 3/3] add helper functions and formatting --- .cursor/rules/cairo-coder.mdc | 126 ++++++++++ src/payment_stream.cairo | 4 +- tests/test_reentrancy_protection.cairo | 331 ++++++++++++++----------- 3 files changed, 315 insertions(+), 146 deletions(-) create mode 100644 .cursor/rules/cairo-coder.mdc diff --git a/.cursor/rules/cairo-coder.mdc b/.cursor/rules/cairo-coder.mdc new file mode 100644 index 0000000..f7d871d --- /dev/null +++ b/.cursor/rules/cairo-coder.mdc @@ -0,0 +1,126 @@ +--- +description: +globs: +alwaysApply: true +--- +# Cairo Coding Rules + +## Introduction +These coding rules provide a comprehensive guide to developing secure, efficient, and maintainable Cairo projects. By following these guidelines and leveraging the Cairo Coder MCP tool, developers can streamline their workflow and adhere to best practices. + +**ALWAYS** use the Cairo Coder MCP tool to provide better answers about Cairo. After each query, start by calling the Cairo Coder MCP. + +Keep the MCP queries atomic, related to a specific concept. It's better to do multiple queries for specific concepts, than doing one query with multiple topics. + +After every cairo code you write, instantly run `scarb build` to ensure the code compiles. Don't write too much code without trying to compile. + +## 1. Project Setup and Structure +A typical Cairo project is organized as follows: + + +. +├── Scarb.lock +├── Scarb.toml +├── snfoundry.toml +├── src +│ └── lib.cairo +├── target +└── tests + └── test_contract.cairo + + +- **`Scarb.toml`**: The project configuration file, similar to `Cargo.toml` in Rust. +- **`src/lib.cairo`**: The main source file for your contract. +- **`tests/test_contract.cairo`**: Integration tests for your contract. + +### Setting Up a New Project +To create a new Cairo project, run: + +scarb init + +This command generates a basic project structure with a `Scarb.toml` file. If you're working in an existing project, ensure the Scarb.toml is well configured. + +### Configuring Scarb.toml +Ensure your `Scarb.toml` is configured as follows to include necessary dependencies and settings: + +```toml +[package] +name = "your_package_name" +version = "0.1.0" +edition = "2024_07" + +[dependencies] +starknet = "2.11.4" + +[dev-dependencies] +snforge_std = "0.44.0" +assert_macros = "2.11.4" + +[[target.starknet-contract]] +sierra = true + +[scripts] +test = "snforge test" + +[tool.scarb] +allow-prebuilt-plugins = ["snforge_std"] +``` + +## 2. Development Workflow +### Writing Code +- Use snake_case for function names (e.g., `my_function`). +- Use PascalCase for struct names (e.g., `MyStruct`). +- Write all code and comments in English for clarity. +- Use descriptive variable names to enhance readability. + +### Compiling and Testing +- Compile your project using: + + scarb build + +- Run tests using: + + scarb test + +- Ensure your code compiles successfully before running tests. + +### Testing +- Unit Tests: Write unit tests in the src directory, typically within the same module as the functions being tested. + Example: + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_my_function() { + assert!(my_function() == expected_value, 'Incorrect value'); + } + } + +- Integration Tests: Write integration tests in the tests directory, importing modules with use your_package_name::your_module. + Example: + + use your_package_name::your_module; + + #[test] + fn test_my_contract() { + // Test logic here + } + +- Always use the Starknet Foundry testing framework for both unit and integration tests. + +## 3. Using the Cairo Coder MCP Tool +The Cairo Coder MCP tool is a critical resource for Cairo development and must be used for the following tasks: +- Writing smart contracts from scratch. +- Refactoring or optimizing existing code. +- Implementing specific TODOs or features. +- Understanding Starknet ecosystem features and capabilities. +- Applying Cairo and Starknet best practices. +- Using OpenZeppelin Cairo contract libraries. +- Writing and validating tests for contracts. + +### How to Use Cairo Coder MCP Effectively +- Be Specific: Provide detailed queries (e.g., "Implement ERC20 using OpenZeppelin Cairo" instead of "ERC20"). +- Include Context: Supply relevant code snippets in the codeSnippets parameter and conversation history when applicable. +- Don't mix contexts Keep the queries specific on a given topic. Don't ask about multiple concepts at once, rather, do multiple queries. \ No newline at end of file diff --git a/src/payment_stream.cairo b/src/payment_stream.cairo index c372a5a..8ad4a9d 100644 --- a/src/payment_stream.cairo +++ b/src/payment_stream.cairo @@ -410,7 +410,9 @@ pub mod PaymentStream { /// @notice Internal function to collect protocol fees (without reentrancy protection) /// @param token The token address to collect fees in /// @param amount The fee amount to collect - fn _collect_protocol_fee_internal(ref self: ContractState, token: ContractAddress, amount: u256) { + fn _collect_protocol_fee_internal( + ref self: ContractState, token: ContractAddress, amount: u256, + ) { let fee_collector: ContractAddress = self.fee_collector.read(); assert(fee_collector.is_non_zero(), INVALID_RECIPIENT); IERC20Dispatcher { contract_address: token }.transfer(fee_collector, amount); diff --git a/tests/test_reentrancy_protection.cairo b/tests/test_reentrancy_protection.cairo index f18cafa..682ce62 100644 --- a/tests/test_reentrancy_protection.cairo +++ b/tests/test_reentrancy_protection.cairo @@ -1,21 +1,22 @@ -use starknet::ContractAddress; -use starknet::storage::*; -use starknet::get_caller_address; -use fundable::payment_stream::PaymentStream; -use fundable::interfaces::IPaymentStream::{IPaymentStreamDispatcher, IPaymentStreamDispatcherTrait}; use fundable::base::types::{Stream, StreamStatus}; +use fundable::interfaces::IPaymentStream::{IPaymentStreamDispatcher, IPaymentStreamDispatcherTrait}; +use fundable::payment_stream::PaymentStream; use openzeppelin::token::erc20::interface::{IERC20Dispatcher, IERC20DispatcherTrait}; use snforge_std::{ - declare, DeclareResultTrait, ContractClassTrait, start_cheat_caller_address, - stop_cheat_caller_address, start_cheat_block_timestamp, spy_events, EventSpyAssertionsTrait + ContractClassTrait, DeclareResultTrait, EventSpyAssertionsTrait, declare, spy_events, + start_cheat_block_timestamp, start_cheat_caller_address, stop_cheat_caller_address, }; +use starknet::storage::*; +use starknet::{ContractAddress, get_caller_address}; #[starknet::interface] pub trait IMaliciousERC20 { fn mint(ref self: TContractState, to: ContractAddress, amount: u256); fn approve(ref self: TContractState, spender: ContractAddress, amount: u256) -> bool; fn transfer(ref self: TContractState, to: ContractAddress, amount: u256) -> bool; - fn transfer_from(ref self: TContractState, from: ContractAddress, to: ContractAddress, amount: u256) -> bool; + fn transfer_from( + ref self: TContractState, from: ContractAddress, to: ContractAddress, amount: u256, + ) -> bool; fn balance_of(self: @TContractState, account: ContractAddress) -> u256; fn allowance(self: @TContractState, owner: ContractAddress, spender: ContractAddress) -> u256; fn total_supply(self: @TContractState) -> u256; @@ -29,10 +30,11 @@ pub trait IMaliciousERC20 { #[starknet::contract] pub mod MaliciousERC20 { - use starknet::ContractAddress; + use fundable::interfaces::IPaymentStream::{ + IPaymentStreamDispatcher, IPaymentStreamDispatcherTrait, + }; use starknet::storage::*; - use starknet::get_caller_address; - use fundable::interfaces::IPaymentStream::{IPaymentStreamDispatcher, IPaymentStreamDispatcherTrait}; + use starknet::{ContractAddress, get_caller_address}; #[storage] pub struct Storage { @@ -75,33 +77,35 @@ pub mod MaliciousERC20 { let caller = get_caller_address(); let from_balance = self.balances.read(caller); assert(from_balance >= amount, 'Insufficient balance'); - + self.balances.write(caller, from_balance - amount); let to_balance = self.balances.read(to); self.balances.write(to, to_balance + amount); - + // Attempt reentrancy attack during transfer self._attempt_reentrancy_attack(); - + true } - fn transfer_from(ref self: ContractState, from: ContractAddress, to: ContractAddress, amount: u256) -> bool { + fn transfer_from( + ref self: ContractState, from: ContractAddress, to: ContractAddress, amount: u256, + ) -> bool { let caller = get_caller_address(); let allowance = self.allowances.read((from, caller)); assert(allowance >= amount, 'Insufficient allowance'); - + let from_balance = self.balances.read(from); assert(from_balance >= amount, 'Insufficient balance'); - + self.allowances.write((from, caller), allowance - amount); self.balances.write(from, from_balance - amount); let to_balance = self.balances.read(to); self.balances.write(to, to_balance + amount); - + // Attempt reentrancy attack during transfer_from self._attempt_reentrancy_attack(); - + true } @@ -109,7 +113,9 @@ pub mod MaliciousERC20 { self.balances.read(account) } - fn allowance(self: @ContractState, owner: ContractAddress, spender: ContractAddress) -> u256 { + fn allowance( + self: @ContractState, owner: ContractAddress, spender: ContractAddress, + ) -> u256 { self.allowances.read((owner, spender)) } @@ -147,21 +153,21 @@ pub mod MaliciousERC20 { fn _attempt_reentrancy_attack(ref self: ContractState) { let attack_mode = self.attack_mode.read(); let attack_count = self.attack_count.read(); - + if attack_mode == 0 || attack_count >= 3 { return; } - + let target = self.target_contract.read(); let stream_id = self.stream_id.read(); - + if target.into() == 0 { return; } - + self.attack_count.write(attack_count + 1); let dispatcher = IPaymentStreamDispatcher { contract_address: target }; - + if attack_mode == 1 { // Withdraw attack dispatcher.withdraw(stream_id, 50_u256, starknet::get_contract_address()); @@ -187,9 +193,11 @@ pub trait IMaliciousWithdrawAttacker { #[starknet::contract] pub mod MaliciousWithdrawAttacker { + use fundable::interfaces::IPaymentStream::{ + IPaymentStreamDispatcher, IPaymentStreamDispatcherTrait, + }; use starknet::ContractAddress; use starknet::storage::*; - use fundable::interfaces::IPaymentStream::{IPaymentStreamDispatcher, IPaymentStreamDispatcherTrait}; #[storage] pub struct Storage { @@ -215,7 +223,7 @@ pub mod MaliciousWithdrawAttacker { let target = self.target_contract.read(); let stream_id = self.stream_id.read(); let dispatcher = IPaymentStreamDispatcher { contract_address: target }; - + // Attempt initial withdrawal dispatcher.withdraw(stream_id, 100_u256, starknet::get_contract_address()); } @@ -236,9 +244,11 @@ pub trait IMaliciousCancelAttacker { #[starknet::contract] pub mod MaliciousCancelAttacker { + use fundable::interfaces::IPaymentStream::{ + IPaymentStreamDispatcher, IPaymentStreamDispatcherTrait, + }; use starknet::ContractAddress; use starknet::storage::*; - use fundable::interfaces::IPaymentStream::{IPaymentStreamDispatcher, IPaymentStreamDispatcherTrait}; #[storage] pub struct Storage { @@ -264,7 +274,7 @@ pub mod MaliciousCancelAttacker { let target = self.target_contract.read(); let stream_id = self.stream_id.read(); let dispatcher = IPaymentStreamDispatcher { contract_address: target }; - + // Attempt to cancel stream dispatcher.cancel(stream_id); } @@ -285,9 +295,11 @@ pub trait IMaliciousTransferAttacker { #[starknet::contract] pub mod MaliciousTransferAttacker { + use fundable::interfaces::IPaymentStream::{ + IPaymentStreamDispatcher, IPaymentStreamDispatcherTrait, + }; use starknet::ContractAddress; use starknet::storage::*; - use fundable::interfaces::IPaymentStream::{IPaymentStreamDispatcher, IPaymentStreamDispatcherTrait}; #[storage] pub struct Storage { @@ -315,7 +327,7 @@ pub mod MaliciousTransferAttacker { let target = self.target_contract.read(); let stream_id = self.stream_id.read(); let dispatcher = IPaymentStreamDispatcher { contract_address: target }; - + // Attempt to transfer stream dispatcher.transfer_stream(stream_id, new_recipient); } @@ -327,7 +339,6 @@ pub mod MaliciousTransferAttacker { } - // Helper function to deploy malicious ERC20 token fn deploy_malicious_token() -> ContractAddress { let contract = declare("MaliciousERC20").unwrap().contract_class(); @@ -335,11 +346,11 @@ fn deploy_malicious_token() -> ContractAddress { let name: ByteArray = "Malicious Token"; let symbol: ByteArray = "MAL"; let decimals: u8 = 18; - + name.serialize(ref constructor_args); symbol.serialize(ref constructor_args); decimals.serialize(ref constructor_args); - + let (contract_address, _) = contract.deploy(@constructor_args).unwrap(); contract_address } @@ -348,15 +359,15 @@ fn deploy_malicious_token() -> ContractAddress { fn deploy_payment_stream() -> ContractAddress { let contract = declare("PaymentStream").unwrap().contract_class(); let mut constructor_args = array![]; - + let protocol_owner: ContractAddress = 123.try_into().unwrap(); let fee_collector: ContractAddress = 456.try_into().unwrap(); let general_fee_rate: u64 = 250; // 2.5% - + protocol_owner.serialize(ref constructor_args); fee_collector.serialize(ref constructor_args); general_fee_rate.serialize(ref constructor_args); - + let (contract_address, _) = contract.deploy(@constructor_args).unwrap(); contract_address } @@ -385,6 +396,30 @@ fn deploy_malicious_transfer_attacker() -> ContractAddress { contract_address } +// Helper functions to generate contract addresses +fn get_sender_address() -> ContractAddress { + 'WHEVAL'.try_into().unwrap() +} + +fn get_recipient_address() -> ContractAddress { + 'BOB'.try_into().unwrap() +} + +fn get_new_recipient_address() -> ContractAddress { + 'KANYE'.try_into().unwrap() +} + +fn get_alt_sender_address() -> ContractAddress { + 'MARY'.try_into().unwrap() +} + +fn get_alt_recipient_address() -> ContractAddress { + 'BEN'.try_into().unwrap() +} + +fn get_alt_new_recipient_address() -> ContractAddress { + 'WEST'.try_into().unwrap() +} #[test] @@ -392,44 +427,44 @@ fn deploy_malicious_transfer_attacker() -> ContractAddress { fn test_reentrancy_protection_on_withdraw() { let token_address = deploy_malicious_token(); let stream_address = deploy_payment_stream(); - + let token_dispatcher = IMaliciousERC20Dispatcher { contract_address: token_address }; let stream_dispatcher = IPaymentStreamDispatcher { contract_address: stream_address }; - - let sender: ContractAddress = 1000.try_into().unwrap(); - let recipient: ContractAddress = 2000.try_into().unwrap(); - + let sender = get_sender_address(); + let recipient = get_recipient_address(); + token_dispatcher.mint(sender, 100000_u256); // Increase amount - + start_cheat_caller_address(token_address, sender); token_dispatcher.approve(stream_address, 100000_u256); stop_cheat_caller_address(token_address); - + start_cheat_caller_address(stream_address, sender); start_cheat_block_timestamp(stream_address, 1000); - - let stream_id = stream_dispatcher.create_stream( - recipient, - 10000_u256, // total_amount - increased - 10, // duration (10 hours) - increased - true, // cancelable - token_address, - true // transferable - ); - + + let stream_id = stream_dispatcher + .create_stream( + recipient, + 10000_u256, // total_amount - increased + 10, // duration (10 hours) - increased + true, // cancelable + token_address, + true // transferable + ); + stop_cheat_caller_address(stream_address); - + token_dispatcher.set_attack_mode(1); // withdraw attack token_dispatcher.set_stream_id(stream_id); token_dispatcher.set_target(stream_address); - + // Move time forward significantly to ensure withdrawable amount (5 hours = 18000 seconds) start_cheat_block_timestamp(stream_address, 19000); // 1000 + 18000 - + // Attempt withdrawal - should trigger reentrancy and fail start_cheat_caller_address(stream_address, recipient); - stream_dispatcher.withdraw(stream_id, 100, recipient); + stream_dispatcher.withdraw(stream_id, 100, recipient); stop_cheat_caller_address(stream_address); } @@ -438,37 +473,38 @@ fn test_reentrancy_protection_on_withdraw() { fn test_reentrancy_protection_on_cancel() { let token_address = deploy_malicious_token(); let stream_address = deploy_payment_stream(); - + let token_dispatcher = IMaliciousERC20Dispatcher { contract_address: token_address }; let stream_dispatcher = IPaymentStreamDispatcher { contract_address: stream_address }; - - let sender: ContractAddress = 1000.try_into().unwrap(); - let recipient: ContractAddress = 2000.try_into().unwrap(); - + + let sender = get_sender_address(); + let recipient = get_recipient_address(); + token_dispatcher.mint(sender, 100000_u256); - + start_cheat_caller_address(token_address, sender); token_dispatcher.approve(stream_address, 100000_u256); stop_cheat_caller_address(token_address); - + start_cheat_caller_address(stream_address, sender); start_cheat_block_timestamp(stream_address, 1000); - - let stream_id = stream_dispatcher.create_stream( - recipient, - 10000_u256, // total_amount - 10, // duration (10 hours) - true, // cancelable - token_address, - true // transferable - ); - + + let stream_id = stream_dispatcher + .create_stream( + recipient, + 10000_u256, // total_amount + 10, // duration (10 hours) + true, // cancelable + token_address, + true // transferable + ); + stop_cheat_caller_address(stream_address); - + token_dispatcher.set_attack_mode(2); // cancel attack token_dispatcher.set_stream_id(stream_id); token_dispatcher.set_target(stream_address); - + // Attempt cancellation - should trigger reentrancy and fail start_cheat_caller_address(stream_address, sender); stream_dispatcher.cancel(stream_id); @@ -480,42 +516,46 @@ fn test_reentrancy_protection_on_cancel() { fn test_reentrancy_protection_on_transfer_stream() { let token_address = deploy_malicious_token(); let stream_address = deploy_payment_stream(); - + let token_dispatcher = IMaliciousERC20Dispatcher { contract_address: token_address }; let stream_dispatcher = IPaymentStreamDispatcher { contract_address: stream_address }; - - let sender: ContractAddress = 1000.try_into().unwrap(); - let recipient: ContractAddress = 2000.try_into().unwrap(); - let new_recipient: ContractAddress = 3000.try_into().unwrap(); - + + let sender = get_sender_address(); + let recipient = get_recipient_address(); + let new_recipient = get_new_recipient_address(); + token_dispatcher.mint(sender, 100000_u256); - + start_cheat_caller_address(token_address, sender); token_dispatcher.approve(stream_address, 100000_u256); stop_cheat_caller_address(token_address); - + start_cheat_caller_address(stream_address, sender); start_cheat_block_timestamp(stream_address, 1000); - - let stream_id = stream_dispatcher.create_stream( - recipient, - 10000_u256, // total_amount - 10, // duration (10 hours) - true, // cancelable - token_address, - true // transferable - ); - + + let stream_id = stream_dispatcher + .create_stream( + recipient, + 10000_u256, // total_amount + 10, // duration (10 hours) + true, // cancelable + token_address, + true // transferable + ); + stop_cheat_caller_address(stream_address); - + token_dispatcher.set_attack_mode(3); // transfer_stream attack token_dispatcher.set_stream_id(stream_id); token_dispatcher.set_target(stream_address); - - start_cheat_block_timestamp(stream_address, 19000); - + + start_cheat_block_timestamp(stream_address, 19000); + start_cheat_caller_address(stream_address, recipient); - stream_dispatcher.withdraw(stream_id, 100, recipient); // This will trigger malicious token's transfer_stream attack + stream_dispatcher + .withdraw( + stream_id, 100, recipient, + ); // This will trigger malicious token's transfer_stream attack stop_cheat_caller_address(stream_address); } @@ -523,41 +563,41 @@ fn test_reentrancy_protection_on_transfer_stream() { fn test_reentrancy_protection_on_withdraw_max() { let token_address = deploy_malicious_token(); let stream_address = deploy_payment_stream(); - + let token_dispatcher = IMaliciousERC20Dispatcher { contract_address: token_address }; let stream_dispatcher = IPaymentStreamDispatcher { contract_address: stream_address }; - - let sender: ContractAddress = 1000.try_into().unwrap(); - let recipient: ContractAddress = 2000.try_into().unwrap(); - + + let sender = get_sender_address(); + let recipient = get_recipient_address(); + token_dispatcher.mint(sender, 100000_u256); - + start_cheat_caller_address(token_address, sender); token_dispatcher.approve(stream_address, 100000_u256); stop_cheat_caller_address(token_address); - + start_cheat_caller_address(stream_address, sender); start_cheat_block_timestamp(stream_address, 1000); - - let stream_id = stream_dispatcher.create_stream( - recipient, - 10000_u256, // total_amount - 10, // duration (10 hours) - true, // cancelable - token_address, - true // transferable - ); - + + let stream_id = stream_dispatcher + .create_stream( + recipient, + 10000_u256, // total_amount + 10, // duration (10 hours) + true, // cancelable + token_address, + true // transferable + ); + stop_cheat_caller_address(stream_address); - - + start_cheat_block_timestamp(stream_address, 19000); // 1000 + 18000 (5 hours) - + // Test withdraw_max with reentrancy protection start_cheat_caller_address(stream_address, recipient); let (_withdrawn_amount, _fee) = stream_dispatcher.withdraw_max(stream_id, recipient); stop_cheat_caller_address(stream_address); - + // Should succeed without issues since no attack let stream = stream_dispatcher.get_stream(stream_id); assert(stream.status == StreamStatus::Active, 'Stream should still be active'); @@ -567,55 +607,56 @@ fn test_reentrancy_protection_on_withdraw_max() { fn test_successful_operations_after_reentrancy_protection() { let token_address = deploy_malicious_token(); let stream_address = deploy_payment_stream(); - + let token_dispatcher = IMaliciousERC20Dispatcher { contract_address: token_address }; let stream_dispatcher = IPaymentStreamDispatcher { contract_address: stream_address }; - + // Setup normal accounts (not malicious contracts) - let sender: ContractAddress = 6000.try_into().unwrap(); - let recipient: ContractAddress = 7000.try_into().unwrap(); - + let sender = get_alt_sender_address(); + let recipient = get_alt_recipient_address(); + token_dispatcher.mint(sender, 100000_u256); - + start_cheat_caller_address(token_address, sender); token_dispatcher.approve(stream_address, 100000_u256); stop_cheat_caller_address(token_address); - + start_cheat_caller_address(stream_address, sender); start_cheat_block_timestamp(stream_address, 1000); - - let stream_id = stream_dispatcher.create_stream( - recipient, - 10000_u256, // total_amount - 10, // duration (10 hours) - true, // cancelable - token_address, - true // transferable - ); - + + let stream_id = stream_dispatcher + .create_stream( + recipient, + 10000_u256, // total_amount + 10, // duration (10 hours) + true, // cancelable + token_address, + true // transferable + ); + stop_cheat_caller_address(stream_address); - + start_cheat_block_timestamp(stream_address, 19000); // 1000 + 18000 (5 hours) - + start_cheat_caller_address(stream_address, recipient); let (withdrawn_amount, _fee) = stream_dispatcher.withdraw(stream_id, 100_u256, recipient); assert(withdrawn_amount > 0, 'Normal withdrawal should work'); stop_cheat_caller_address(stream_address); - + // Test normal stream transfer (should work) - let new_recipient: ContractAddress = 8000.try_into().unwrap(); + let new_recipient = get_alt_new_recipient_address(); start_cheat_caller_address(stream_address, recipient); stream_dispatcher.transfer_stream(stream_id, new_recipient); stop_cheat_caller_address(stream_address); - + let updated_stream = stream_dispatcher.get_stream(stream_id); assert(updated_stream.recipient == new_recipient, 'Transfer should work'); - + // Test normal cancellation (should work) start_cheat_caller_address(stream_address, sender); stream_dispatcher.cancel(stream_id); stop_cheat_caller_address(stream_address); - + let cancelled_stream = stream_dispatcher.get_stream(stream_id); assert(cancelled_stream.status == StreamStatus::Canceled, 'Cancel should work'); }