diff --git a/src/access_list.rs b/src/access_list.rs index 7f39d56..9945cce 100644 --- a/src/access_list.rs +++ b/src/access_list.rs @@ -1,7 +1,7 @@ use alloc::collections::BTreeSet; use alloy_primitives::{ map::{HashMap, HashSet}, - Address, B256, + Address, TxKind, B256, }; use alloy_rpc_types_eth::{AccessList, AccessListItem}; use revm::{ @@ -20,18 +20,19 @@ pub struct AccessListInspector { access_list: HashMap>, } +impl From for AccessListInspector { + fn from(access_list: AccessList) -> Self { + Self::new(access_list) + } +} + impl AccessListInspector { /// Creates a new inspector instance /// /// The `access_list` is the provided access list from the call request - pub fn new( - access_list: AccessList, - from: Address, - to: Address, - precompiles: impl IntoIterator, - ) -> Self { + pub fn new(access_list: AccessList) -> Self { Self { - excluded: [from, to].into_iter().chain(precompiles).collect(), + excluded: Default::default(), access_list: access_list .0 .into_iter() @@ -59,12 +60,67 @@ impl AccessListInspector { }); AccessList(items.collect()) } + + /// Collects addresses which should be excluded from the access list. Must be called before the + /// top-level call. + /// + /// Those include caller, callee and precompiles. + fn collect_excluded_addresses(&mut self, context: &EvmContext) { + let from = context.env.tx.caller; + let to = if let TxKind::Call(to) = context.env.tx.transact_to { + to + } else { + // We need to exclude the created address if this is a CREATE frame. + // + // This assumes that caller has already been loaded but nonce was not increased yet. + let nonce = context.journaled_state.account(from).info.nonce; + from.create(nonce) + }; + let precompiles = context.precompiles.addresses().copied(); + self.excluded = [from, to].into_iter().chain(precompiles).collect(); + } } impl Inspector for AccessListInspector where DB: Database, { + fn call( + &mut self, + context: &mut EvmContext, + _inputs: &mut revm::interpreter::CallInputs, + ) -> Option { + // At the top-level frame, fill the excluded addresses + if context.journaled_state.depth() == 0 { + self.collect_excluded_addresses(context) + } + None + } + + fn create( + &mut self, + context: &mut EvmContext, + _inputs: &mut revm::interpreter::CreateInputs, + ) -> Option { + // At the top-level frame, fill the excluded addresses + if context.journaled_state.depth() == 0 { + self.collect_excluded_addresses(context) + } + None + } + + fn eofcreate( + &mut self, + context: &mut EvmContext, + _inputs: &mut revm::interpreter::EOFCreateInputs, + ) -> Option { + // At the top-level frame, fill the excluded addresses + if context.journaled_state.depth() == 0 { + self.collect_excluded_addresses(context) + } + None + } + fn step(&mut self, interp: &mut Interpreter, _context: &mut EvmContext) { match interp.current_opcode() { opcode::SLOAD | opcode::SSTORE => {