Skip to content

Commit

Permalink
Merge pull request #412 from 0xPolygonMiden/greenhat/i363-miden-abi-t…
Browse files Browse the repository at this point in the history
…ransform

[1/x] feature(frontend): generate synthetic functions with Miden ABI transformation (stdlib, tx_kernel);
  • Loading branch information
bitwalker authored Feb 24, 2025
2 parents 9b64767 + 11d6689 commit b328b09
Show file tree
Hide file tree
Showing 9 changed files with 1,079 additions and 859 deletions.
132 changes: 80 additions & 52 deletions frontend-wasm2/src/miden_abi/transform.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use midenc_dialect_hir::InstBuilder;
use midenc_hir::diagnostics::{DiagnosticsHandler, SourceSpan};
use midenc_hir2::{FunctionIdent, ValueRef};
use midenc_hir2::{dialects::builtin::FunctionRef, FunctionIdent, Immediate, Type, ValueRef};

use super::{stdlib, tx_kernel};
use crate::module::function_builder_ext::FunctionBuilderExt;
Expand Down Expand Up @@ -58,83 +58,111 @@ fn get_transform_strategy(module_id: &str, function_id: &str) -> TransformStrate
panic!("No transform strategy found for function '{function_id}' in module '{module_id}'");
}

/// Transform a function call based on the transformation strategy
/// Transform a Miden ABI function call based on the transformation strategy
///
/// `import_func` - import function that we're transforming a call to (think of a MASM function)
/// `args` - arguments to the generated synthetic function
/// Returns results that will be returned from the synthetic function
pub fn transform_miden_abi_call(
func_id: FunctionIdent,
import_func_ref: FunctionRef,
import_func_id: FunctionIdent,
args: &[ValueRef],
builder: &mut FunctionBuilderExt,
span: SourceSpan,
diagnostics: &DiagnosticsHandler,
) -> Vec<ValueRef> {
use TransformStrategy::*;
match get_transform_strategy(func_id.module.as_str(), func_id.function.as_str()) {
ListReturn => list_return(func_id, args, builder, span, diagnostics),
ReturnViaPointer => return_via_pointer(func_id, args, builder, span, diagnostics),
NoTransform => no_transform(func_id, args, builder, span, diagnostics),
match get_transform_strategy(import_func_id.module.as_str(), import_func_id.function.as_str()) {
ListReturn => list_return(import_func_ref, args, builder),
ReturnViaPointer => return_via_pointer(import_func_ref, args, builder),
NoTransform => no_transform(import_func_ref, args, builder),
}
}

/// No transformation needed
#[inline(always)]
pub fn no_transform(
func_id: FunctionIdent,
import_func_ref: FunctionRef,
args: &[ValueRef],
builder: &mut FunctionBuilderExt,
span: SourceSpan,
_diagnostics: &DiagnosticsHandler,
) -> Vec<ValueRef> {
todo!()
// let call = builder.ins().exec(func_id, args, span);
// let results = builder.inst_results(call);
// results.to_vec()
let span = import_func_ref.borrow().name().span;
let signature = import_func_ref.borrow().signature().clone();
let exec = builder
.ins()
.exec(import_func_ref, signature, args.to_vec(), span)
.expect("failed to build an exec op in no_transform strategy");

let borrow = exec.borrow();
let results_storage = borrow.as_ref().results();
let results: Vec<ValueRef> =
results_storage.iter().map(|op_res| op_res.borrow().as_value_ref()).collect();
results
}

/// The Miden ABI function returns a length and a pointer and we only want the length
pub fn list_return(
func_id: FunctionIdent,
import_func_ref: FunctionRef,
args: &[ValueRef],
builder: &mut FunctionBuilderExt,
span: SourceSpan,
_diagnostics: &DiagnosticsHandler,
) -> Vec<ValueRef> {
todo!()
// let call = builder.ins().exec(func_id, args, span);
// let results = builder.inst_results(call);
// assert_eq!(results.len(), 2, "List return strategy expects 2 results: length and pointer");
// // Return the first result (length) only
// results[0..1].to_vec()
let span = import_func_ref.borrow().name().span;
let signature = import_func_ref.borrow().signature().clone();
let exec = builder
.ins()
.exec(import_func_ref, signature, args.to_vec(), span)
.expect("failed to build an exec op in list_return strategy");

let borrow = exec.borrow();
let results_storage = borrow.as_ref().results();
let results: Vec<ValueRef> =
results_storage.iter().map(|op_res| op_res.borrow().as_value_ref()).collect();

assert_eq!(results.len(), 2, "List return strategy expects 2 results: length and pointer");
// Return the first result (length) only
results[0..1].to_vec()
}

/// The Miden ABI function returns felts on the stack and we want to return via a pointer argument
pub fn return_via_pointer(
func_id: FunctionIdent,
import_func_ref: FunctionRef,
args: &[ValueRef],
builder: &mut FunctionBuilderExt,
span: SourceSpan,
_diagnostics: &DiagnosticsHandler,
) -> Vec<ValueRef> {
todo!()
// // Omit the last argument (pointer)
// let args_wo_pointer = &args[0..args.len() - 1];
// let call = builder.ins().exec(func_id, args_wo_pointer, span);
// let results = builder.inst_results(call).to_vec();
// let ptr_arg = *args.last().unwrap();
// let ptr_arg_ty = ptr_arg.borrow().ty().clone().clone();
// assert_eq!(ptr_arg_ty, I32);
// let ptr_u32 = builder.ins().bitcast(ptr_arg, U32, span);
// let result_ty =
// midenc_hir::StructType::new(results.iter().map(|v| (*v).borrow().ty().clone().clone()));
// for (idx, value) in results.iter().enumerate() {
// let value_ty = (*value).borrow().ty().clone().clone();
// let eff_ptr = if idx == 0 {
// // We're assuming here that the base pointer is of the correct alignment
// ptr_u32
// } else {
// let imm = Immediate::U32(result_ty.get(idx).offset);
// builder.ins().add_imm_checked(ptr_u32, imm, span)
// };
// let addr = builder.ins().inttoptr(eff_ptr, Ptr(value_ty.into()), span);
// builder.ins().store(addr, *value, span);
// }
// Vec::new()
let span = import_func_ref.borrow().name().span;
// Omit the last argument (pointer)
let args_wo_pointer = &args[0..args.len() - 1];
let signature = import_func_ref.borrow().signature().clone();
let exec = builder
.ins()
.exec(import_func_ref, signature, args_wo_pointer.to_vec(), span)
.expect("failed to build an exec op in return_via_pointer strategy");

let borrow = exec.borrow();
let results_storage = borrow.as_ref().results();
let results: Vec<ValueRef> =
results_storage.iter().map(|op_res| op_res.borrow().as_value_ref()).collect();

let ptr_arg = *args.last().expect("empty args");
let ptr_arg_ty = ptr_arg.borrow().ty().clone();
assert_eq!(ptr_arg_ty, Type::I32);
let ptr_u32 = builder.ins().bitcast(ptr_arg, Type::U32, span).expect("failed bitcast to U32");

let result_ty =
midenc_hir2::StructType::new(results.iter().map(|v| (*v).borrow().ty().clone()));
for (idx, value) in results.iter().enumerate() {
let value_ty = (*value).borrow().ty().clone().clone();
let eff_ptr = if idx == 0 {
// We're assuming here that the base pointer is of the correct alignment
ptr_u32
} else {
let imm = Immediate::U32(result_ty.get(idx).offset);
let imm_val = builder.ins().imm(imm, span);
builder.ins().add(ptr_u32, imm_val, span).expect("failed add")
};
let addr = builder
.ins()
.inttoptr(eff_ptr, Type::Ptr(value_ty.into()), span)
.expect("failed inttoptr");
builder.ins().store(addr, *value, span).expect("failed store");
}
Vec::new()
}
173 changes: 136 additions & 37 deletions frontend-wasm2/src/module/module_translation_state.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
use std::{cell::RefCell, rc::Rc};

use midenc_dialect_hir::InstBuilder;
use midenc_hir::diagnostics::{DiagnosticsHandler, Severity};
use midenc_hir2::{
dialects::builtin::{ComponentBuilder, Function, FunctionRef, ModuleBuilder},
CallConv, FunctionIdent, FxHashMap, Ident, Signature, Symbol, SymbolName, SymbolNameComponent,
SymbolPath, SymbolRef, SymbolTable, UnsafeIntrusiveEntityRef, Visibility,
AbiParam, CallConv, FunctionIdent, FunctionType, FxHashMap, Ident, Op, Signature, Symbol,
SymbolName, SymbolNameComponent, SymbolPath, SymbolRef, SymbolTable, UnsafeIntrusiveEntityRef,
ValueRef, Visibility,
};

use super::{instance::ModuleArgument, ir_func_type, EntityIndex, FuncIndex, Module, ModuleTypes};
use super::{
function_builder_ext::{FunctionBuilderContext, FunctionBuilderExt},
instance::ModuleArgument,
ir_func_type, EntityIndex, FuncIndex, Module, ModuleTypes,
};
use crate::{
error::WasmResult,
intrinsics::{
intrinsics_conversion_result, is_miden_intrinsics_module, IntrinsicsConversionResult,
},
miden_abi::{is_miden_abi_module, miden_abi_function_type, recover_imported_masm_function_id},
miden_abi::{
is_miden_abi_module, miden_abi_function_type, recover_imported_masm_function_id,
transform::transform_miden_abi_call,
},
translation_utils::sig_from_func_type,
};

Expand Down Expand Up @@ -41,6 +52,8 @@ impl<'a> ModuleTranslationState<'a> {
module_args: Vec<ModuleArgument>,
diagnostics: &DiagnosticsHandler,
) -> Self {
// TODO: extract into `fn process_module_imports` after component translation is
// implemented
let mut function_import_subst = FxHashMap::default();
if module.imports.len() == module_args.len() {
for (import, arg) in module.imports.iter().zip(module_args) {
Expand All @@ -67,50 +80,44 @@ impl<'a> ModuleTranslationState<'a> {
for (index, func_type) in &module.functions {
let wasm_func_type = mod_types[func_type.signature].clone();
let ir_func_type = ir_func_type(&wasm_func_type, diagnostics).unwrap();
let func_name = module.func_name(index);
let func_id = FunctionIdent {
module: Ident::from(module.name().as_str()),
function: Ident::from(func_name.as_str()),
};
let sig = sig_from_func_type(&ir_func_type, CallConv::SystemV, Visibility::Public);
if let Some(subst) = function_import_subst.get(&index) {
// functions.insert(index, (*subst, sig));
todo!("define the import in some symbol table");
} else if module.is_imported_function(index) {
assert!((index.as_u32() as usize) < module.num_imported_funcs);
let import = &module.imports[index.as_u32() as usize];
let func_id =
let import_func_id =
recover_imported_masm_function_id(import.module.as_str(), &import.field);
let defined_function = if is_miden_intrinsics_module(func_id.module.as_symbol())
&& intrinsics_conversion_result(&func_id).is_operation()
{
CallableFunction {
wasm_id: func_id,
function_ref: None,
signature: sig,
}
} else {
let import_module_ref = if let Some(found_module_ref) =
component_builder.find_module(func_id.module.as_symbol())
{
found_module_ref
let callable_function =
if is_miden_intrinsics_module(import_func_id.module.as_symbol()) {
if intrinsics_conversion_result(&import_func_id).is_operation() {
CallableFunction {
wasm_id: import_func_id,
function_ref: None,
signature: sig,
}
} else {
define_func_for_intrinsic(component_builder, sig, import_func_id)
}
} else if is_miden_abi_module(import_func_id.module.as_symbol()) {
define_func_for_miden_abi_trans(
component_builder,
module_builder,
func_id,
sig,
import_func_id,
)
} else {
component_builder
.define_module(func_id.module)
.expect("failed to create a module for imports")
todo!("no intrinsics and no abi transformation import");
};
let mut import_module_builder = ModuleBuilder::new(import_module_ref);
let import_func_ref = import_module_builder
.define_function(func_id.function, sig.clone())
.expect("failed to create an import function");
CallableFunction {
wasm_id: func_id,
function_ref: Some(import_func_ref),
signature: sig,
}
};
functions.insert(index, defined_function);
functions.insert(index, callable_function);
} else {
let func_name = module.func_name(index);
let func_id = FunctionIdent {
module: Ident::from(module.name().as_str()),
function: Ident::from(func_name.as_str()),
};
let func_ref = module_builder
.define_function(func_id.function, sig.clone())
.expect("adding new function failed");
Expand Down Expand Up @@ -139,3 +146,95 @@ impl<'a> ModuleTranslationState<'a> {
Ok(defined_func)
}
}

fn define_func_for_miden_abi_trans(
component_builder: &mut ComponentBuilder,
module_builder: &mut ModuleBuilder,
synth_func_id: FunctionIdent,
synth_func_sig: Signature,
import_func_id: FunctionIdent,
) -> CallableFunction {
let import_ft = miden_abi_function_type(
import_func_id.module.as_symbol(),
import_func_id.function.as_symbol(),
);
let import_sig = Signature::new(
import_ft.params.into_iter().map(AbiParam::new),
import_ft.results.into_iter().map(AbiParam::new),
);
let mut func_ref = module_builder
.define_function(synth_func_id.function, synth_func_sig.clone())
.expect("failed to create an import function");
let mut func = func_ref.borrow_mut();
let span = func.name().span;
let context = func.as_operation().context_rc().clone();
let func = func.as_mut().downcast_mut::<Function>().unwrap();
let mut func_builder =
FunctionBuilderExt::new(func, Rc::new(RefCell::new(FunctionBuilderContext::new(context))));
let entry_block = func_builder.current_block();
func_builder.seal_block(entry_block); // Declare all predecessors known.
let args: Vec<ValueRef> = entry_block
.borrow()
.arguments()
.iter()
.copied()
.map(|ba| ba as ValueRef)
.collect();

let import_module_ref = if let Some(found_module_ref) =
component_builder.find_module(import_func_id.module.as_symbol())
{
found_module_ref
} else {
component_builder
.define_module(import_func_id.module)
.expect("failed to create a module for imports")
};
let mut import_module_builder = ModuleBuilder::new(import_module_ref);
let import_func_ref = import_module_builder
.define_function(import_func_id.function, import_sig.clone())
.expect("failed to create an import function");
let results = transform_miden_abi_call(
import_func_ref,
import_func_id,
args.as_slice(),
&mut func_builder,
);

let exit_block = func_builder.create_block();
func_builder.append_block_params_for_function_returns(exit_block);
func_builder.ins().br(exit_block, results, span);
func_builder.seal_block(exit_block);
func_builder.switch_to_block(exit_block);
func_builder.ins().ret(None, span).expect("failed ret");

CallableFunction {
wasm_id: synth_func_id,
function_ref: Some(func_ref),
signature: synth_func_sig,
}
}

fn define_func_for_intrinsic(
component_builder: &mut ComponentBuilder,
sig: Signature,
func_id: FunctionIdent,
) -> CallableFunction {
let import_module_ref =
if let Some(found_module_ref) = component_builder.find_module(func_id.module.as_symbol()) {
found_module_ref
} else {
component_builder
.define_module(func_id.module)
.expect("failed to create a module for imports")
};
let mut import_module_builder = ModuleBuilder::new(import_module_ref);
let import_func_ref = import_module_builder
.define_function(func_id.function, sig.clone())
.expect("failed to create an import function");
CallableFunction {
wasm_id: func_id,
function_ref: Some(import_func_ref),
signature: sig,
}
}
Loading

0 comments on commit b328b09

Please sign in to comment.