Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't visit added imported functions during application iteration. #194

Merged
merged 4 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/ir/module/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,20 @@ impl<'a> Module<'a> {
})
}

/// Creates Vec of (Function, Number of Instructions)
pub fn get_func_metadata(&self) -> Vec<(FunctionID, usize)> {
let mut metadata = vec![];
for func in self.functions.iter() {
match &func.kind {
FuncKind::Import(_) => {}
FuncKind::Local(LocalFunction { func_id, body, .. }) => {
metadata.push((*func_id, body.num_instructions));
}
}
}
metadata
}

/// Emit the module into a wasm binary file.
pub fn emit_wasm(&mut self, file_name: &str) -> Result<(), std::io::Error> {
let module = self.encode_internal();
Expand Down
16 changes: 3 additions & 13 deletions src/iterator/component_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

use crate::ir::component::Component;
use crate::ir::id::{FunctionID, GlobalID, LocalID, ModuleID};
use crate::ir::module::module_functions::{FuncKind, LocalFunction};
use crate::ir::module::module_functions::FuncKind;
use crate::ir::module::module_globals::Global;
use crate::ir::module::Iter;
use crate::ir::types::{DataType, FuncInstrMode, InstrumentationMode, Location};
use crate::iterator::iterator_trait::{IteratingInstrumenter, Iterator};
use crate::module_builder::AddLocal;
Expand All @@ -22,7 +21,7 @@ pub struct ComponentIterator<'a, 'b> {
comp_iterator: ComponentSubIterator,
}

fn print_metadata(metadata: &HashMap<ModuleID, HashMap<FunctionID, usize>>) {
fn print_metadata(metadata: &HashMap<ModuleID, Vec<(FunctionID, usize)>>) {
for c in metadata.keys() {
println!("Module: {:?}", c);
for (m, i) in metadata.get(c).unwrap().iter() {
Expand All @@ -41,16 +40,7 @@ impl<'a, 'b> ComponentIterator<'a, 'b> {
// Creates Module -> Function -> Number of Instructions
let mut metadata = HashMap::new();
for (mod_idx, m) in comp.modules.iter().enumerate() {
let mut mod_metadata = HashMap::new();
for func in m.functions.iter() {
match &func.kind {
FuncKind::Import(_) => {}
FuncKind::Local(LocalFunction { func_id, body, .. }) => {
mod_metadata.insert(*func_id, body.num_instructions);
}
}
}
metadata.insert(ModuleID(mod_idx as u32), mod_metadata);
metadata.insert(ModuleID(mod_idx as u32), m.get_func_metadata());
}
print_metadata(&metadata);
let num_modules = comp.num_modules;
Expand Down
33 changes: 9 additions & 24 deletions src/iterator/module_iterator.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
//! Iterator to traverse a Module

use crate::ir::id::{FunctionID, GlobalID, LocalID};
use crate::ir::module::module_functions::{FuncKind, LocalFunction};
use crate::ir::module::module_functions::FuncKind;
use crate::ir::module::module_globals::Global;
use crate::ir::module::{Iter, Module};
use crate::ir::module::Module;
use crate::ir::types::{DataType, FuncInstrMode, InstrumentationMode, Location};
use crate::iterator::iterator_trait::{IteratingInstrumenter, Iterator};
use crate::module_builder::AddLocal;
use crate::opcode::{Inject, InjectAt, Instrumenter, MacroOpcode, Opcode};
use crate::subiterator::module_subiterator::ModuleSubIterator;
use std::collections::HashMap;
use wasmparser::Operator;

/// Iterator for a Module.
Expand All @@ -25,27 +24,15 @@ pub struct ModuleIterator<'a, 'b> {
impl<'a, 'b> ModuleIterator<'a, 'b> {
/// Creates a new ModuleIterator
pub fn new(module: &'a mut Module<'b>, skip_funcs: &Vec<FunctionID>) -> Self {
// Creates Function -> Number of Instructions
let mut metadata = HashMap::new();
for func in module.functions.iter() {
match &func.kind {
FuncKind::Import(_) => {}
FuncKind::Local(LocalFunction { func_id, body, .. }) => {
metadata.insert(*func_id, body.num_instructions);
}
}
}
let num_funcs = module.num_local_functions;
let metadata = module.get_func_metadata();
ModuleIterator {
module,
mod_iterator: ModuleSubIterator::new(num_funcs, metadata, skip_funcs.to_owned()),
mod_iterator: ModuleSubIterator::new(metadata, skip_funcs.to_owned()),
}
}

pub fn curr_op_owned(&self) -> Option<Operator<'b>> {
if self.mod_iterator.end() {
None
} else if let (
if let (
Location::Module {
func_idx,
instr_idx,
Expand All @@ -54,7 +41,7 @@ impl<'a, 'b> ModuleIterator<'a, 'b> {
..,
) = self.mod_iterator.curr_loc()
{
match &self.module.functions.get(func_idx as FunctionID).kind {
match &self.module.functions.get(func_idx).kind {
FuncKind::Import(_) => panic!("Cannot get an instruction to an imported function"),
FuncKind::Local(l) => Some(l.body.instructions[instr_idx].op.clone()),
}
Expand Down Expand Up @@ -247,7 +234,7 @@ impl<'a> Instrumenter<'a> for ModuleIterator<'_, 'a> {
..
} = loc
{
match self.module.functions.get_mut(func_idx as FunctionID).kind {
match self.module.functions.get_mut(func_idx).kind {
FuncKind::Import(_) => panic!("Cannot instrument an imported function"),
FuncKind::Local(ref mut l) => {
l.body.instructions[instr_idx].instr_flag.alternate = Some(vec![])
Expand Down Expand Up @@ -343,9 +330,7 @@ impl<'a> Iterator for ModuleIterator<'_, 'a> {

/// Returns the current instruction
fn curr_op(&self) -> Option<&Operator<'a>> {
if self.mod_iterator.end() {
None
} else if let (
if let (
Location::Module {
func_idx,
instr_idx,
Expand All @@ -354,7 +339,7 @@ impl<'a> Iterator for ModuleIterator<'_, 'a> {
..,
) = self.mod_iterator.curr_loc()
{
match &self.module.functions.get(func_idx as FunctionID).kind {
match &self.module.functions.get(func_idx).kind {
FuncKind::Import(_) => panic!("Cannot get an instruction to an imported function"),
FuncKind::Local(l) => Some(&l.body.instructions[instr_idx].op),
}
Expand Down
11 changes: 4 additions & 7 deletions src/subiterator/component_subiterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ pub struct ComponentSubIterator {
num_mods: usize,
/// The module iterator used to keep track of the location in the module.
pub(crate) mod_iterator: ModuleSubIterator,
/// Metadata that maps Module Index -> Function Index -> Instruction Index
metadata: HashMap<ModuleID, HashMap<FunctionID, usize>>,
/// Metadata that maps Module Index -> Vec<(Function Index, Instruction Index)>
metadata: HashMap<ModuleID, Vec<(FunctionID, usize)>>,
/// Map of Module -> Functions to skip in that module. Provide an empty HashMap if no functions are to be skipped.
skip_funcs: HashMap<ModuleID, Vec<FunctionID>>,
}
Expand All @@ -24,7 +24,7 @@ impl ComponentSubIterator {
pub fn new(
curr_mod: ModuleID,
num_mods: usize,
metadata: HashMap<ModuleID, HashMap<FunctionID, usize>>,
metadata: HashMap<ModuleID, Vec<(FunctionID, usize)>>,
skip_funcs: HashMap<ModuleID, Vec<FunctionID>>,
) -> Self {
// Get current skip func
Expand All @@ -34,7 +34,6 @@ impl ComponentSubIterator {
num_mods,
metadata: metadata.clone(),
mod_iterator: ModuleSubIterator::new(
metadata.get(&ModuleID(0)).unwrap().keys().len() as u32,
(*metadata.get(&curr_mod).unwrap()).clone(),
match skip_funcs.contains_key(&curr_mod) {
true => skip_funcs.get(&curr_mod).unwrap().clone(),
Expand All @@ -56,11 +55,9 @@ impl ComponentSubIterator {
fn next_module(&mut self) -> bool {
*self.curr_mod += 1;
if *self.curr_mod < self.num_mods as u32 {
let num_funcs = self.metadata.get(&self.curr_mod).unwrap().keys().len() as u32;
let met = self.metadata.get(&self.curr_mod).unwrap().clone();
// If we're defining a new module, we have to reset function
self.mod_iterator = ModuleSubIterator::new(
num_funcs,
met,
match self.skip_funcs.contains_key(&self.curr_mod) {
true => self.skip_funcs.get(&self.curr_mod).unwrap().clone(),
Expand All @@ -80,7 +77,7 @@ impl ComponentSubIterator {

/// Gets the index of the current function in the current module
pub fn curr_func_idx(&self) -> FunctionID {
self.mod_iterator.curr_func
self.mod_iterator.get_curr_func().0
}

/// Gets the index of the current instruction in the current function
Expand Down
93 changes: 38 additions & 55 deletions src/subiterator/module_subiterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,13 @@
use crate::ir::id::FunctionID;
use crate::ir::types::Location;
use crate::subiterator::function_subiterator::FuncSubIterator;
use std::collections::HashMap;

/// Sub-iterator for a Module. Keeps track of current location in a Module.
pub struct ModuleSubIterator {
/// The current function the SubIterator is at
pub(crate) curr_func: FunctionID,
/// The number of functions that have been visited thus far
visited_funcs: u32,
/// Number of functions in this module
num_funcs: u32,
/// Metadata that maps Function Index -> Instruction Index
metadata: HashMap<FunctionID, usize>,
pub(crate) curr_idx: usize,
/// Metadata containing a functions index and number_of_instructions
metadata: Vec<(FunctionID, usize)>,
/// The function iterator used to keep track of the location in the function.
pub(crate) func_iterator: FuncSubIterator,
/// Functions to skip. Provide an empty vector if no functions are to be skipped.
Expand All @@ -23,34 +18,23 @@ pub struct ModuleSubIterator {

impl ModuleSubIterator {
/// Creates a new ModuleSubIterator
pub fn new(
num_funcs: u32,
metadata: HashMap<FunctionID, usize>,
skip_funcs: Vec<FunctionID>,
) -> Self {
pub fn new(metadata: Vec<(FunctionID, usize)>, skip_funcs: Vec<FunctionID>) -> Self {
let curr_idx = 0;

let (_curr_fid, curr_num_instrs) = metadata[curr_idx];
let mut mod_it = ModuleSubIterator {
curr_func: *metadata.keys().min().unwrap(),
visited_funcs: 0,
num_funcs,
metadata: metadata.clone(),
func_iterator: FuncSubIterator::new(
*metadata.get(metadata.keys().min().unwrap()).unwrap(),
),
curr_idx,
metadata,
func_iterator: FuncSubIterator::new(curr_num_instrs),
skip_funcs,
};
// In case 0 is in skip func
while mod_it
.skip_funcs
.contains(&(mod_it.curr_func as FunctionID))
{
mod_it.next_function();
}
mod_it.handle_skips();

mod_it
}

/// Checks if the SubIterator has finished traversing all the functions
pub fn end(&self) -> bool {
self.visited_funcs == self.num_funcs
pub fn get_curr_func(&self) -> (FunctionID, usize) {
self.metadata[self.curr_idx]
}

/// Returns the current Location in the Module as a Location
Expand All @@ -59,57 +43,56 @@ impl ModuleSubIterator {
let curr_instr = self.func_iterator.curr_instr;
(
Location::Module {
func_idx: self.curr_func,
func_idx: self.get_curr_func().0,
instr_idx: curr_instr,
},
self.func_iterator.is_end(curr_instr),
)
}

/// Resets the ModuleSubIterator when it is a Child SubIterator of a ComponentSubIterator
pub(crate) fn reset_from_comp_iterator(&mut self, metadata: HashMap<FunctionID, usize>) {
*self.curr_func = 0;
pub(crate) fn reset_from_comp_iterator(&mut self, metadata: Vec<(FunctionID, usize)>) {
self.metadata = metadata;
self.func_iterator.reset(
*self
.metadata
.get(self.metadata.keys().min().unwrap())
.unwrap(),
);
self.reset();
}

/// Resets the ModuleSubIterator when it is not a Child SubIterator
pub fn reset(&mut self) {
*self.curr_func = 0;
self.func_iterator
.reset(*self.metadata.get(&FunctionID(0)).unwrap());
self.curr_idx = 0;
self.handle_skips();
self.func_iterator.reset(self.get_curr_func().1);
}

/// Checks if there are functions left to visit
pub fn has_next_function(&self) -> bool {
*self.curr_func + 1 < self.num_funcs
fn handle_skips(&mut self) {
let mut curr_fid = self.get_curr_func().0;
while self.has_next_function() && self.skip_funcs.contains(&curr_fid) {
self.curr_idx += 1;
curr_fid = self.get_curr_func().0;
}
}

/// Goes to the next function in the module
fn next_function(&mut self) -> bool {
*self.curr_func += 1;
self.visited_funcs += 1;
if !self.has_next_function() {
return false;
}
self.curr_idx += 1;

// skip over configured funcs
while self.visited_funcs < self.num_funcs
&& self.skip_funcs.contains(&(self.curr_func as FunctionID))
{
*self.curr_func += 1;
self.visited_funcs += 1;
}
if self.visited_funcs < self.num_funcs {
self.func_iterator = FuncSubIterator::new(*self.metadata.get(&self.curr_func).unwrap());
self.handle_skips();
if self.curr_idx < self.metadata.len() {
self.func_iterator = FuncSubIterator::new(self.get_curr_func().1);
true
} else {
false
}
}

/// Checks if there are functions left to visit
pub fn has_next_function(&self) -> bool {
self.curr_idx + 1 < self.metadata.len()
}

/// Checks if there are functions left to visit
pub(crate) fn has_next(&self) -> bool {
self.func_iterator.has_next() || self.has_next_function()
Expand Down
38 changes: 36 additions & 2 deletions tests/func_builder.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use orca_wasm::ir::id::FunctionID;
use log::trace;
use orca_wasm::ir::function::FunctionBuilder;
use orca_wasm::ir::id::{FunctionID, TypeID};
use orca_wasm::iterator::iterator_trait::Iterator;
use orca_wasm::iterator::module_iterator::ModuleIterator;
use orca_wasm::opcode::Instrumenter;
use orca_wasm::Opcode;
use orca_wasm::{DataType, Opcode};
use orca_wasm::{Location, Module};
use std::process::Command;

Expand Down Expand Up @@ -61,3 +65,33 @@ fn run_start_orca_default() {
let out = wasmprinter::print_bytes(result.clone()).expect("couldn't translate Wasm to wat");
println!("{}", out);
}
#[test]
// test start function instrumentation with FunctionModifier
fn add_import_and_local_fn_then_iterate() {
let file_name = "tests/test_inputs/handwritten/modules/_start.wat";
let wasm = wat::parse_file(file_name).expect("couldn't convert the input wat to Wasm");
let mut module = Module::parse(&wasm, false).expect("Unable to parse");
// add an imported function AND THEN a new local function
module.add_import_func("new".to_string(), "import".to_string(), TypeID(0));
assert_eq!(module.num_import_func(), 1);

let params = vec![];
let results = vec![DataType::I32];

let mut new_func = FunctionBuilder::new(&params, &results);
new_func.i32_const(1);
new_func.finish_module(&mut module);

// now iterate over module
let mut mod_it = ModuleIterator::new(&mut module, &vec![]);
loop {
let _op = mod_it.curr_op();
if mod_it.next().is_none() {
break;
};
}

let result = module.encode();
let out = wasmprinter::print_bytes(result.clone()).expect("couldn't translate Wasm to wat");
println!("{}", out);
}