Skip to content

Commit

Permalink
feat: static call (#36)
Browse files Browse the repository at this point in the history
Co-authored-by: Leo <leo@powdrlabs.com>
  • Loading branch information
0xrusowsky and leonardoalt authored Feb 17, 2025
1 parent 3b90551 commit d288031
Show file tree
Hide file tree
Showing 13 changed files with 403 additions and 283 deletions.
308 changes: 150 additions & 158 deletions contract-derive/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@ use std::error::Error;

use alloy_core::primitives::keccak256;
use alloy_dyn_abi::DynSolType;
use alloy_sol_types::SolType;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{
parse::{Parse, ParseStream},
spanned::Spanned,
FnArg, Ident, ImplItemMethod, Meta, ReturnType, TraitItemMethod, Type,
parse::{Parse, ParseStream}, FnArg, Ident, ImplItemMethod, LitStr, ReturnType, TraitItemMethod, Type
};

// Unified method info from `ImplItemMethod` and `TraitItemMethod`
#[derive(Clone)]
pub struct MethodInfo<'a> {
name: &'a Ident,
args: Vec<syn::FnArg>,
Expand All @@ -37,6 +36,16 @@ impl<'a> From<&'a TraitItemMethod> for MethodInfo<'a> {
}
}

impl<'a> MethodInfo<'a> {
pub fn is_mutable(&self) -> bool {
match self.args.first() {
Some(FnArg::Receiver(receiver)) => receiver.mutability.is_some(),
Some(FnArg::Typed(_)) => panic!("First argument must be self"),
None => panic!("Expected `self` as the first arg"),
}
}
}

// Helper function to get the parameter names + types of a method
fn get_arg_props<'a>(
skip_first_arg: bool,
Expand Down Expand Up @@ -65,97 +74,36 @@ pub fn get_arg_props_all<'a>(method: &'a MethodInfo<'a>) -> (Vec<Ident>, Vec<&sy
get_arg_props(false, method)
}

#[derive(Debug, Clone, Copy, PartialEq)]
pub enum InterfaceCompilationTarget {
R55,
EVM,
}

#[derive(Debug, Clone, Copy, PartialEq)]
pub enum InterfaceNamingStyle {
CamelCase,
}

pub struct InterfaceArgs {
pub target: InterfaceCompilationTarget,
pub rename: Option<InterfaceNamingStyle>,
}

