Skip to content

Commit

Permalink
Don't visit added imported functions during application iteration. (#194
Browse files Browse the repository at this point in the history
)

* Don't visit added imported functions during application iteration.

* Make sure all funcs are visited

* Add test case for bug

* fmt
  • Loading branch information
ejrgilbert authored Oct 1, 2024
1 parent 61d5c79 commit ded87ff
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 101 deletions.
14 changes: 14 additions & 0 deletions src/ir/module/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,20 @@ impl<'a> Module<'a> {
})
}

/// Creates Vec of (Function, Number of Instructions)
pub fn get_func_metadata(&self) -> Vec<(FunctionID, usize)> {
let mut metadata = vec![];
for func in self.functions.iter() {
match &func.kind {
FuncKind::Import(_) => {}
FuncKind::Local(LocalFunction { func_id, body, .. }) => {
metadata.push((*func_id, body.num_instructions));
}
}
}
metadata
}

/// Emit the module into a wasm binary file.
pub fn emit_wasm(&mut self, file_name: &str) -> Result<(), std::io::Error> {
let module = self.encode_internal();
Expand Down
16 changes: 3 additions & 13 deletions src/iterator/component_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

use crate::ir::component::Component;
use crate::ir::id::{FunctionID, GlobalID, LocalID, ModuleID};
use crate::ir::module::module_functions::{FuncKind, LocalFunction};
use crate::ir::module::module_functions::FuncKind;
use crate::ir::module::module_globals::Global;
use crate::ir::module::Iter;
use crate::ir::types::{DataType, FuncInstrMode, InstrumentationMode, Location};
use crate::iterator::iterator_trait::{IteratingInstrumenter, Iterator};
use crate::module_builder::AddLocal;
Expand All @@ -22,7 +21,7 @@ pub struct ComponentIterator<'a, 'b> {
comp_iterator: ComponentSubIterator,
}

fn print_metadata(metadata: &HashMap<ModuleID, HashMap<FunctionID, usize>>) {
fn print_metadata(metadata: &HashMap<ModuleID, Vec<(FunctionID, usize)>>) {
for c in metadata.keys() {
println!("Module: {:?}", c);
for (m, i) in metadata.get(c).unwrap().iter() {
Expand All @@ -41,16 +40,7 @@ impl<'a, 'b> ComponentIterator<'a, 'b> {
// Creates Module -> Function -> Number of Instructions
let mut metadata = HashMap::new();
for (mod_idx, m) in comp.modules.iter().enumerate() {
let mut mod_metadata = HashMap::new();
for func in m.functions.iter() {
match &func.kind {
FuncKind::Import(_) => {}
FuncKind::Local(LocalFunction { func_id, body, .. }) => {
mod_metadata.insert(*func_id, body.num_instructions);
}
}
}
metadata.insert(ModuleID(mod_idx as u32), mod_metadata);
metadata.insert(ModuleID(mod_idx as u32), m.get_func_metadata());
}
print_metadata(&metadata);
let num_modules = comp.num_modules;
Expand Down
33 changes: 9 additions & 24 deletions src/iterator/module_iterator.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
//! Iterator to traverse a Module

use crate::ir::id::{FunctionID, GlobalID, LocalID};
use crate::ir::module::module_functions::{FuncKind, LocalFunction};
use crate::ir::module::module_functions::FuncKind;
use crate::ir::module::module_globals::Global;
use crate::ir::module::{Iter, Module};
use crate::ir::module::Module;
use crate::ir::types::{DataType, FuncInstrMode, InstrumentationMode, Location};
use crate::iterator::iterator_trait::{IteratingInstrumenter, Iterator};
use crate::module_builder::AddLocal;
use crate::opcode::{Inject, InjectAt, Instrumenter, MacroOpcode, Opcode};
use crate::subiterator::module_subiterator::ModuleSubIterator;
use std::collections::HashMap;
use wasmparser::Operator;

/// Iterator for a Module.
Expand All @@ -25,27 +24,15 @@ pub struct ModuleIterator<'a, 'b> {
impl<'a, 'b> ModuleIterator<'a, 'b> {
/// Creates a new ModuleIterator
pub fn new(module: &'a mut Module<'b>, skip_funcs: &Vec<FunctionID>) -> Self {
// Creates Function -> Number of Instructions
let mut metadata = HashMap::new();
for func in module.functions.iter() {
match &func.kind {
FuncKind::Import(_) => {}
FuncKind::Local(LocalFunction { func_id, body, .. }) => {
metadata.insert(*func_id, body.num_instructions);
}
}
}
let num_funcs = module.num_local_functions;
let metadata = module.get_func_metadata();
ModuleIterator {
module,
mod_iterator: ModuleSubIterator::new(num_funcs, metadata, skip_funcs.to_owned()),
mod_iterator: ModuleSubIterator::new(metadata, skip_funcs.to_owned()),
}
}

pub fn curr_op_owned(&self) -> Option<Operator<'b>> {
if self.mod_iterator.end() {
None
} else if let (
if let (
Location::Module {
func_idx,
instr_idx,
Expand All @@ -54,7 +41,7 @@ impl<'a, 'b> ModuleIterator<'a, 'b> {
..,
) = self.mod_iterator.curr_loc()
{
match &self.module.functions.get(func_idx as FunctionID).kind {
match &self.module.functions.get(func_idx).kind {
FuncKind::Import(_) => panic!("Cannot get an instruction to an imported function"),
FuncKind::Local(l) => Some(l.body.instructions[instr_idx].op.clone()),
}
Expand Down Expand Up @@ -247,7 +234,7 @@ impl<'a> Instrumenter<'a> for ModuleIterator<'_, 'a> {
..
} = loc
{
match self.module.functions.get_mut(func_idx as FunctionID).kind {
match self.module.functions.get_mut(func_idx).kind {
FuncKind::Import(_) => panic!("Cannot instrument an imported function"),
FuncKind::Local(ref mut l) => {
l.body.instructions[instr_idx].instr_flag.alternate = Some(vec![])
Expand Down Expand Up @@ -343,9 +330,7 @@ impl<'a> Iterator for ModuleIterator<'_, 'a> {

/// Returns the current instruction
fn curr_op(&self) -> Option<&Operator<'a>> {
if self.mod_iterator.end() {
None
} else if let (
if let (
Location::Module {
func_idx,
instr_idx,
Expand All @@ -354,7 +339,7 @@ impl<'a> Iterator for ModuleIterator<'_, 'a> {
..,
) = self.mod_iterator.curr_loc()
{
match &self.module.functions.get(func_idx as FunctionID).kind {
match &self.module.functions.get(func_idx).kind {
FuncKind::Import(_) => panic!("Cannot get an instruction to an imported function"),
FuncKind::Local(l) => Some(&l.body.instructions[instr_idx].op),
}
Expand Down
11 changes: 4 additions & 7 deletions src/subiterator/component_subiterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ pub struct ComponentSubIterator {
num_mods: usize,
/// The module iterator used to keep track of the location in the module.
pub(crate) mod_iterator: ModuleSubIterator,
/// Metadata that maps Module Index -> Function Index -> Instruction Index
metadata: HashMap<ModuleID, HashMap<FunctionID, usize>>,
/// Metadata that maps Module Index -> Vec<(Function Index, Instruction Index)>
metadata: HashMap<ModuleID, Vec<(FunctionID, usize)>>,
/// Map of Module -> Functions to skip in that module. Provide an empty HashMap if no functions are to be skipped.
skip_funcs: HashMap<ModuleID, Vec<FunctionID>>,
}
Expand All @@ -24,7 +24,7 @@ impl ComponentSubIterator {
pub fn new(
curr_mod: ModuleID,
num_mods: usize,
metadata: HashMap<ModuleID, HashMap<FunctionID, usize>>,
metadata: HashMap<ModuleID, Vec<(FunctionID, usize)>>,
skip_funcs: HashMap<ModuleID, Vec<FunctionID>>,
) -> Self {
// Get current skip func
Expand All @@ -34,7 +34,6 @@ impl ComponentSubIterator {
num_mods,
metadata: metadata.clone(),
mod_iterator: ModuleSubIterator::new(
metadata.get(&ModuleID(0)).unwrap().keys().len() as u32,
(*metadata.get(&curr_mod).unwrap()).clone(),
match skip_funcs.contains_key(&curr_mod) {
true => skip_funcs.get(&curr_mod).unwrap().clone(),
Expand All @@ -56,11 +55,9 @@ impl ComponentSubIterator {
fn next_module(&mut self) -> bool {
*self.curr_mod += 1;
if *self.curr_mod < self.num_mods as u32 {
let num_funcs = self.metadata.get(&self.curr_mod).unwrap().keys().len() as u32;
let met = self.metadata.get(&self.curr_mod).unwrap().clone();
// If we're defining a new module, we have to reset function
self.mod_iterator = ModuleSubIterator::new(
num_funcs,
met,
match self.skip_funcs.contains_key(&self.curr_mod) {
true => self.skip_funcs.get(&self.curr_mod).unwrap().clone(),
Expand All @@ -80,7 +77,7 @@ impl ComponentSubIterator {

/// Gets the index of the current function in the current module
pub fn curr_func_idx(&self) -> FunctionID {
self.mod_iterator.curr_func
self.mod_iterator.get_curr_func().0
}

/// Gets the index of the current instruction in the current function
Expand Down
93 changes: 38 additions & 55 deletions src/subiterator/module_subiterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,13 @@
use crate::ir::id::FunctionID;
use crate::ir::types::Location;
use crate::subiterator::function_subiterator::FuncSubIterator;
use std::collections::HashMap;

/// Sub-iterator for a Module. Keeps track of current location in a Module.
pub struct ModuleSubIterator {
/// The current function the SubIterator is at
pub(crate) curr_func: FunctionID,
/// The number of functions that have been visited thus far
visited_funcs: u32,
/// Number of functions in this module
num_funcs: u32,
/// Metadata that maps Function Index -> Instruction Index
metadata: HashMap<FunctionID, usize>,
pub(crate) curr_idx: usize,
/// Metadata containing a functions index and number_of_instructions
metadata: Vec<(FunctionID, usize)>,
/// The function iterator used to keep track of the location in the function.
pub(crate) func_iterator: FuncSubIterator,
/// Functions to skip. Provide an empty vector if no functions are to be skipped.
Expand All @@ -23,34 +18,23 @@ pub struct ModuleSubIterator {

impl ModuleSubIterator {
/// Creates a new ModuleSubIterator
pub fn new(
num_funcs: u32,
metadata: HashMap<FunctionID, usize>,
skip_funcs: Vec<FunctionID>,
) -> Self {
pub fn new(metadata: Vec<(FunctionID, usize)>, skip_funcs: Vec<FunctionID>) -> Self {
let curr_idx = 0;

let (_curr_fid, curr_num_instrs) = metadata[curr_idx];
let mut mod_it = ModuleSubIterator {
curr_func: *metadata.keys().min().unwrap(),
visited_funcs: 0,
num_funcs,
metadata: metadata.clone(),
func_iterator: FuncSubIterator::new(
*metadata.get(metadata.keys().min().unwrap()).unwrap(),
),
curr_idx,
metadata,
func_iterator: FuncSubIterator::new(curr_num_instrs),
skip_funcs,
};
// In case 0 is in skip func
while mod_it
.skip_funcs
.contains(&(mod_it.curr_func as FunctionID))
{
mod_it.next_function();
}
mod_it.handle_skips();

mod_it
}

/// Checks if the SubIterator has finished traversing all the functions
pub fn end(&self) -> bool {
self.visited_funcs == self.num_funcs
pub fn get_curr_func(&self) -> (FunctionID, usize) {
self.metadata[self.curr_idx]
}

/// Returns the current Location in the Module as a Location
Expand All @@ -59,57 +43,56 @@ impl ModuleSubIterator {
let curr_instr = self.func_iterator.curr_instr;
(
Location::Module {
func_idx: self.curr_func,
func_idx: self.get_curr_func().0,
instr_idx: curr_instr,
},
self.func_iterator.is_end(curr_instr),
)
}

/// Resets the ModuleSubIterator when it is a Child SubIterator of a ComponentSubIterator
pub(crate) fn reset_from_comp_iterator(&mut self, metadata: HashMap<FunctionID, usize>) {
*self.curr_func = 0;
pub(crate) fn reset_from_comp_iterator(&mut self, metadata: Vec<(FunctionID, usize)>) {
self.metadata = metadata;
self.func_iterator.reset(
*self
.metadata
.get(self.metadata.keys().min().unwrap())
.unwrap(),
);
self.reset();
}

/// Resets the ModuleSubIterator when it is not a Child SubIterator
pub fn reset(&mut self) {
*self.curr_func = 0;
self.func_iterator
.reset(*self.metadata.get(&FunctionID(0)).unwrap());
self.curr_idx = 0;
self.handle_skips();
self.func_iterator.reset(self.get_curr_func().1);
}

/// Checks if there are functions left to visit
pub fn has_next_function(&self) -> bool {
*self.curr_func + 1 < self.num_funcs
fn handle_skips(&mut self) {
let mut curr_fid = self.get_curr_func().0;
while self.has_next_function() && self.skip_funcs.contains(&curr_fid) {
self.curr_idx += 1;
curr_fid = self.get_curr_func().0;
}
}

/// Goes to the next function in the module
fn next_function(&mut self) -> bool {
*self.curr_func += 1;
self.visited_funcs += 1;
if !self.has_next_function() {
return false;
}
self.curr_idx += 1;

// skip over configured funcs
while self.visited_funcs < self.num_funcs
&& self.skip_funcs.contains(&(self.curr_func as FunctionID))
{
*self.curr_func += 1;
self.visited_funcs += 1;
}
if self.visited_funcs < self.num_funcs {
self.func_iterator = FuncSubIterator::new(*self.metadata.get(&self.curr_func).unwrap());
self.handle_skips();
if self.curr_idx < self.metadata.len() {
self.func_iterator = FuncSubIterator::new(self.get_curr_func().1);
true
} else {
false
}
}

/// Checks if there are functions left to visit
pub fn has_next_function(&self) -> bool {
self.curr_idx + 1 < self.metadata.len()
}

/// Checks if there are functions left to visit
pub(crate) fn has_next(&self) -> bool {
self.func_iterator.has_next() || self.has_next_function()
Expand Down
38 changes: 36 additions & 2 deletions tests/func_builder.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use orca_wasm::ir::id::FunctionID;
use log::trace;
use orca_wasm::ir::function::FunctionBuilder;
use orca_wasm::ir::id::{FunctionID, TypeID};
use orca_wasm::iterator::iterator_trait::Iterator;
use orca_wasm::iterator::module_iterator::ModuleIterator;
use orca_wasm::opcode::Instrumenter;
use orca_wasm::Opcode;
use orca_wasm::{DataType, Opcode};
use orca_wasm::{Location, Module};
use std::process::Command;

Expand Down Expand Up @@ -61,3 +65,33 @@ fn run_start_orca_default() {
let out = wasmprinter::print_bytes(result.clone()).expect("couldn't translate Wasm to wat");
println!("{}", out);
}
#[test]
// test start function instrumentation with FunctionModifier
fn add_import_and_local_fn_then_iterate() {
let file_name = "tests/test_inputs/handwritten/modules/_start.wat";
let wasm = wat::parse_file(file_name).expect("couldn't convert the input wat to Wasm");
let mut module = Module::parse(&wasm, false).expect("Unable to parse");
// add an imported function AND THEN a new local function
module.add_import_func("new".to_string(), "import".to_string(), TypeID(0));
assert_eq!(module.num_import_func(), 1);

let params = vec![];
let results = vec![DataType::I32];

let mut new_func = FunctionBuilder::new(&params, &results);
new_func.i32_const(1);
new_func.finish_module(&mut module);

// now iterate over module
let mut mod_it = ModuleIterator::new(&mut module, &vec![]);
loop {
let _op = mod_it.curr_op();
if mod_it.next().is_none() {
break;
};
}

let result = module.encode();
let out = wasmprinter::print_bytes(result.clone()).expect("couldn't translate Wasm to wat");
println!("{}", out);
}

0 comments on commit ded87ff

Please sign in to comment.