diff --git a/src/base/types.cairo b/src/base/types.cairo index 4fb856c..04ff9d8 100644 --- a/src/base/types.cairo +++ b/src/base/types.cairo @@ -49,8 +49,9 @@ pub struct Organization { #[derive(Copy, Drop, Serde, starknet::Store)] pub struct Milestone { - pub organization: ContractAddress, pub project_id: u64, + pub milestone_id: u64, + pub organization: ContractAddress, pub milestone_description: felt252, pub milestone_amount: u256, pub created_at: u64, diff --git a/src/budgetchain/Budget.cairo b/src/budgetchain/Budget.cairo index 11ccc27..f75bdb6 100644 --- a/src/budgetchain/Budget.cairo +++ b/src/budgetchain/Budget.cairo @@ -447,6 +447,7 @@ pub mod Budget { Milestone { organization: org, project_id: project_id, + milestone_id: j.into() + 1, milestone_description: *milestone_descriptions.at(j), milestone_amount: *milestone_amounts.at(j), created_at: get_block_timestamp(), @@ -526,17 +527,19 @@ pub mod Budget { let created_at = get_block_timestamp(); + // Read the number of the current milestones the organization has + let current_milestone = self.org_milestones.read(org); + let new_milestone: Milestone = Milestone { organization: org, project_id: project_id, + milestone_id: current_milestone + 1, milestone_description: milestone_description, milestone_amount: milestone_amount, created_at: created_at, completed: false, released: false, }; - // // read the number of the current milestones the organization has - let current_milestone = self.org_milestones.read(org); self.milestones.write((project_id, current_milestone + 1), new_milestone); self.org_milestones.write(org, current_milestone + 1); diff --git a/src/budgetchain/MilestoneManager.cairo b/src/budgetchain/MilestoneManager.cairo new file mode 100644 index 0000000..0ff0552 --- /dev/null +++ b/src/budgetchain/MilestoneManager.cairo @@ -0,0 +1,201 @@ +#[starknet::contract] +pub mod MilestoneManager { + use budgetchain_contracts::base::errors::*; + use budgetchain_contracts::base::types::{ADMIN_ROLE, Milestone, ORGANIZATION_ROLE, Project}; + use budgetchain_contracts::interfaces::IMilestoneManager::IMilestoneManager; + use core::array::{Array, ArrayTrait}; + use openzeppelin::access::accesscontrol::{AccessControlComponent, DEFAULT_ADMIN_ROLE}; + use openzeppelin::introspection::src5::SRC5Component; + use starknet::storage::{ + Map, StorageMapReadAccess, StorageMapWriteAccess, StoragePointerReadAccess, + StoragePointerWriteAccess, + }; + use starknet::{ + ContractAddress, contract_address_const, get_block_timestamp, get_caller_address, + }; + component!(path: AccessControlComponent, storage: accesscontrol, event: AccessControlEvent); + component!(path: SRC5Component, storage: src5, event: SRC5Event); + + // AccessControl Mixin + #[abi(embed_v0)] + impl AccessControlImpl = + AccessControlComponent::AccessControlImpl; + impl AccessControlInternalImpl = AccessControlComponent::InternalImpl; + + // SRC5 Mixin + #[abi(embed_v0)] + impl SRC5Impl = SRC5Component::SRC5Impl; + + #[storage] + struct Storage { + admin: ContractAddress, + projects: Map, + milestones: Map<(u64, u64), Milestone>, // (project_id, milestone_id) -> Milestone + project_milestone_count: Map, // project_id -> count of milestones + is_paused: bool, + #[substorage(v0)] + accesscontrol: AccessControlComponent::Storage, + #[substorage(v0)] + src5: SRC5Component::Storage, + } + + #[event] + #[derive(Drop, starknet::Event)] + pub enum Event { + MilestoneCreated: MilestoneCreated, + MilestoneCompleted: MilestoneCompleted, + #[flat] + AccessControlEvent: AccessControlComponent::Event, + #[flat] + SRC5Event: SRC5Component::Event, + } + + #[derive(Drop, starknet::Event)] + pub struct MilestoneCreated { + pub organization: ContractAddress, + pub project_id: u64, + pub milestone_id: u64, + pub milestone_description: felt252, + pub milestone_amount: u256, + pub created_at: u64, + } + + #[derive(Drop, starknet::Event)] + pub struct MilestoneCompleted { + pub project_id: u64, + pub milestone_id: u64, + } + + #[constructor] + fn constructor(ref self: ContractState, default_admin: ContractAddress) { + assert(default_admin != contract_address_const::<0>(), ERROR_ZERO_ADDRESS); + + // Initialize access control + self.accesscontrol.initializer(); + self.accesscontrol._grant_role(DEFAULT_ADMIN_ROLE, default_admin); + self.accesscontrol._grant_role(ADMIN_ROLE, default_admin); + + // Initialize contract storage + self.admin.write(default_admin); + self.is_paused.write(false); + } + + #[abi(embed_v0)] + impl MilestoneManagerImpl of IMilestoneManager { + fn create_milestone( + ref self: ContractState, + organization: ContractAddress, + project_id: u64, + milestone_description: felt252, + milestone_amount: u256, + ) -> u64 { + // Ensure the contract is not paused + self.assert_not_paused(); + + // Verify caller's authorization + let caller = get_caller_address(); + let admin = self.admin.read(); + + assert( + caller == admin + || self.accesscontrol.has_role(ORGANIZATION_ROLE, caller) + || caller == organization, + ERROR_UNAUTHORIZED, + ); + + // Verify project exists with a valid ID + assert(project_id > 0, ERROR_INVALID_PROJECT_ID); + + // Generate new milestone ID + let milestone_id = self.project_milestone_count.read(project_id) + 1; + + // Create new milestone + let created_at = get_block_timestamp(); + let new_milestone = Milestone { + project_id, + milestone_id, + organization, + milestone_description, + milestone_amount, + created_at, + completed: false, + released: false, + }; + self.milestones.write((project_id, milestone_id), new_milestone); + self.project_milestone_count.write(project_id, milestone_id); + self + .emit( + Event::MilestoneCreated( + MilestoneCreated { + organization, + project_id, + milestone_id, + milestone_description, + milestone_amount, + created_at, + }, + ), + ); + milestone_id + } + + fn set_milestone_complete(ref self: ContractState, project_id: u64, milestone_id: u64) { + // Ensure the contract is not paused + self.assert_not_paused(); + let mut milestone = self.milestones.read((project_id, milestone_id)); + assert(milestone.project_id == project_id, ERROR_INVALID_MILESTONE); + assert(milestone.milestone_id == milestone_id, ERROR_INVALID_MILESTONE); + assert(milestone.completed != true, ERROR_MILESTONE_ALREADY_COMPLETED); + milestone.completed = true; + self.milestones.write((project_id, milestone_id), milestone); + self.emit(Event::MilestoneCompleted(MilestoneCompleted { project_id, milestone_id })); + } + + fn get_milestone(self: @ContractState, project_id: u64, milestone_id: u64) -> Milestone { + self.milestones.read((project_id, milestone_id)) + } + + fn get_project_milestones(self: @ContractState, project_id: u64) -> Array { + let mut milestones = ArrayTrait::new(); + let milestone_count = self.project_milestone_count.read(project_id); + let mut i: u64 = 1; + while i <= milestone_count { + let milestone = self.milestones.read((project_id, i)); + milestones.append(milestone); + i += 1; + }; + milestones + } + + fn get_admin(self: @ContractState) -> ContractAddress { + self.admin.read() + } + + fn is_paused(self: @ContractState) -> bool { + self.is_paused.read() + } + + fn pause_contract(ref self: ContractState) { + let caller = get_caller_address(); + assert(caller == self.admin.read(), ERROR_ONLY_ADMIN); + assert(!self.is_paused.read(), ERROR_ALREADY_PAUSED); + self.is_paused.write(true); + } + + fn unpause_contract(ref self: ContractState) { + let caller = get_caller_address(); + assert(caller == self.admin.read(), ERROR_ONLY_ADMIN); + self.is_paused.write(false); + } + } + + #[generate_trait] + pub impl Internal of InternalTrait { + // Internal view function + // - Takes `@self` as it only needs to read state + // - Can only be called by other functions within the contract + fn assert_not_paused(self: @ContractState) { + assert(!self.is_paused.read(), ERROR_CONTRACT_PAUSED); + } + } +} diff --git a/src/interfaces/IMilestoneManager.cairo b/src/interfaces/IMilestoneManager.cairo new file mode 100644 index 0000000..45b119d --- /dev/null +++ b/src/interfaces/IMilestoneManager.cairo @@ -0,0 +1,26 @@ +use budgetchain_contracts::base::types::Milestone; +use starknet::ContractAddress; + +#[starknet::interface] +pub trait IMilestoneManager { + // Milestone Management + fn create_milestone( + ref self: TContractState, + organization: ContractAddress, + project_id: u64, + milestone_description: felt252, + milestone_amount: u256, + ) -> u64; + + fn set_milestone_complete(ref self: TContractState, project_id: u64, milestone_id: u64); + + fn get_milestone(self: @TContractState, project_id: u64, milestone_id: u64) -> Milestone; + + fn get_project_milestones(self: @TContractState, project_id: u64) -> Array; + + // Admin functions + fn get_admin(self: @TContractState) -> ContractAddress; + fn is_paused(self: @TContractState) -> bool; + fn pause_contract(ref self: TContractState); + fn unpause_contract(ref self: TContractState); +} diff --git a/src/lib.cairo b/src/lib.cairo index b035bc4..6011d86 100644 --- a/src/lib.cairo +++ b/src/lib.cairo @@ -6,12 +6,16 @@ pub mod base { pub mod interfaces { pub mod IBudget; + pub mod IMilestoneManager; } pub mod budgetchain { pub mod Budget; + pub mod MilestoneManager; } // Re-export the main modules for easier access pub use budgetchain::Budget; +pub use budgetchain::MilestoneManager; pub use interfaces::IBudget; +pub use interfaces::IMilestoneManager; diff --git a/tests/test_milestone_manager.cairo b/tests/test_milestone_manager.cairo new file mode 100644 index 0000000..51ae95d --- /dev/null +++ b/tests/test_milestone_manager.cairo @@ -0,0 +1,312 @@ +use budgetchain_contracts::base::errors::*; +use budgetchain_contracts::budgetchain::MilestoneManager::*; +use budgetchain_contracts::interfaces::IMilestoneManager::{ + IMilestoneManagerDispatcher, IMilestoneManagerDispatcherTrait, +}; +use core::array::ArrayTrait; +use core::result::ResultTrait; +use core::traits::Into; +use snforge_std::{ + CheatSpan, ContractClassTrait, DeclareResultTrait, EventSpyAssertionsTrait, + cheat_caller_address, declare, spy_events, +}; +use starknet::{ContractAddress, contract_address_const}; + + +fn ADMIN() -> ContractAddress { + contract_address_const::<'ADMIN'>() +} + +fn ORGANIZATION() -> ContractAddress { + contract_address_const::<'ORGANIZATION'>() +} + +fn OTHER_ORG() -> ContractAddress { + contract_address_const::<'OTHER_ORG'>() +} + +fn NON_ORG() -> ContractAddress { + contract_address_const::<'NON_ORG'>() +} + +fn PROJECT_OWNER() -> ContractAddress { + contract_address_const::<'PROJECT_OWNER'>() +} + + +fn setup_test_data() -> (u64, u256, felt252) { + (1_u64, // project_id + 500_u256, // milestone_amount + 'Test milestone' // milestone_description + ) +} + + +fn deploy_milestone_manager( + admin: ContractAddress, +) -> (ContractAddress, IMilestoneManagerDispatcher) { + let contract_class = declare("MilestoneManager").unwrap().contract_class(); + + let mut calldata: Array = ArrayTrait::new(); + calldata.append(admin.into()); + + let (contract_address, _) = contract_class.deploy(@calldata).unwrap(); + + (contract_address, IMilestoneManagerDispatcher { contract_address }) +} + +#[test] +fn test_create_milestone() { + let admin = ADMIN(); + let org = ORGANIZATION(); + + let (contract_address, dispatcher) = deploy_milestone_manager(admin); + + // Setup test data + let (project_id, milestone_amount, milestone_description) = setup_test_data(); + + // Create milestone as admin + cheat_caller_address(contract_address, admin, CheatSpan::TargetCalls(1)); + + // Set up event spy + let mut spy = spy_events(); + + // Create milestone + let milestone_id = dispatcher + .create_milestone(org, project_id, milestone_description, milestone_amount); + + assert(milestone_id == 1, 'Incorrect milestone ID'); + + let milestone = dispatcher.get_milestone(project_id, milestone_id); + assert(milestone.project_id == project_id, 'Wrong project ID'); + assert(milestone.milestone_id == milestone_id, 'Wrong milestone ID'); + assert(milestone.organization == org, 'Wrong organization'); + assert(milestone.milestone_description == milestone_description, 'Wrong description'); + assert(milestone.milestone_amount == milestone_amount, 'Wrong amount'); + assert(milestone.completed == false, 'Should not be completed'); + assert(milestone.released == false, 'Should not be released'); + + // Verify event was emitted + spy + .assert_emitted( + @array![ + ( + contract_address, + MilestoneManager::Event::MilestoneCreated( + MilestoneManager::MilestoneCreated { + organization: org, + project_id, + milestone_id, + milestone_description, + milestone_amount, + created_at: milestone.created_at, + }, + ), + ), + ], + ); +} + +#[test] +fn test_create_multiple_milestones() { + let admin = ADMIN(); + let org = ORGANIZATION(); + + let (contract_address, dispatcher) = deploy_milestone_manager(admin); + + // Setup test data + let (project_id, milestone_amount, milestone_description) = setup_test_data(); + + // Create milestones as admin + cheat_caller_address(contract_address, admin, CheatSpan::TargetCalls(2)); + + let milestone_id1 = dispatcher + .create_milestone(org, project_id, milestone_description, milestone_amount); + + let milestone_id2 = dispatcher + .create_milestone(org, project_id, 'Second milestone', milestone_amount * 2); + + assert(milestone_id1 == 1, 'Incorrect first milestone ID'); + assert(milestone_id2 == 2, 'Incorrect second milestone ID'); + + let milestone1 = dispatcher.get_milestone(project_id, milestone_id1); + let milestone2 = dispatcher.get_milestone(project_id, milestone_id2); + + assert(milestone1.milestone_description == milestone_description, 'Wrong description 1'); + assert(milestone1.milestone_amount == milestone_amount, 'Wrong amount 1'); + + assert(milestone2.milestone_description == 'Second milestone', 'Wrong description 2'); + assert(milestone2.milestone_amount == milestone_amount * 2, 'Wrong amount 2'); + + let milestones = dispatcher.get_project_milestones(project_id); + assert(milestones.len() == 2, 'Wrong number of milestones'); +} + +#[test] +fn test_set_milestone_complete() { + let admin = ADMIN(); + let org = ORGANIZATION(); + + let (contract_address, dispatcher) = deploy_milestone_manager(admin); + + // Setup test data + let (project_id, milestone_amount, milestone_description) = setup_test_data(); + + // Create milestone as admin + cheat_caller_address(contract_address, admin, CheatSpan::TargetCalls(1)); + let milestone_id = dispatcher + .create_milestone(org, project_id, milestone_description, milestone_amount); + + // Set up event spy + let mut spy = spy_events(); + + cheat_caller_address(contract_address, admin, CheatSpan::TargetCalls(1)); + dispatcher.set_milestone_complete(project_id, milestone_id); + + let milestone = dispatcher.get_milestone(project_id, milestone_id); + assert(milestone.completed == true, 'Milestone not marked complete'); + assert(milestone.released == false, 'Released should still be false'); + + // Verify event was emitted + spy + .assert_emitted( + @array![ + ( + contract_address, + MilestoneManager::Event::MilestoneCompleted( + MilestoneManager::MilestoneCompleted { project_id, milestone_id }, + ), + ), + ], + ); +} + +#[test] +#[should_panic(expected: 'Milestone already completed')] +fn test_cannot_complete_milestone_twice() { + let admin = ADMIN(); + let org = ORGANIZATION(); + + let (contract_address, dispatcher) = deploy_milestone_manager(admin); + + // Setup test data + let (project_id, milestone_amount, milestone_description) = setup_test_data(); + + cheat_caller_address(contract_address, admin, CheatSpan::TargetCalls(2)); + let milestone_id = dispatcher + .create_milestone(org, project_id, milestone_description, milestone_amount); + dispatcher.set_milestone_complete(project_id, milestone_id); + + cheat_caller_address(contract_address, admin, CheatSpan::TargetCalls(1)); + dispatcher.set_milestone_complete(project_id, milestone_id); +} + +#[test] +#[should_panic(expected: 'Invalid milestone')] +fn test_cannot_complete_nonexistent_milestone() { + let admin = ADMIN(); + let _org = ORGANIZATION(); + + let (contract_address, dispatcher) = deploy_milestone_manager(admin); + + // Setup test data + let (project_id, _, _) = setup_test_data(); + let nonexistent_milestone_id = 999_u64; + + cheat_caller_address(contract_address, admin, CheatSpan::TargetCalls(1)); + dispatcher.set_milestone_complete(project_id, nonexistent_milestone_id); +} + +#[test] +#[should_panic(expected: 'Caller not authorized')] +fn test_unauthorized_cannot_create_milestone() { + let admin = ADMIN(); + let org = ORGANIZATION(); + let non_org = NON_ORG(); + + let (contract_address, dispatcher) = deploy_milestone_manager(admin); + + // Setup test data + let (project_id, milestone_amount, milestone_description) = setup_test_data(); + + cheat_caller_address(contract_address, non_org, CheatSpan::TargetCalls(1)); + dispatcher.create_milestone(org, project_id, milestone_description, milestone_amount); +} + +#[test] +fn test_pause_and_unpause_contract() { + let admin = ADMIN(); + let org = ORGANIZATION(); + + let (contract_address, dispatcher) = deploy_milestone_manager(admin); + + assert(dispatcher.is_paused() == false, 'Contract should not be paused'); + + // Pause contract as admin + cheat_caller_address(contract_address, admin, CheatSpan::TargetCalls(1)); + dispatcher.pause_contract(); + + assert(dispatcher.is_paused() == true, 'Contract should be paused'); + + cheat_caller_address(contract_address, admin, CheatSpan::TargetCalls(1)); + dispatcher.unpause_contract(); + + assert(dispatcher.is_paused() == false, 'Contract should be unpaused'); + + // Setup test data + let (project_id, milestone_amount, milestone_description) = setup_test_data(); + + cheat_caller_address(contract_address, admin, CheatSpan::TargetCalls(1)); + let milestone_id = dispatcher + .create_milestone(org, project_id, milestone_description, milestone_amount); + + // Verify milestone was created + assert(milestone_id == 1, 'Milestone should be created'); +} + +#[test] +#[should_panic(expected: 'Contract is paused')] +fn test_cannot_create_milestone_when_paused() { + let admin = ADMIN(); + let org = ORGANIZATION(); + + let (contract_address, dispatcher) = deploy_milestone_manager(admin); + + // Setup test data + let (project_id, milestone_amount, milestone_description) = setup_test_data(); + + // Pause contract as admin + cheat_caller_address(contract_address, admin, CheatSpan::TargetCalls(1)); + dispatcher.pause_contract(); + + cheat_caller_address(contract_address, admin, CheatSpan::TargetCalls(1)); + dispatcher.create_milestone(org, project_id, milestone_description, milestone_amount); +} + +#[test] +#[should_panic(expected: 'ONLY ADMIN')] +fn test_only_admin_can_pause_contract() { + let admin = ADMIN(); + let _non_admin = NON_ORG(); + + let (contract_address, dispatcher) = deploy_milestone_manager(admin); + + cheat_caller_address(contract_address, _non_admin, CheatSpan::TargetCalls(1)); + dispatcher.pause_contract(); +} + +#[test] +#[should_panic(expected: 'ONLY ADMIN')] +fn test_only_admin_can_unpause_contract() { + let admin = ADMIN(); + let non_admin = NON_ORG(); + + let (contract_address, dispatcher) = deploy_milestone_manager(admin); + + // Pause contract as admin + cheat_caller_address(contract_address, admin, CheatSpan::TargetCalls(1)); + dispatcher.pause_contract(); + + cheat_caller_address(contract_address, non_admin, CheatSpan::TargetCalls(1)); + dispatcher.unpause_contract(); +}