impl Parse for InterfaceArgs {
fn parse(input: ParseStream) -> Result<Self, syn::Error> {
// default arg values if uninformed
let mut target = InterfaceCompilationTarget::R55;
let mut rename_style = None;

// Parse all attributes
let mut first = true;
while !input.is_empty() {
if !first {
input.parse::<syn::Token![,]>()?;
}
first = false;

let key: syn::Ident = input.parse()?;

match key.to_string().as_str() {
"rename" => {
input.parse::<syn::Token![=]>()?;

// Handle both string literals and identifiers
let value = if input.peek(syn::LitStr) {
let lit: syn::LitStr = input.parse()?;
lit.value()
} else {
let ident: syn::Ident = input.parse()?;
ident.to_string()
};

rename_style = Some(match value.as_str() {
"camelCase" => InterfaceNamingStyle::CamelCase,
invalid => {
return Err(syn::Error::new(
key.span(),
format!("unsupported rename style: {}", invalid),
))
}
});
}
"target" => {
input.parse::<syn::Token![=]>()?;

// Handle both string literals and identifiers
let value = if input.peek(syn::LitStr) {
let lit: syn::LitStr = input.parse()?;
lit.value()
} else {
let ident: syn::Ident = input.parse()?;
ident.to_string()
};

target = match value.as_str() {
"r55" => InterfaceCompilationTarget::R55,
"evm" => InterfaceCompilationTarget::EVM,
invalid => {
return Err(syn::Error::new(
key.span(),
format!("unsupported compilation target: {}", invalid),
))
}
};
}
invalid => {
return Err(syn::Error::new(
key.span(),
format!("unknown attribute: {}", invalid),
))
}
let rename_style = if !input.is_empty() {
let value = if input.peek(LitStr) {
input.parse::<LitStr>()?.value()
} else {
input.parse::<Ident>()?.to_string()
};

match value.as_str() {
"camelCase" => Some(InterfaceNamingStyle::CamelCase),
invalid => return Err(syn::Error::new(
input.span(),
format!("unsupported style: {}. Only 'camelCase' is supported", invalid)
))
}
}
} else {
None
};

Ok(InterfaceArgs {
target,
rename: rename_style,
})
Ok(InterfaceArgs { rename: rename_style })
}
}

Expand All @@ -164,109 +112,152 @@ pub fn generate_interface<T>(
methods: &[&T],
interface_name: &Ident,
interface_style: Option<InterfaceNamingStyle>,
interface_target: InterfaceCompilationTarget,
) -> quote::__private::TokenStream
where
for<'a> MethodInfo<'a>: From<&'a T>,
{
let methods: Vec<MethodInfo> = methods.iter().map(|&m| MethodInfo::from(m)).collect();
let (mut_methods, immut_methods): (Vec<MethodInfo>, Vec<MethodInfo>) =
methods.into_iter().partition(|m| m.is_mutable());

// Generate implementation
let method_impls = methods.iter().map(|method| {
let name = method.name;
let return_type = method.return_type;
let method_selector = match interface_target {
InterfaceCompilationTarget::R55 => u32::from_be_bytes(
generate_selector_r55(method, interface_style).expect("Unable to generate r55 fn selector"),
),
InterfaceCompilationTarget::EVM => u32::from_be_bytes(
generate_selector_evm(method, interface_style).expect("Unable to generate evm fn selector"),
),
};
// Generate implementations
let mut_method_impls = mut_methods
.iter()
.map(|method| generate_method_impl(method, interface_style, true));
let immut_method_impls = immut_methods
.iter()
.map(|method| generate_method_impl(method, interface_style, false));

let (arg_names, arg_types) = get_arg_props_skip_first(method);
quote! {
use core::marker::PhantomData;
pub struct #interface_name<C: CallCtx> {
address: Address,
_ctx: PhantomData<C>
}

let calldata = if arg_names.is_empty() {
quote! {
let mut complete_calldata = Vec::with_capacity(4);
complete_calldata.extend_from_slice(&[
#method_selector.to_be_bytes()[0],
#method_selector.to_be_bytes()[1],
#method_selector.to_be_bytes()[2],
#method_selector.to_be_bytes()[3],
]);
}
} else {
quote! {
let mut args_calldata = (#(#arg_names),*).abi_encode();
let mut complete_calldata = Vec::with_capacity(4 + args_calldata.len());
complete_calldata.extend_from_slice(&[
#method_selector.to_be_bytes()[0],
#method_selector.to_be_bytes()[1],
#method_selector.to_be_bytes()[2],
#method_selector.to_be_bytes()[3],
]);
complete_calldata.append(&mut args_calldata);
impl InitInterface for #interface_name<ReadOnly> {
fn new(address: Address) -> InterfaceBuilder<Self> {
InterfaceBuilder {
address,
_phantom: PhantomData
}
}
};

let return_ty = match return_type {
ReturnType::Default => quote! { () },
ReturnType::Type(_, ty) => quote! { #ty },
};

quote! {
pub fn #name(&self, #(#arg_names: #arg_types),*) -> Option<#return_ty> {
use alloy_sol_types::SolValue;
use alloc::vec::Vec;
}

#calldata
// Implement conversion between interface types
impl<C: CallCtx> IntoInterface<#interface_name<C>> for #interface_name<ReadOnly> {
fn into_interface(self) -> #interface_name<C> {
#interface_name {
address: self.address,
_ctx: PhantomData
}
}
}

let result = eth_riscv_runtime::call_contract(
self.address,
0_u64,
&complete_calldata,
None
)?;
impl<C: CallCtx> FromBuilder for #interface_name<C> {
type Context = C;

<#return_ty>::abi_decode(&result, true).ok()
fn from_builder(builder: InterfaceBuilder<Self>) -> Self {
Self {
address: builder.address,
_ctx: PhantomData
}
}
}
});

quote! {
pub struct #interface_name {
address: Address,
impl<C: StaticCtx> #interface_name<C> {
#(#immut_method_impls)*
}

impl #interface_name {
pub fn new(address: Address) -> Self {
Self { address }
}

#(#method_impls)*
impl<C: MutableCtx> #interface_name<C> {
#(#mut_method_impls)*
}
}
}

