Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1/x] feature(frontend): generate synthetic functions with Miden ABI transformation (stdlib, tx_kernel); #412

Merged
merged 2 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 80 additions & 52 deletions frontend-wasm2/src/miden_abi/transform.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use midenc_dialect_hir::InstBuilder;
use midenc_hir::diagnostics::{DiagnosticsHandler, SourceSpan};
use midenc_hir2::{FunctionIdent, ValueRef};
use midenc_hir2::{dialects::builtin::FunctionRef, FunctionIdent, Immediate, Type, ValueRef};

use super::{stdlib, tx_kernel};
use crate::module::function_builder_ext::FunctionBuilderExt;
Expand Down Expand Up @@ -58,83 +58,111 @@ fn get_transform_strategy(module_id: &str, function_id: &str) -> TransformStrate
panic!("No transform strategy found for function '{function_id}' in module '{module_id}'");
}

/// Transform a function call based on the transformation strategy
/// Transform a Miden ABI function call based on the transformation strategy
///
/// `import_func` - import function that we're transforming a call to (think of a MASM function)
/// `args` - arguments to the generated synthetic function
/// Returns results that will be returned from the synthetic function
pub fn transform_miden_abi_call(
func_id: FunctionIdent,
import_func_ref: FunctionRef,
import_func_id: FunctionIdent,
args: &[ValueRef],
builder: &mut FunctionBuilderExt,
span: SourceSpan,
diagnostics: &DiagnosticsHandler,
) -> Vec<ValueRef> {
use TransformStrategy::*;
match get_transform_strategy(func_id.module.as_str(), func_id.function.as_str()) {
ListReturn => list_return(func_id, args, builder, span, diagnostics),
ReturnViaPointer => return_via_pointer(func_id, args, builder, span, diagnostics),
NoTransform => no_transform(func_id, args, builder, span, diagnostics),
match get_transform_strategy(import_func_id.module.as_str(), import_func_id.function.as_str()) {
ListReturn => list_return(import_func_ref, args, builder),
ReturnViaPointer => return_via_pointer(import_func_ref, args, builder),
NoTransform => no_transform(import_func_ref, args, builder),
}
}

/// No transformation needed
#[inline(always)]
pub fn no_transform(
func_id: FunctionIdent,
import_func_ref: FunctionRef,
args: &[ValueRef],
builder: &mut FunctionBuilderExt,
span: SourceSpan,
_diagnostics: &DiagnosticsHandler,
) -> Vec<ValueRef> {
todo!()
// let call = builder.ins().exec(func_id, args, span);
// let results = builder.inst_results(call);
// results.to_vec()
let span = import_func_ref.borrow().name().span;
let signature = import_func_ref.borrow().signature().clone();
let exec = builder
.ins()
.exec(import_func_ref, signature, args.to_vec(), span)
.expect("failed to build an exec op in no_transform strategy");

let borrow = exec.borrow();
let results_storage = borrow.as_ref().results();
let results: Vec<ValueRef> =
results_storage.iter().map(|op_res| op_res.borrow().as_value_ref()).collect();
results
}

/// The Miden ABI function returns a length and a pointer and we only want the length
pub fn list_return(
func_id: FunctionIdent,
import_func_ref: FunctionRef,
args: &[ValueRef],
builder: &mut FunctionBuilderExt,
span: SourceSpan,
_diagnostics: &DiagnosticsHandler,
) -> Vec<ValueRef> {
todo!()
// let call = builder.ins().exec(func_id, args, span);
// let results = builder.inst_results(call);
// assert_eq!(results.len(), 2, "List return strategy expects 2 results: length and pointer");
// // Return the first result (length) only
// results[0..1].to_vec()
let span = import_func_ref.borrow().name().span;
let signature = import_func_ref.borrow().signature().clone();
let exec = builder
.ins()
.exec(import_func_ref, signature, args.to_vec(), span)
.expect("failed to build an exec op in list_return strategy");

let borrow = exec.borrow();
let results_storage = borrow.as_ref().results();
let results: Vec<ValueRef> =
results_storage.iter().map(|op_res| op_res.borrow().as_value_ref()).collect();

assert_eq!(results.len(), 2, "List return strategy expects 2 results: length and pointer");
// Return the first result (length) only
results[0..1].to_vec()
}

/// The Miden ABI function returns felts on the stack and we want to return via a pointer argument
pub fn return_via_pointer(
func_id: FunctionIdent,
import_func_ref: FunctionRef,
args: &[ValueRef],
builder: &mut FunctionBuilderExt,
span: SourceSpan,
_diagnostics: &DiagnosticsHandler,
) -> Vec<ValueRef> {
todo!()
// // Omit the last argument (pointer)
// let args_wo_pointer = &args[0..args.len() - 1];
// let call = builder.ins().exec(func_id, args_wo_pointer, span);
// let results = builder.inst_results(call).to_vec();
// let ptr_arg = *args.last().unwrap();
// let ptr_arg_ty = ptr_arg.borrow().ty().clone().clone();
// assert_eq!(ptr_arg_ty, I32);
// let ptr_u32 = builder.ins().bitcast(ptr_arg, U32, span);
// let result_ty =
// midenc_hir::StructType::new(results.iter().map(|v| (*v).borrow().ty().clone().clone()));
// for (idx, value) in results.iter().enumerate() {
// let value_ty = (*value).borrow().ty().clone().clone();
// let eff_ptr = if idx == 0 {
// // We're assuming here that the base pointer is of the correct alignment
// ptr_u32
// } else {
// let imm = Immediate::U32(result_ty.get(idx).offset);
// builder.ins().add_imm_checked(ptr_u32, imm, span)
// };
// let addr = builder.ins().inttoptr(eff_ptr, Ptr(value_ty.into()), span);
// builder.ins().store(addr, *value, span);
// }
// Vec::new()
let span = import_func_ref.borrow().name().span;
// Omit the last argument (pointer)
let args_wo_pointer = &args[0..args.len() - 1];
let signature = import_func_ref.borrow().signature().clone();
let exec = builder
.ins()
.exec(import_func_ref, signature, args_wo_pointer.to_vec(), span)
.expect("failed to build an exec op in return_via_pointer strategy");

let borrow = exec.borrow();
let results_storage = borrow.as_ref().results();
let results: Vec<ValueRef> =
results_storage.iter().map(|op_res| op_res.borrow().as_value_ref()).collect();

let ptr_arg = *args.last().expect("empty args");
let ptr_arg_ty = ptr_arg.borrow().ty().clone();
assert_eq!(ptr_arg_ty, Type::I32);
let ptr_u32 = builder.ins().bitcast(ptr_arg, Type::U32, span).expect("failed bitcast to U32");

let result_ty =
midenc_hir2::StructType::new(results.iter().map(|v| (*v).borrow().ty().clone()));
for (idx, value) in results.iter().enumerate() {
let value_ty = (*value).borrow().ty().clone().clone();
let eff_ptr = if idx == 0 {
// We're assuming here that the base pointer is of the correct alignment
ptr_u32
} else {
let imm = Immediate::U32(result_ty.get(idx).offset);
let imm_val = builder.ins().imm(imm, span);
builder.ins().add(ptr_u32, imm_val, span).expect("failed add")
};
let addr = builder
.ins()
.inttoptr(eff_ptr, Type::Ptr(value_ty.into()), span)
.expect("failed inttoptr");
builder.ins().store(addr, *value, span).expect("failed store");
}
Vec::new()
}
173 changes: 136 additions & 37 deletions frontend-wasm2/src/module/module_translation_state.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
use std::{cell::RefCell, rc::Rc};

use midenc_dialect_hir::InstBuilder;
use midenc_hir::diagnostics::{DiagnosticsHandler, Severity};
use midenc_hir2::{
dialects::builtin::{ComponentBuilder, Function, FunctionRef, ModuleBuilder},
CallConv, FunctionIdent, FxHashMap, Ident, Signature, Symbol, SymbolName, SymbolNameComponent,
SymbolPath, SymbolRef, SymbolTable, UnsafeIntrusiveEntityRef, Visibility,
AbiParam, CallConv, FunctionIdent, FunctionType, FxHashMap, Ident, Op, Signature, Symbol,
SymbolName, SymbolNameComponent, SymbolPath, SymbolRef, SymbolTable, UnsafeIntrusiveEntityRef,
ValueRef, Visibility,
};

use super::{instance::ModuleArgument, ir_func_type, EntityIndex, FuncIndex, Module, ModuleTypes};
use super::{
function_builder_ext::{FunctionBuilderContext, FunctionBuilderExt},
instance::ModuleArgument,
ir_func_type, EntityIndex, FuncIndex, Module, ModuleTypes,
};
use crate::{
error::WasmResult,
intrinsics::{
intrinsics_conversion_result, is_miden_intrinsics_module, IntrinsicsConversionResult,
},
miden_abi::{is_miden_abi_module, miden_abi_function_type, recover_imported_masm_function_id},
miden_abi::{
is_miden_abi_module, miden_abi_function_type, recover_imported_masm_function_id,
transform::transform_miden_abi_call,
},
translation_utils::sig_from_func_type,
};

Expand Down Expand Up @@ -41,6 +52,8 @@ impl<'a> ModuleTranslationState<'a> {
module_args: Vec<ModuleArgument>,
diagnostics: &DiagnosticsHandler,
) -> Self {
// TODO: extract into `fn process_module_imports` after component translation is
// implemented
let mut function_import_subst = FxHashMap::default();
if module.imports.len() == module_args.len() {
for (import, arg) in module.imports.iter().zip(module_args) {
Expand All @@ -67,50 +80,44 @@ impl<'a> ModuleTranslationState<'a> {
for (index, func_type) in &module.functions {
let wasm_func_type = mod_types[func_type.signature].clone();
let ir_func_type = ir_func_type(&wasm_func_type, diagnostics).unwrap();
let func_name = module.func_name(index);
let func_id = FunctionIdent {
module: Ident::from(module.name().as_str()),
function: Ident::from(func_name.as_str()),
};
let sig = sig_from_func_type(&ir_func_type, CallConv::SystemV, Visibility::Public);
if let Some(subst) = function_import_subst.get(&index) {
// functions.insert(index, (*subst, sig));
todo!("define the import in some symbol table");
} else if module.is_imported_function(index) {
assert!((index.as_u32() as usize) < module.num_imported_funcs);
let import = &module.imports[index.as_u32() as usize];
let func_id =
let import_func_id =
recover_imported_masm_function_id(import.module.as_str(), &import.field);
let defined_function = if is_miden_intrinsics_module(func_id.module.as_symbol())
&& intrinsics_conversion_result(&func_id).is_operation()
{
CallableFunction {
wasm_id: func_id,
function_ref: None,
signature: sig,
}
} else {
let import_module_ref = if let Some(found_module_ref) =
component_builder.find_module(func_id.module.as_symbol())
{
found_module_ref
let callable_function =
if is_miden_intrinsics_module(import_func_id.module.as_symbol()) {
if intrinsics_conversion_result(&import_func_id).is_operation() {
CallableFunction {
wasm_id: import_func_id,
function_ref: None,
signature: sig,
}
} else {
define_func_for_intrinsic(component_builder, sig, import_func_id)
}
} else if is_miden_abi_module(import_func_id.module.as_symbol()) {
define_func_for_miden_abi_trans(
component_builder,
module_builder,
func_id,
sig,
import_func_id,
)
} else {
component_builder
.define_module(func_id.module)
.expect("failed to create a module for imports")
todo!("no intrinsics and no abi transformation import");
};
let mut import_module_builder = ModuleBuilder::new(import_module_ref);
let import_func_ref = import_module_builder
.define_function(func_id.function, sig.clone())
.expect("failed to create an import function");
CallableFunction {
wasm_id: func_id,
function_ref: Some(import_func_ref),
signature: sig,
}
};
functions.insert(index, defined_function);
functions.insert(index, callable_function);
} else {
let func_name = module.func_name(index);
let func_id = FunctionIdent {
module: Ident::from(module.name().as_str()),
function: Ident::from(func_name.as_str()),
};
let func_ref = module_builder
.define_function(func_id.function, sig.clone())
.expect("adding new function failed");
Expand Down Expand Up @@ -139,3 +146,95 @@ impl<'a> ModuleTranslationState<'a> {
Ok(defined_func)
}
}

fn define_func_for_miden_abi_trans(
component_builder: &mut ComponentBuilder,
module_builder: &mut ModuleBuilder,
synth_func_id: FunctionIdent,
synth_func_sig: Signature,
import_func_id: FunctionIdent,
) -> CallableFunction {
let import_ft = miden_abi_function_type(
import_func_id.module.as_symbol(),
import_func_id.function.as_symbol(),
);
let import_sig = Signature::new(
import_ft.params.into_iter().map(AbiParam::new),
import_ft.results.into_iter().map(AbiParam::new),
);
let mut func_ref = module_builder
.define_function(synth_func_id.function, synth_func_sig.clone())
.expect("failed to create an import function");
let mut func = func_ref.borrow_mut();
let span = func.name().span;
let context = func.as_operation().context_rc().clone();
let func = func.as_mut().downcast_mut::<Function>().unwrap();
let mut func_builder =
FunctionBuilderExt::new(func, Rc::new(RefCell::new(FunctionBuilderContext::new(context))));
let entry_block = func_builder.current_block();
func_builder.seal_block(entry_block); // Declare all predecessors known.
let args: Vec<ValueRef> = entry_block
.borrow()
.arguments()
.iter()
.copied()
.map(|ba| ba as ValueRef)
.collect();

let import_module_ref = if let Some(found_module_ref) =
component_builder.find_module(import_func_id.module.as_symbol())
{
found_module_ref
} else {
component_builder
.define_module(import_func_id.module)
.expect("failed to create a module for imports")
};
let mut import_module_builder = ModuleBuilder::new(import_module_ref);
let import_func_ref = import_module_builder
.define_function(import_func_id.function, import_sig.clone())
.expect("failed to create an import function");
let results = transform_miden_abi_call(
import_func_ref,
import_func_id,
args.as_slice(),
&mut func_builder,
);

let exit_block = func_builder.create_block();
func_builder.append_block_params_for_function_returns(exit_block);
func_builder.ins().br(exit_block, results, span);
func_builder.seal_block(exit_block);
func_builder.switch_to_block(exit_block);
func_builder.ins().ret(None, span).expect("failed ret");

CallableFunction {
wasm_id: synth_func_id,
function_ref: Some(func_ref),
signature: synth_func_sig,
}
}

fn define_func_for_intrinsic(
component_builder: &mut ComponentBuilder,
sig: Signature,
func_id: FunctionIdent,
) -> CallableFunction {
let import_module_ref =
if let Some(found_module_ref) = component_builder.find_module(func_id.module.as_symbol()) {
found_module_ref
} else {
component_builder
.define_module(func_id.module)
.expect("failed to create a module for imports")
};
let mut import_module_builder = ModuleBuilder::new(import_module_ref);
let import_func_ref = import_module_builder
.define_function(func_id.function, sig.clone())
.expect("failed to create an import function");
CallableFunction {
wasm_id: func_id,
function_ref: Some(import_func_ref),
signature: sig,
}
}
Loading