Skip to content

Commit

Permalink
feat: add typed_func_handle, check for _start function
Browse files Browse the repository at this point in the history
Signed-off-by: Henry Gressmann <mail@henrygressmann.de>
  • Loading branch information
explodingcamera committed Dec 4, 2023
1 parent 8e065ba commit 63591a4
Show file tree
Hide file tree
Showing 8 changed files with 277 additions and 86 deletions.
10 changes: 5 additions & 5 deletions crates/cli/bin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::str::FromStr;
use argh::FromArgs;
use color_eyre::eyre::Result;
use log::info;
use tinywasm::{self, Module, WasmValue};
use tinywasm::{self, Module};
mod util;

#[cfg(feature = "wat")]
Expand Down Expand Up @@ -86,7 +86,7 @@ fn main() -> Result<()> {
#[cfg(not(feature = "wat"))]
true => {
return Err(color_eyre::eyre::eyre!(
"wat support is not enabled, please enable it with --features wat"
"wat support is not enabled in this build"
))
}
false => tinywasm::Module::parse_file(path)?,
Expand All @@ -103,9 +103,9 @@ fn run(module: Module) -> Result<()> {
let mut store = tinywasm::Store::default();

let instance = module.instantiate(&mut store)?;
let func = instance.get_func(&store, "add")?;
let params = vec![WasmValue::I32(2), WasmValue::I32(2)];
let res = func.call(&mut store, params)?;

let func = instance.get_typed_func::<(i32, i32), (i32,)>(&store, "add")?;
let (res,) = func.call(&mut store, (2, 2))?;
info!("{res:?}");

Ok(())
Expand Down
2 changes: 1 addition & 1 deletion crates/cli/wat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use wast::{
Wat,
};

pub fn wat2wasm<'a>(wat: &str) -> Vec<u8> {
pub fn wat2wasm(wat: &str) -> Vec<u8> {
let buf = ParseBuffer::new(wat).expect("failed to create parse buffer");
let mut module = parser::parse::<Wat>(&buf).expect("failed to parse wat");
module.encode().expect("failed to encode wat")
Expand Down
133 changes: 133 additions & 0 deletions crates/tinywasm/src/func.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
use alloc::{format, string::String, string::ToString, vec, vec::Vec};
use tinywasm_types::{FuncAddr, FuncType, ValType, WasmValue};

use crate::{runtime::Stack, Error, ModuleInstance, Result, Store};

#[derive(Debug)]
pub struct FuncHandle {
pub(crate) _module: ModuleInstance,
pub(crate) addr: FuncAddr,
pub(crate) ty: FuncType,
pub name: Option<String>,
}
impl FuncHandle {
/// Call a function
pub fn call(&self, store: &mut Store, params: &[WasmValue]) -> Result<Vec<WasmValue>> {
let func = store
.data
.funcs
.get(self.addr as usize)
.ok_or(Error::Other(format!("function {} not found", self.addr)))?;

let func_ty = &self.ty;

// check that params match func_ty params
for (ty, param) in func_ty.params.iter().zip(params) {
if ty != &param.val_type() {
return Err(Error::Other(format!(
"param type mismatch: expected {:?}, got {:?}",
ty, param
)));
}
}

let mut local_types: Vec<ValType> = Vec::new();
local_types.extend(func_ty.params.iter());
local_types.extend(func.locals().iter());

// let runtime = &mut store.runtime;

let mut stack = Stack::default();
stack.locals.extend(params.iter().cloned());

let instrs = func.instructions().iter();
store.runtime.exec(&mut stack, instrs)?;

let res = func_ty
.results
.iter()
.map(|_| stack.value_stack.pop())
.collect::<Option<Vec<_>>>()
.ok_or(Error::Other(
"function did not return the correct number of values".into(),
))?;

Ok(res)
}
}

pub struct TypedFuncHandle<P, R> {
pub func: FuncHandle,
pub(crate) marker: core::marker::PhantomData<(P, R)>,
}

pub trait IntoWasmValueTuple {
fn into_wasm_value_tuple(self) -> Vec<WasmValue>;
}

pub trait FromWasmValueTuple {
fn from_wasm_value_tuple(values: Vec<WasmValue>) -> Result<Self>
where
Self: Sized;
}

impl<P: IntoWasmValueTuple, R: FromWasmValueTuple> TypedFuncHandle<P, R> {
pub fn call(&self, store: &mut Store, params: P) -> Result<R> {
// Convert params into Vec<WasmValue>
let wasm_values = params.into_wasm_value_tuple();

// Call the underlying WASM function
let result = self.func.call(store, &wasm_values)?;

// Convert the Vec<WasmValue> back to R
R::from_wasm_value_tuple(result)
}
}
macro_rules! impl_into_wasm_value_tuple {
($($T:ident),*) => {
impl<$($T),*> IntoWasmValueTuple for ($($T,)*)
where
$($T: Into<WasmValue>),*
{
#[allow(non_snake_case)]
fn into_wasm_value_tuple(self) -> Vec<WasmValue> {
let ($($T,)*) = self;
vec![$($T.into(),)*]
}
}
}
}

impl_into_wasm_value_tuple!(T1);
impl_into_wasm_value_tuple!(T1, T2);
impl_into_wasm_value_tuple!(T1, T2, T3);
impl_into_wasm_value_tuple!(T1, T2, T3, T4);
impl_into_wasm_value_tuple!(T1, T2, T3, T4, T5);

macro_rules! impl_from_wasm_value_tuple {
($($T:ident),*) => {
impl<$($T),*> FromWasmValueTuple for ($($T,)*)
where
$($T: TryFrom<WasmValue, Error = ()>),*
{
fn from_wasm_value_tuple(values: Vec<WasmValue>) -> Result<Self> {
let mut iter = values.into_iter();
Ok((
$(
$T::try_from(
iter.next()
.ok_or(Error::Other("Not enough values in WasmValue vector".to_string()))?
)
.map_err(|_| Error::Other("Could not convert WasmValue to expected type".to_string()))?,
)*
))
}
}
}
}

impl_from_wasm_value_tuple!(T1);
impl_from_wasm_value_tuple!(T1, T2);
impl_from_wasm_value_tuple!(T1, T2, T3);
impl_from_wasm_value_tuple!(T1, T2, T3, T4);
impl_from_wasm_value_tuple!(T1, T2, T3, T4, T5);
101 changes: 34 additions & 67 deletions crates/tinywasm/src/instance.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
use alloc::{
boxed::Box,
format,
string::{String, ToString},
sync::Arc,
vec,
vec::Vec,
};
use tinywasm_types::{Export, FuncAddr, FuncType, ModuleInstanceAddr, ValType, WasmValue};
use alloc::{boxed::Box, string::ToString, sync::Arc, vec::Vec};
use tinywasm_types::{Export, FuncAddr, FuncType, ModuleInstanceAddr};

use crate::{runtime::Stack, Error, ExportInstance, Result, Store};
use crate::{
func::{FromWasmValueTuple, IntoWasmValueTuple},
ExportInstance, FuncHandle, Result, Store, TypedFuncHandle,
};

/// A WebAssembly Module Instance.
/// Addrs are indices into the store's data structures.
Expand Down Expand Up @@ -63,10 +59,35 @@ impl ModuleInstance {
})
}

/// Get a typed exported function by name
pub fn get_typed_func<P, R>(&self, store: &Store, name: &str) -> Result<TypedFuncHandle<P, R>>
where
P: IntoWasmValueTuple,
R: FromWasmValueTuple,
{
let func = self.get_func(store, name)?;
Ok(TypedFuncHandle {
func,
marker: core::marker::PhantomData,
})
}

/// Get the start function of the module
/// Returns None if the module has no start function
/// If no start function is specified, also checks for a _start function in the exports
/// (which is not part of the spec, but used by llvm)
/// https://webassembly.github.io/spec/core/syntax/modules.html#start-function
pub fn get_start_func(&mut self, store: &Store) -> Result<Option<FuncHandle>> {
let Some(func_index) = self.0.func_start else {
return Ok(None);
let func_index = match self.0.func_start {
Some(func_index) => func_index,
None => {
// alternatively, check for a _start function in the exports
let Ok(start) = self.0.exports.func("_start") else {
return Ok(None);
};

start.index
}
};

let func_addr = self.0.func_addrs[func_index as usize];
Expand All @@ -89,61 +110,7 @@ impl ModuleInstance {
return Ok(None);
};

let _ = func.call(store, vec![]);
let _ = func.call(store, &[])?;
Ok(Some(()))
}
}

#[derive(Debug)]
pub struct FuncHandle {
_module: ModuleInstance,
addr: FuncAddr,
ty: FuncType,
pub name: Option<String>,
}

impl FuncHandle {
/// Call a function
pub fn call(&self, store: &mut Store, params: Vec<WasmValue>) -> Result<Vec<WasmValue>> {
let func = store
.data
.funcs
.get(self.addr as usize)
.ok_or(Error::Other(format!("function {} not found", self.addr)))?;

let func_ty = &self.ty;

// check that params match func_ty params
for (ty, param) in func_ty.params.iter().zip(params.clone()) {
if ty != &param.val_type() {
return Err(Error::Other(format!(
"param type mismatch: expected {:?}, got {:?}",
ty, param
)));
}
}

let mut local_types: Vec<ValType> = Vec::new();
local_types.extend(func_ty.params.iter());
local_types.extend(func.locals().iter());

// let runtime = &mut store.runtime;

let mut stack = Stack::default();
stack.locals.extend(params);

let instrs = func.instructions().iter();
store.runtime.exec(&mut stack, instrs)?;

let res = func_ty
.results
.iter()
.map(|_| stack.value_stack.pop())
.collect::<Option<Vec<_>>>()
.ok_or(Error::Other(
"function did not return the correct number of values".into(),
))?;

Ok(res)
}
}
3 changes: 3 additions & 0 deletions crates/tinywasm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ pub use instance::ModuleInstance;
pub mod export;
pub use export::ExportInstance;

pub mod func;
pub use func::{FuncHandle, TypedFuncHandle};

pub use tinywasm_parser as parser;
pub use tinywasm_types::*;
pub mod runtime;
Expand Down
16 changes: 9 additions & 7 deletions crates/tinywasm/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,33 @@ impl Runtime {
instrs: core::slice::Iter<Instruction>,
) -> Result<()> {
let locals = &mut stack.locals;
let value_stack = &mut stack.value_stack;

for instr in instrs {
use tinywasm_types::Instruction::*;
match instr {
LocalGet(local_index) => {
let val = &locals[*local_index as usize];
debug!("local: {:#?}", val);
stack.value_stack.push(val.clone());
value_stack.push(val.clone());
}
I64Add => {
let a = stack.value_stack.pop().unwrap();
let b = stack.value_stack.pop().unwrap();
let a = value_stack.pop().unwrap();
let b = value_stack.pop().unwrap();
let (WasmValue::I64(a), WasmValue::I64(b)) = (a, b) else {
panic!("Invalid type");
};
let c = WasmValue::I64(a + b);
stack.value_stack.push(c);
value_stack.push(c);
}
I32Add => {
let a = stack.value_stack.pop().unwrap();
let b = stack.value_stack.pop().unwrap();
let a = value_stack.pop().unwrap();
let b = value_stack.pop().unwrap();
let (WasmValue::I32(a), WasmValue::I32(b)) = (a, b) else {
panic!("Invalid type");
};
let c = WasmValue::I32(a + b);
stack.value_stack.push(c);
value_stack.push(c);
}
End => {
return Ok(());
Expand Down
11 changes: 6 additions & 5 deletions crates/tinywasm/src/runtime/stack/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ pub struct Stack {
// TODO: Split into Vec<u8> and Vec<ValType> for better memory usage?
pub value_stack: Vec<WasmValue>, // keeping this typed for now to make it easier to debug
pub value_stack_top: usize,
// /// The call stack
// pub call_stack: Vec<CallFrame>,
// pub call_stack_top: usize,

/// The call stack
pub call_stack: Vec<CallFrame>,
pub call_stack_top: usize,
}

impl Default for Stack {
Expand All @@ -31,8 +32,8 @@ impl Default for Stack {
locals: Vec::new(),
value_stack: Vec::with_capacity(STACK_SIZE),
value_stack_top: 0,
// call_stack: Vec::with_capacity(CALL_STACK_SIZE),
// call_stack_top: 0,
call_stack: Vec::with_capacity(CALL_STACK_SIZE),
call_stack_top: 0,
}
}
}
Loading

0 comments on commit 63591a4

Please sign in to comment.