// Helper function to generate fn selector for r55 contracts
pub fn generate_selector_r55(method: &MethodInfo, style: Option<InterfaceNamingStyle>) -> Option<[u8; 4]> {
let name = match style {
None => method.name.to_string(),
Some(style) => match style {
InterfaceNamingStyle::CamelCase => to_camel_case(method.name.to_string()),
fn generate_method_impl(
method: &MethodInfo,
interface_style: Option<InterfaceNamingStyle>,
is_mutable: bool,
) -> TokenStream {
let name = method.name;
let return_type = method.return_type;
let method_selector = u32::from_be_bytes(
generate_fn_selector(method, interface_style).expect("Unable to generate fn selector")
);

let (arg_names, arg_types) = get_arg_props_skip_first(method);

let calldata = if arg_names.is_empty() {
quote! {
let mut complete_calldata = Vec::with_capacity(4);
complete_calldata.extend_from_slice(&[
#method_selector.to_be_bytes()[0],
#method_selector.to_be_bytes()[1],
#method_selector.to_be_bytes()[2],
#method_selector.to_be_bytes()[3],
]);
}
} else {
quote! {
let mut args_calldata = (#(#arg_names),*).abi_encode();
let mut complete_calldata = Vec::with_capacity(4 + args_calldata.len());
complete_calldata.extend_from_slice(&[
#method_selector.to_be_bytes()[0],
#method_selector.to_be_bytes()[1],
#method_selector.to_be_bytes()[2],
#method_selector.to_be_bytes()[3],
]);
complete_calldata.append(&mut args_calldata);
}
};
keccak256(name)[..4].try_into().ok()

let (call_fn, self_param) = if is_mutable {
(
quote! { eth_riscv_runtime::call_contract },
quote! { &mut self },
)
} else {
(
quote! { eth_riscv_runtime::staticcall_contract },
quote! { &self},
)
};

let return_ty = match return_type {
ReturnType::Default => quote! { () },
ReturnType::Type(_, ty) => quote! { #ty },
};

quote! {
pub fn #name(#self_param, #(#arg_names: #arg_types),*) -> Option<#return_ty> {
use alloy_sol_types::SolValue;
use alloc::vec::Vec;

#calldata

let result = #call_fn(
self.address,
0_u64,
&complete_calldata,
None
)?;

<#return_ty>::abi_decode(&result, true).ok()
}
}
}

// Helper function to generate fn selector for evm contracts
pub fn generate_selector_evm(method: &MethodInfo, style: Option<InterfaceNamingStyle>) -> Option<[u8; 4]> {
// Helper function to generate fn selector
pub fn generate_fn_selector(
method: &MethodInfo,
style: Option<InterfaceNamingStyle>,
) -> Option<[u8; 4]> {
let name = match style {
None => method.name.to_string(),
Some(style) => match style {
InterfaceNamingStyle::CamelCase => to_camel_case(method.name.to_string()),
}
},
};

let (_, arg_types) = get_arg_props_skip_first(method);
Expand All @@ -282,7 +273,8 @@ pub fn generate_selector_evm(method: &MethodInfo, style: Option<InterfaceNamingS
.join(",");

let selector = format!("{}({})", name, args_str);
keccak256(selector)[..4].try_into().ok()
let selector_bytes = keccak256(selector.as_bytes())[..4].try_into().ok()?;
Some(selector_bytes)
}

// Helper function to convert rust types to their solidity equivalent
Expand Down Expand Up @@ -407,7 +399,7 @@ fn to_camel_case(s: String) -> String {
capitalize_next = true;
}
}

result
}

Expand Down
Loading

0 comments on commit d288031

Please sign in to comment.