From ded87ff439587ddd99cfdd0015ccd5708dac6f9f Mon Sep 17 00:00:00 2001 From: evilg Date: Tue, 1 Oct 2024 12:44:58 -0400 Subject: [PATCH] Don't visit added imported functions during application iteration. (#194) * Don't visit added imported functions during application iteration. * Make sure all funcs are visited * Add test case for bug * fmt --- src/ir/module/mod.rs | 14 ++++ src/iterator/component_iterator.rs | 16 +--- src/iterator/module_iterator.rs | 33 +++------ src/subiterator/component_subiterator.rs | 11 +-- src/subiterator/module_subiterator.rs | 93 ++++++++++-------------- tests/func_builder.rs | 38 +++++++++- 6 files changed, 104 insertions(+), 101 deletions(-) diff --git a/src/ir/module/mod.rs b/src/ir/module/mod.rs index 65c8032..315d279 100644 --- a/src/ir/module/mod.rs +++ b/src/ir/module/mod.rs @@ -488,6 +488,20 @@ impl<'a> Module<'a> { }) } + /// Creates Vec of (Function, Number of Instructions) + pub fn get_func_metadata(&self) -> Vec<(FunctionID, usize)> { + let mut metadata = vec![]; + for func in self.functions.iter() { + match &func.kind { + FuncKind::Import(_) => {} + FuncKind::Local(LocalFunction { func_id, body, .. }) => { + metadata.push((*func_id, body.num_instructions)); + } + } + } + metadata + } + /// Emit the module into a wasm binary file. pub fn emit_wasm(&mut self, file_name: &str) -> Result<(), std::io::Error> { let module = self.encode_internal(); diff --git a/src/iterator/component_iterator.rs b/src/iterator/component_iterator.rs index 0f749c6..8c4e636 100644 --- a/src/iterator/component_iterator.rs +++ b/src/iterator/component_iterator.rs @@ -2,9 +2,8 @@ use crate::ir::component::Component; use crate::ir::id::{FunctionID, GlobalID, LocalID, ModuleID}; -use crate::ir::module::module_functions::{FuncKind, LocalFunction}; +use crate::ir::module::module_functions::FuncKind; use crate::ir::module::module_globals::Global; -use crate::ir::module::Iter; use crate::ir::types::{DataType, FuncInstrMode, InstrumentationMode, Location}; use crate::iterator::iterator_trait::{IteratingInstrumenter, Iterator}; use crate::module_builder::AddLocal; @@ -22,7 +21,7 @@ pub struct ComponentIterator<'a, 'b> { comp_iterator: ComponentSubIterator, } -fn print_metadata(metadata: &HashMap>) { +fn print_metadata(metadata: &HashMap>) { for c in metadata.keys() { println!("Module: {:?}", c); for (m, i) in metadata.get(c).unwrap().iter() { @@ -41,16 +40,7 @@ impl<'a, 'b> ComponentIterator<'a, 'b> { // Creates Module -> Function -> Number of Instructions let mut metadata = HashMap::new(); for (mod_idx, m) in comp.modules.iter().enumerate() { - let mut mod_metadata = HashMap::new(); - for func in m.functions.iter() { - match &func.kind { - FuncKind::Import(_) => {} - FuncKind::Local(LocalFunction { func_id, body, .. }) => { - mod_metadata.insert(*func_id, body.num_instructions); - } - } - } - metadata.insert(ModuleID(mod_idx as u32), mod_metadata); + metadata.insert(ModuleID(mod_idx as u32), m.get_func_metadata()); } print_metadata(&metadata); let num_modules = comp.num_modules; diff --git a/src/iterator/module_iterator.rs b/src/iterator/module_iterator.rs index a2a85e5..0ea0bd9 100644 --- a/src/iterator/module_iterator.rs +++ b/src/iterator/module_iterator.rs @@ -1,15 +1,14 @@ //! Iterator to traverse a Module use crate::ir::id::{FunctionID, GlobalID, LocalID}; -use crate::ir::module::module_functions::{FuncKind, LocalFunction}; +use crate::ir::module::module_functions::FuncKind; use crate::ir::module::module_globals::Global; -use crate::ir::module::{Iter, Module}; +use crate::ir::module::Module; use crate::ir::types::{DataType, FuncInstrMode, InstrumentationMode, Location}; use crate::iterator::iterator_trait::{IteratingInstrumenter, Iterator}; use crate::module_builder::AddLocal; use crate::opcode::{Inject, InjectAt, Instrumenter, MacroOpcode, Opcode}; use crate::subiterator::module_subiterator::ModuleSubIterator; -use std::collections::HashMap; use wasmparser::Operator; /// Iterator for a Module. @@ -25,27 +24,15 @@ pub struct ModuleIterator<'a, 'b> { impl<'a, 'b> ModuleIterator<'a, 'b> { /// Creates a new ModuleIterator pub fn new(module: &'a mut Module<'b>, skip_funcs: &Vec) -> Self { - // Creates Function -> Number of Instructions - let mut metadata = HashMap::new(); - for func in module.functions.iter() { - match &func.kind { - FuncKind::Import(_) => {} - FuncKind::Local(LocalFunction { func_id, body, .. }) => { - metadata.insert(*func_id, body.num_instructions); - } - } - } - let num_funcs = module.num_local_functions; + let metadata = module.get_func_metadata(); ModuleIterator { module, - mod_iterator: ModuleSubIterator::new(num_funcs, metadata, skip_funcs.to_owned()), + mod_iterator: ModuleSubIterator::new(metadata, skip_funcs.to_owned()), } } pub fn curr_op_owned(&self) -> Option> { - if self.mod_iterator.end() { - None - } else if let ( + if let ( Location::Module { func_idx, instr_idx, @@ -54,7 +41,7 @@ impl<'a, 'b> ModuleIterator<'a, 'b> { .., ) = self.mod_iterator.curr_loc() { - match &self.module.functions.get(func_idx as FunctionID).kind { + match &self.module.functions.get(func_idx).kind { FuncKind::Import(_) => panic!("Cannot get an instruction to an imported function"), FuncKind::Local(l) => Some(l.body.instructions[instr_idx].op.clone()), } @@ -247,7 +234,7 @@ impl<'a> Instrumenter<'a> for ModuleIterator<'_, 'a> { .. } = loc { - match self.module.functions.get_mut(func_idx as FunctionID).kind { + match self.module.functions.get_mut(func_idx).kind { FuncKind::Import(_) => panic!("Cannot instrument an imported function"), FuncKind::Local(ref mut l) => { l.body.instructions[instr_idx].instr_flag.alternate = Some(vec![]) @@ -343,9 +330,7 @@ impl<'a> Iterator for ModuleIterator<'_, 'a> { /// Returns the current instruction fn curr_op(&self) -> Option<&Operator<'a>> { - if self.mod_iterator.end() { - None - } else if let ( + if let ( Location::Module { func_idx, instr_idx, @@ -354,7 +339,7 @@ impl<'a> Iterator for ModuleIterator<'_, 'a> { .., ) = self.mod_iterator.curr_loc() { - match &self.module.functions.get(func_idx as FunctionID).kind { + match &self.module.functions.get(func_idx).kind { FuncKind::Import(_) => panic!("Cannot get an instruction to an imported function"), FuncKind::Local(l) => Some(&l.body.instructions[instr_idx].op), } diff --git a/src/subiterator/component_subiterator.rs b/src/subiterator/component_subiterator.rs index e874b6f..74f3503 100644 --- a/src/subiterator/component_subiterator.rs +++ b/src/subiterator/component_subiterator.rs @@ -13,8 +13,8 @@ pub struct ComponentSubIterator { num_mods: usize, /// The module iterator used to keep track of the location in the module. pub(crate) mod_iterator: ModuleSubIterator, - /// Metadata that maps Module Index -> Function Index -> Instruction Index - metadata: HashMap>, + /// Metadata that maps Module Index -> Vec<(Function Index, Instruction Index)> + metadata: HashMap>, /// Map of Module -> Functions to skip in that module. Provide an empty HashMap if no functions are to be skipped. skip_funcs: HashMap>, } @@ -24,7 +24,7 @@ impl ComponentSubIterator { pub fn new( curr_mod: ModuleID, num_mods: usize, - metadata: HashMap>, + metadata: HashMap>, skip_funcs: HashMap>, ) -> Self { // Get current skip func @@ -34,7 +34,6 @@ impl ComponentSubIterator { num_mods, metadata: metadata.clone(), mod_iterator: ModuleSubIterator::new( - metadata.get(&ModuleID(0)).unwrap().keys().len() as u32, (*metadata.get(&curr_mod).unwrap()).clone(), match skip_funcs.contains_key(&curr_mod) { true => skip_funcs.get(&curr_mod).unwrap().clone(), @@ -56,11 +55,9 @@ impl ComponentSubIterator { fn next_module(&mut self) -> bool { *self.curr_mod += 1; if *self.curr_mod < self.num_mods as u32 { - let num_funcs = self.metadata.get(&self.curr_mod).unwrap().keys().len() as u32; let met = self.metadata.get(&self.curr_mod).unwrap().clone(); // If we're defining a new module, we have to reset function self.mod_iterator = ModuleSubIterator::new( - num_funcs, met, match self.skip_funcs.contains_key(&self.curr_mod) { true => self.skip_funcs.get(&self.curr_mod).unwrap().clone(), @@ -80,7 +77,7 @@ impl ComponentSubIterator { /// Gets the index of the current function in the current module pub fn curr_func_idx(&self) -> FunctionID { - self.mod_iterator.curr_func + self.mod_iterator.get_curr_func().0 } /// Gets the index of the current instruction in the current function diff --git a/src/subiterator/module_subiterator.rs b/src/subiterator/module_subiterator.rs index 21465f5..68d9337 100644 --- a/src/subiterator/module_subiterator.rs +++ b/src/subiterator/module_subiterator.rs @@ -3,18 +3,13 @@ use crate::ir::id::FunctionID; use crate::ir::types::Location; use crate::subiterator::function_subiterator::FuncSubIterator; -use std::collections::HashMap; /// Sub-iterator for a Module. Keeps track of current location in a Module. pub struct ModuleSubIterator { /// The current function the SubIterator is at - pub(crate) curr_func: FunctionID, - /// The number of functions that have been visited thus far - visited_funcs: u32, - /// Number of functions in this module - num_funcs: u32, - /// Metadata that maps Function Index -> Instruction Index - metadata: HashMap, + pub(crate) curr_idx: usize, + /// Metadata containing a functions index and number_of_instructions + metadata: Vec<(FunctionID, usize)>, /// The function iterator used to keep track of the location in the function. pub(crate) func_iterator: FuncSubIterator, /// Functions to skip. Provide an empty vector if no functions are to be skipped. @@ -23,34 +18,23 @@ pub struct ModuleSubIterator { impl ModuleSubIterator { /// Creates a new ModuleSubIterator - pub fn new( - num_funcs: u32, - metadata: HashMap, - skip_funcs: Vec, - ) -> Self { + pub fn new(metadata: Vec<(FunctionID, usize)>, skip_funcs: Vec) -> Self { + let curr_idx = 0; + + let (_curr_fid, curr_num_instrs) = metadata[curr_idx]; let mut mod_it = ModuleSubIterator { - curr_func: *metadata.keys().min().unwrap(), - visited_funcs: 0, - num_funcs, - metadata: metadata.clone(), - func_iterator: FuncSubIterator::new( - *metadata.get(metadata.keys().min().unwrap()).unwrap(), - ), + curr_idx, + metadata, + func_iterator: FuncSubIterator::new(curr_num_instrs), skip_funcs, }; - // In case 0 is in skip func - while mod_it - .skip_funcs - .contains(&(mod_it.curr_func as FunctionID)) - { - mod_it.next_function(); - } + mod_it.handle_skips(); + mod_it } - /// Checks if the SubIterator has finished traversing all the functions - pub fn end(&self) -> bool { - self.visited_funcs == self.num_funcs + pub fn get_curr_func(&self) -> (FunctionID, usize) { + self.metadata[self.curr_idx] } /// Returns the current Location in the Module as a Location @@ -59,7 +43,7 @@ impl ModuleSubIterator { let curr_instr = self.func_iterator.curr_instr; ( Location::Module { - func_idx: self.curr_func, + func_idx: self.get_curr_func().0, instr_idx: curr_instr, }, self.func_iterator.is_end(curr_instr), @@ -67,49 +51,48 @@ impl ModuleSubIterator { } /// Resets the ModuleSubIterator when it is a Child SubIterator of a ComponentSubIterator - pub(crate) fn reset_from_comp_iterator(&mut self, metadata: HashMap) { - *self.curr_func = 0; + pub(crate) fn reset_from_comp_iterator(&mut self, metadata: Vec<(FunctionID, usize)>) { self.metadata = metadata; - self.func_iterator.reset( - *self - .metadata - .get(self.metadata.keys().min().unwrap()) - .unwrap(), - ); + self.reset(); } /// Resets the ModuleSubIterator when it is not a Child SubIterator pub fn reset(&mut self) { - *self.curr_func = 0; - self.func_iterator - .reset(*self.metadata.get(&FunctionID(0)).unwrap()); + self.curr_idx = 0; + self.handle_skips(); + self.func_iterator.reset(self.get_curr_func().1); } - /// Checks if there are functions left to visit - pub fn has_next_function(&self) -> bool { - *self.curr_func + 1 < self.num_funcs + fn handle_skips(&mut self) { + let mut curr_fid = self.get_curr_func().0; + while self.has_next_function() && self.skip_funcs.contains(&curr_fid) { + self.curr_idx += 1; + curr_fid = self.get_curr_func().0; + } } /// Goes to the next function in the module fn next_function(&mut self) -> bool { - *self.curr_func += 1; - self.visited_funcs += 1; + if !self.has_next_function() { + return false; + } + self.curr_idx += 1; // skip over configured funcs - while self.visited_funcs < self.num_funcs - && self.skip_funcs.contains(&(self.curr_func as FunctionID)) - { - *self.curr_func += 1; - self.visited_funcs += 1; - } - if self.visited_funcs < self.num_funcs { - self.func_iterator = FuncSubIterator::new(*self.metadata.get(&self.curr_func).unwrap()); + self.handle_skips(); + if self.curr_idx < self.metadata.len() { + self.func_iterator = FuncSubIterator::new(self.get_curr_func().1); true } else { false } } + /// Checks if there are functions left to visit + pub fn has_next_function(&self) -> bool { + self.curr_idx + 1 < self.metadata.len() + } + /// Checks if there are functions left to visit pub(crate) fn has_next(&self) -> bool { self.func_iterator.has_next() || self.has_next_function() diff --git a/tests/func_builder.rs b/tests/func_builder.rs index 47cd1ea..5935dc5 100644 --- a/tests/func_builder.rs +++ b/tests/func_builder.rs @@ -1,6 +1,10 @@ -use orca_wasm::ir::id::FunctionID; +use log::trace; +use orca_wasm::ir::function::FunctionBuilder; +use orca_wasm::ir::id::{FunctionID, TypeID}; +use orca_wasm::iterator::iterator_trait::Iterator; +use orca_wasm::iterator::module_iterator::ModuleIterator; use orca_wasm::opcode::Instrumenter; -use orca_wasm::Opcode; +use orca_wasm::{DataType, Opcode}; use orca_wasm::{Location, Module}; use std::process::Command; @@ -61,3 +65,33 @@ fn run_start_orca_default() { let out = wasmprinter::print_bytes(result.clone()).expect("couldn't translate Wasm to wat"); println!("{}", out); } +#[test] +// test start function instrumentation with FunctionModifier +fn add_import_and_local_fn_then_iterate() { + let file_name = "tests/test_inputs/handwritten/modules/_start.wat"; + let wasm = wat::parse_file(file_name).expect("couldn't convert the input wat to Wasm"); + let mut module = Module::parse(&wasm, false).expect("Unable to parse"); + // add an imported function AND THEN a new local function + module.add_import_func("new".to_string(), "import".to_string(), TypeID(0)); + assert_eq!(module.num_import_func(), 1); + + let params = vec![]; + let results = vec![DataType::I32]; + + let mut new_func = FunctionBuilder::new(¶ms, &results); + new_func.i32_const(1); + new_func.finish_module(&mut module); + + // now iterate over module + let mut mod_it = ModuleIterator::new(&mut module, &vec![]); + loop { + let _op = mod_it.curr_op(); + if mod_it.next().is_none() { + break; + }; + } + + let result = module.encode(); + let out = wasmprinter::print_bytes(result.clone()).expect("couldn't translate Wasm to wat"); + println!("{}", out); +}