From 63591a46640ec43199c91e6283d87a3c94074ace Mon Sep 17 00:00:00 2001 From: Henry Gressmann Date: Mon, 4 Dec 2023 13:54:29 +0100 Subject: [PATCH] feat: add typed_func_handle, check for _start function Signed-off-by: Henry Gressmann --- crates/cli/bin.rs | 10 +- crates/cli/wat.rs | 2 +- crates/tinywasm/src/func.rs | 133 +++++++++++++++++++++++ crates/tinywasm/src/instance.rs | 101 ++++++----------- crates/tinywasm/src/lib.rs | 3 + crates/tinywasm/src/runtime/mod.rs | 16 +-- crates/tinywasm/src/runtime/stack/mod.rs | 11 +- crates/types/src/lib.rs | 87 ++++++++++++++- 8 files changed, 277 insertions(+), 86 deletions(-) create mode 100644 crates/tinywasm/src/func.rs diff --git a/crates/cli/bin.rs b/crates/cli/bin.rs index f8b0e1e..3615e4a 100644 --- a/crates/cli/bin.rs +++ b/crates/cli/bin.rs @@ -3,7 +3,7 @@ use std::str::FromStr; use argh::FromArgs; use color_eyre::eyre::Result; use log::info; -use tinywasm::{self, Module, WasmValue}; +use tinywasm::{self, Module}; mod util; #[cfg(feature = "wat")] @@ -86,7 +86,7 @@ fn main() -> Result<()> { #[cfg(not(feature = "wat"))] true => { return Err(color_eyre::eyre::eyre!( - "wat support is not enabled, please enable it with --features wat" + "wat support is not enabled in this build" )) } false => tinywasm::Module::parse_file(path)?, @@ -103,9 +103,9 @@ fn run(module: Module) -> Result<()> { let mut store = tinywasm::Store::default(); let instance = module.instantiate(&mut store)?; - let func = instance.get_func(&store, "add")?; - let params = vec![WasmValue::I32(2), WasmValue::I32(2)]; - let res = func.call(&mut store, params)?; + + let func = instance.get_typed_func::<(i32, i32), (i32,)>(&store, "add")?; + let (res,) = func.call(&mut store, (2, 2))?; info!("{res:?}"); Ok(()) diff --git a/crates/cli/wat.rs b/crates/cli/wat.rs index 22d931d..fd00a8e 100644 --- a/crates/cli/wat.rs +++ b/crates/cli/wat.rs @@ -3,7 +3,7 @@ use wast::{ Wat, }; -pub fn wat2wasm<'a>(wat: &str) -> Vec { +pub fn wat2wasm(wat: &str) -> Vec { let buf = ParseBuffer::new(wat).expect("failed to create parse buffer"); let mut module = parser::parse::(&buf).expect("failed to parse wat"); module.encode().expect("failed to encode wat") diff --git a/crates/tinywasm/src/func.rs b/crates/tinywasm/src/func.rs new file mode 100644 index 0000000..f8f8804 --- /dev/null +++ b/crates/tinywasm/src/func.rs @@ -0,0 +1,133 @@ +use alloc::{format, string::String, string::ToString, vec, vec::Vec}; +use tinywasm_types::{FuncAddr, FuncType, ValType, WasmValue}; + +use crate::{runtime::Stack, Error, ModuleInstance, Result, Store}; + +#[derive(Debug)] +pub struct FuncHandle { + pub(crate) _module: ModuleInstance, + pub(crate) addr: FuncAddr, + pub(crate) ty: FuncType, + pub name: Option, +} +impl FuncHandle { + /// Call a function + pub fn call(&self, store: &mut Store, params: &[WasmValue]) -> Result> { + let func = store + .data + .funcs + .get(self.addr as usize) + .ok_or(Error::Other(format!("function {} not found", self.addr)))?; + + let func_ty = &self.ty; + + // check that params match func_ty params + for (ty, param) in func_ty.params.iter().zip(params) { + if ty != ¶m.val_type() { + return Err(Error::Other(format!( + "param type mismatch: expected {:?}, got {:?}", + ty, param + ))); + } + } + + let mut local_types: Vec = Vec::new(); + local_types.extend(func_ty.params.iter()); + local_types.extend(func.locals().iter()); + + // let runtime = &mut store.runtime; + + let mut stack = Stack::default(); + stack.locals.extend(params.iter().cloned()); + + let instrs = func.instructions().iter(); + store.runtime.exec(&mut stack, instrs)?; + + let res = func_ty + .results + .iter() + .map(|_| stack.value_stack.pop()) + .collect::>>() + .ok_or(Error::Other( + "function did not return the correct number of values".into(), + ))?; + + Ok(res) + } +} + +pub struct TypedFuncHandle { + pub func: FuncHandle, + pub(crate) marker: core::marker::PhantomData<(P, R)>, +} + +pub trait IntoWasmValueTuple { + fn into_wasm_value_tuple(self) -> Vec; +} + +pub trait FromWasmValueTuple { + fn from_wasm_value_tuple(values: Vec) -> Result + where + Self: Sized; +} + +impl TypedFuncHandle { + pub fn call(&self, store: &mut Store, params: P) -> Result { + // Convert params into Vec + let wasm_values = params.into_wasm_value_tuple(); + + // Call the underlying WASM function + let result = self.func.call(store, &wasm_values)?; + + // Convert the Vec back to R + R::from_wasm_value_tuple(result) + } +} +macro_rules! impl_into_wasm_value_tuple { + ($($T:ident),*) => { + impl<$($T),*> IntoWasmValueTuple for ($($T,)*) + where + $($T: Into),* + { + #[allow(non_snake_case)] + fn into_wasm_value_tuple(self) -> Vec { + let ($($T,)*) = self; + vec![$($T.into(),)*] + } + } + } +} + +impl_into_wasm_value_tuple!(T1); +impl_into_wasm_value_tuple!(T1, T2); +impl_into_wasm_value_tuple!(T1, T2, T3); +impl_into_wasm_value_tuple!(T1, T2, T3, T4); +impl_into_wasm_value_tuple!(T1, T2, T3, T4, T5); + +macro_rules! impl_from_wasm_value_tuple { + ($($T:ident),*) => { + impl<$($T),*> FromWasmValueTuple for ($($T,)*) + where + $($T: TryFrom),* + { + fn from_wasm_value_tuple(values: Vec) -> Result { + let mut iter = values.into_iter(); + Ok(( + $( + $T::try_from( + iter.next() + .ok_or(Error::Other("Not enough values in WasmValue vector".to_string()))? + ) + .map_err(|_| Error::Other("Could not convert WasmValue to expected type".to_string()))?, + )* + )) + } + } + } +} + +impl_from_wasm_value_tuple!(T1); +impl_from_wasm_value_tuple!(T1, T2); +impl_from_wasm_value_tuple!(T1, T2, T3); +impl_from_wasm_value_tuple!(T1, T2, T3, T4); +impl_from_wasm_value_tuple!(T1, T2, T3, T4, T5); diff --git a/crates/tinywasm/src/instance.rs b/crates/tinywasm/src/instance.rs index 8b445f1..c90cec6 100644 --- a/crates/tinywasm/src/instance.rs +++ b/crates/tinywasm/src/instance.rs @@ -1,14 +1,10 @@ -use alloc::{ - boxed::Box, - format, - string::{String, ToString}, - sync::Arc, - vec, - vec::Vec, -}; -use tinywasm_types::{Export, FuncAddr, FuncType, ModuleInstanceAddr, ValType, WasmValue}; +use alloc::{boxed::Box, string::ToString, sync::Arc, vec::Vec}; +use tinywasm_types::{Export, FuncAddr, FuncType, ModuleInstanceAddr}; -use crate::{runtime::Stack, Error, ExportInstance, Result, Store}; +use crate::{ + func::{FromWasmValueTuple, IntoWasmValueTuple}, + ExportInstance, FuncHandle, Result, Store, TypedFuncHandle, +}; /// A WebAssembly Module Instance. /// Addrs are indices into the store's data structures. @@ -63,10 +59,35 @@ impl ModuleInstance { }) } + /// Get a typed exported function by name + pub fn get_typed_func(&self, store: &Store, name: &str) -> Result> + where + P: IntoWasmValueTuple, + R: FromWasmValueTuple, + { + let func = self.get_func(store, name)?; + Ok(TypedFuncHandle { + func, + marker: core::marker::PhantomData, + }) + } + /// Get the start function of the module + /// Returns None if the module has no start function + /// If no start function is specified, also checks for a _start function in the exports + /// (which is not part of the spec, but used by llvm) + /// https://webassembly.github.io/spec/core/syntax/modules.html#start-function pub fn get_start_func(&mut self, store: &Store) -> Result> { - let Some(func_index) = self.0.func_start else { - return Ok(None); + let func_index = match self.0.func_start { + Some(func_index) => func_index, + None => { + // alternatively, check for a _start function in the exports + let Ok(start) = self.0.exports.func("_start") else { + return Ok(None); + }; + + start.index + } }; let func_addr = self.0.func_addrs[func_index as usize]; @@ -89,61 +110,7 @@ impl ModuleInstance { return Ok(None); }; - let _ = func.call(store, vec![]); + let _ = func.call(store, &[])?; Ok(Some(())) } } - -#[derive(Debug)] -pub struct FuncHandle { - _module: ModuleInstance, - addr: FuncAddr, - ty: FuncType, - pub name: Option, -} - -impl FuncHandle { - /// Call a function - pub fn call(&self, store: &mut Store, params: Vec) -> Result> { - let func = store - .data - .funcs - .get(self.addr as usize) - .ok_or(Error::Other(format!("function {} not found", self.addr)))?; - - let func_ty = &self.ty; - - // check that params match func_ty params - for (ty, param) in func_ty.params.iter().zip(params.clone()) { - if ty != ¶m.val_type() { - return Err(Error::Other(format!( - "param type mismatch: expected {:?}, got {:?}", - ty, param - ))); - } - } - - let mut local_types: Vec = Vec::new(); - local_types.extend(func_ty.params.iter()); - local_types.extend(func.locals().iter()); - - // let runtime = &mut store.runtime; - - let mut stack = Stack::default(); - stack.locals.extend(params); - - let instrs = func.instructions().iter(); - store.runtime.exec(&mut stack, instrs)?; - - let res = func_ty - .results - .iter() - .map(|_| stack.value_stack.pop()) - .collect::>>() - .ok_or(Error::Other( - "function did not return the correct number of values".into(), - ))?; - - Ok(res) - } -} diff --git a/crates/tinywasm/src/lib.rs b/crates/tinywasm/src/lib.rs index 9e976e4..03e5cac 100644 --- a/crates/tinywasm/src/lib.rs +++ b/crates/tinywasm/src/lib.rs @@ -20,6 +20,9 @@ pub use instance::ModuleInstance; pub mod export; pub use export::ExportInstance; +pub mod func; +pub use func::{FuncHandle, TypedFuncHandle}; + pub use tinywasm_parser as parser; pub use tinywasm_types::*; pub mod runtime; diff --git a/crates/tinywasm/src/runtime/mod.rs b/crates/tinywasm/src/runtime/mod.rs index 78a98bd..b8aa5aa 100644 --- a/crates/tinywasm/src/runtime/mod.rs +++ b/crates/tinywasm/src/runtime/mod.rs @@ -19,31 +19,33 @@ impl Runtime { instrs: core::slice::Iter, ) -> Result<()> { let locals = &mut stack.locals; + let value_stack = &mut stack.value_stack; + for instr in instrs { use tinywasm_types::Instruction::*; match instr { LocalGet(local_index) => { let val = &locals[*local_index as usize]; debug!("local: {:#?}", val); - stack.value_stack.push(val.clone()); + value_stack.push(val.clone()); } I64Add => { - let a = stack.value_stack.pop().unwrap(); - let b = stack.value_stack.pop().unwrap(); + let a = value_stack.pop().unwrap(); + let b = value_stack.pop().unwrap(); let (WasmValue::I64(a), WasmValue::I64(b)) = (a, b) else { panic!("Invalid type"); }; let c = WasmValue::I64(a + b); - stack.value_stack.push(c); + value_stack.push(c); } I32Add => { - let a = stack.value_stack.pop().unwrap(); - let b = stack.value_stack.pop().unwrap(); + let a = value_stack.pop().unwrap(); + let b = value_stack.pop().unwrap(); let (WasmValue::I32(a), WasmValue::I32(b)) = (a, b) else { panic!("Invalid type"); }; let c = WasmValue::I32(a + b); - stack.value_stack.push(c); + value_stack.push(c); } End => { return Ok(()); diff --git a/crates/tinywasm/src/runtime/stack/mod.rs b/crates/tinywasm/src/runtime/stack/mod.rs index 1149429..18f68b7 100644 --- a/crates/tinywasm/src/runtime/stack/mod.rs +++ b/crates/tinywasm/src/runtime/stack/mod.rs @@ -20,9 +20,10 @@ pub struct Stack { // TODO: Split into Vec and Vec for better memory usage? pub value_stack: Vec, // keeping this typed for now to make it easier to debug pub value_stack_top: usize, - // /// The call stack - // pub call_stack: Vec, - // pub call_stack_top: usize, + + /// The call stack + pub call_stack: Vec, + pub call_stack_top: usize, } impl Default for Stack { @@ -31,8 +32,8 @@ impl Default for Stack { locals: Vec::new(), value_stack: Vec::with_capacity(STACK_SIZE), value_stack_top: 0, - // call_stack: Vec::with_capacity(CALL_STACK_SIZE), - // call_stack_top: 0, + call_stack: Vec::with_capacity(CALL_STACK_SIZE), + call_stack_top: 0, } } } diff --git a/crates/types/src/lib.rs b/crates/types/src/lib.rs index a354203..9f37759 100644 --- a/crates/types/src/lib.rs +++ b/crates/types/src/lib.rs @@ -35,6 +35,91 @@ pub enum WasmValue { V128(i128), } +impl From for WasmValue { + fn from(i: i32) -> Self { + Self::I32(i) + } +} + +impl From for WasmValue { + fn from(i: i64) -> Self { + Self::I64(i) + } +} + +impl From for WasmValue { + fn from(i: f32) -> Self { + Self::F32(i) + } +} + +impl From for WasmValue { + fn from(i: f64) -> Self { + Self::F64(i) + } +} + +impl From for WasmValue { + fn from(i: i128) -> Self { + Self::V128(i) + } +} + +impl TryFrom for i32 { + type Error = (); + + fn try_from(value: WasmValue) -> Result { + match value { + WasmValue::I32(i) => Ok(i), + _ => Err(()), + } + } +} + +impl TryFrom for i64 { + type Error = (); + + fn try_from(value: WasmValue) -> Result { + match value { + WasmValue::I64(i) => Ok(i), + _ => Err(()), + } + } +} + +impl TryFrom for f32 { + type Error = (); + + fn try_from(value: WasmValue) -> Result { + match value { + WasmValue::F32(i) => Ok(i), + _ => Err(()), + } + } +} + +impl TryFrom for f64 { + type Error = (); + + fn try_from(value: WasmValue) -> Result { + match value { + WasmValue::F64(i) => Ok(i), + _ => Err(()), + } + } +} + +impl TryFrom for i128 { + type Error = (); + + fn try_from(value: WasmValue) -> Result { + match value { + WasmValue::V128(i) => Ok(i), + _ => Err(()), + } + } +} + impl Debug for WasmValue { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -85,7 +170,6 @@ pub enum ExternalKind { /// These are indexes into the respective stores. /// See https://webassembly.github.io/spec/core/exec/runtime.html#addresses pub type Addr = u32; -pub type ModuleInstanceAddr = Addr; pub type FuncAddr = Addr; pub type TableAddr = Addr; pub type MemAddr = Addr; @@ -97,6 +181,7 @@ pub type ExternAddr = Addr; pub type TypeAddr = Addr; pub type LocalAddr = Addr; pub type LabelAddr = Addr; +pub type ModuleInstanceAddr = Addr; /// A WebAssembly Export Instance. /// https://webassembly.github.io/spec/core/exec/runtime.html#export-instances