diff --git a/admin_sep/src/administratable.rs b/admin_sep/src/administratable.rs index 37c8a1c..8d13dc3 100644 --- a/admin_sep/src/administratable.rs +++ b/admin_sep/src/administratable.rs @@ -6,6 +6,14 @@ use soroban_sdk::{Address, Env, Symbol, symbol_short}; pub trait Administratable { fn admin(env: &Env) -> soroban_sdk::Address; fn set_admin(env: &Env, new_admin: &soroban_sdk::Address); + + #[internal] + fn require_admin(env: &Env) { + Self::admin(env).require_auth(); + } + + #[internal] + fn init(env: &Env, admin: &soroban_sdk::Address); } pub const STORAGE_KEY: Symbol = symbol_short!("A"); @@ -22,22 +30,14 @@ impl Administratable for Admin { unsafe { get(env).unwrap_unchecked() } } fn set_admin(env: &Env, new_admin: &soroban_sdk::Address) { - if let Some(address) = get(env) { - address.require_auth(); - } + Self::require_admin(env); env.storage().instance().set(&STORAGE_KEY, &new_admin); } -} - -pub trait AdminExt { - fn require_admin(e: &Env); -} -impl AdminExt for T -where - T: Administratable, -{ - fn require_admin(e: &Env) { - T::admin(e).require_auth(); + fn init(env: &Env, admin: &soroban_sdk::Address) { + if get(env).is_some() { + panic!("Admin already initialized"); + } + env.storage().instance().set(&STORAGE_KEY, &admin); } } diff --git a/admin_sep/src/constructor.rs b/admin_sep/src/constructor.rs index 634d2ee..f893907 100644 --- a/admin_sep/src/constructor.rs +++ b/admin_sep/src/constructor.rs @@ -9,7 +9,7 @@ pub trait Constructable: crate::Administratable { #[allow(unused_variables)] fn construct(env: &Env, args: T) {} fn constructor(env: &Env, args: T) { - Self::set_admin(env, args.admin()); + Self::init(env, args.admin()); Self::construct(env, args); } } diff --git a/admin_sep/src/upgradable.rs b/admin_sep/src/upgradable.rs index 81489a2..c5e11d7 100644 --- a/admin_sep/src/upgradable.rs +++ b/admin_sep/src/upgradable.rs @@ -19,7 +19,7 @@ impl Upgradable for Upgrader { impl Upgradable for AdministratableExt { type Impl = N; fn upgrade(env: &soroban_sdk::Env, wasm_hash: soroban_sdk::BytesN<32>) { - T::admin(env).require_auth(); + T::require_admin(env); N::upgrade(env, wasm_hash); } } diff --git a/contracttrait-macro/src/contracttrait.rs b/contracttrait-macro/src/contracttrait.rs index fd9ed4d..5cc9272 100644 --- a/contracttrait-macro/src/contracttrait.rs +++ b/contracttrait-macro/src/contracttrait.rs @@ -1,13 +1,15 @@ use deluxe::HasAttributes; use proc_macro2::{Ident, TokenStream}; -use quote::{format_ident, quote}; +use quote::{format_ident, quote, ToTokens}; use syn::{ - punctuated::Punctuated, Attribute, FnArg, Item, ItemTrait, PatType, Signature, Token, TraitItem, Type + punctuated::Punctuated, Attribute, FnArg, Item, ItemTrait, PatType, Signature, Token, + TraitItem, Type, }; use crate::{ args::{InnerArgs, MyMacroArgs, MyTraitMacroArgs}, error::Error, + util::has_attr, }; pub fn generate(args: &MyTraitMacroArgs, item: &Item) -> TokenStream { @@ -20,8 +22,8 @@ pub fn derive_contract(args: &MyMacroArgs, trait_impls: &Item) -> TokenStream { fn generate_method( (trait_item, item_trait): (&syn::TraitItem, &syn::ItemTrait), -) -> Option<(TokenStream, TokenStream)> { - let syn::TraitItem::Fn(method) = trait_item else { +) -> Option<(Option, TokenStream)> { + let syn::TraitItem::Fn(mut method) = trait_item.clone() else { return None; }; let sig = &method.sig; @@ -31,9 +33,22 @@ fn generate_method( }; let args = args_to_idents(&sig.inputs); let attrs = &method.attrs; + if has_attr(attrs, "internal") { + method.attrs = method + .attrs + .into_iter() + .filter(|attr| !attr.path().is_ident("internal")) + .collect::>(); + let method_stream = if method.default.is_none() { + generate_trait_method(&method, name, &args) + } else { + method.to_token_stream() + }; + return Some((None, method_stream)); + } Some(( - generate_static_method(item_trait, sig, attrs, name, &args), - generate_trait_method(sig, attrs, name, &args), + Some(generate_static_method(item_trait, sig, attrs, name, &args)), + generate_trait_method(&method, name, &args), )) } @@ -104,20 +119,14 @@ fn transform_type_and_call(ty: &Type, arg_name: &Ident) -> (TokenStream, TokenSt } } -fn generate_trait_method( - sig: &Signature, - attrs: &[Attribute], - name: &Ident, - args: &[&Ident], -) -> TokenStream { - let inputs = sig.inputs.iter(); - let output = &sig.output; - quote! { - #(#attrs)* - fn #name(#(#inputs),*) #output { +fn generate_trait_method(method: &syn::TraitItemFn, name: &Ident, args: &[&Ident]) -> TokenStream { + let mut method = method.clone(); + method.default = Some(syn::parse_quote! { + { Self::Impl::#name(#(#args),*) } - } + }); + method.to_token_stream() } fn inner_generate( @@ -289,7 +298,11 @@ mod tests { pub trait Administratable { /// Get current admin fn admin_get(env: Env) -> soroban_sdk::Address; - fn admin_set(env: Env, new_admin: soroban_sdk::Address); + fn admin_set(env: Env, new_admin: &soroban_sdk::Address); + #[internal] + fn require_auth(env: Env) { + Self::admin_get(env).require_auth(); + } } }; let default = Some(format_ident!("Admin")); @@ -309,9 +322,12 @@ mod tests { fn admin_get(env: Env) -> soroban_sdk::Address { Self::Impl::admin_get(env) } - fn admin_set(env: Env, new_admin: soroban_sdk::Address) { + fn admin_set(env: Env, new_admin: &soroban_sdk::Address) { Self::Impl::admin_set(env, new_admin) } + fn require_auth(env: Env) { + Self::admin_get(env).require_auth(); + } } #[macro_export] macro_rules! Administratable { @@ -336,7 +352,7 @@ mod tests { } pub fn admin_set(env: Env, new_admin: soroban_sdk::Address) { - < $contract_name as Administratable >::admin_set(env, new_admin) + < $contract_name as Administratable >::admin_set(env, &new_admin) } } }; @@ -354,7 +370,7 @@ mod tests { } pub fn admin_set(env: Env, new_admin: soroban_sdk::Address) { - < $contract_name as Administratable >::admin_set(env, new_admin) + < $contract_name as Administratable >::admin_set(env, &new_admin) } } }; diff --git a/contracttrait-macro/src/util.rs b/contracttrait-macro/src/util.rs index a6281c0..4df9ce7 100644 --- a/contracttrait-macro/src/util.rs +++ b/contracttrait-macro/src/util.rs @@ -14,6 +14,11 @@ pub(crate) fn p_e(e: std::io::Error) -> std::io::Error { e } +pub(crate) fn has_attr(attrs: &[syn::Attribute], ident_str: &str) -> bool { + attrs.iter().any(|attr| attr.path().is_ident(ident_str)) +} + + /// Format the given snippet. The snippet is expected to be *complete* code. /// When we cannot parse the given snippet, this function returns `None`. #[allow(unused)]