From 256a0346261fce951d34e537b727cf649ef8ca0d Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 16:19:09 +1100 Subject: [PATCH 01/22] Initialize string cache in context --- melior/src/context.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/melior/src/context.rs b/melior/src/context.rs index ec14451448..43ed62eae6 100644 --- a/melior/src/context.rs +++ b/melior/src/context.rs @@ -4,6 +4,7 @@ use crate::{ logical_result::LogicalResult, string_ref::StringRef, }; +use dashmap::DashMap; use mlir_sys::{ mlirContextAppendDialectRegistry, mlirContextAttachDiagnosticHandler, mlirContextCreate, mlirContextDestroy, mlirContextDetachDiagnosticHandler, mlirContextEnableMultithreading, @@ -12,7 +13,12 @@ use mlir_sys::{ mlirContextIsRegisteredOperation, mlirContextLoadAllAvailableDialects, mlirContextSetAllowUnregisteredDialects, MlirContext, MlirDiagnostic, MlirLogicalResult, }; -use std::{ffi::c_void, marker::PhantomData, mem::transmute, ops::Deref}; +use std::{ + ffi::{c_void, CString}, + marker::PhantomData, + mem::transmute, + ops::Deref, +}; /// A context of IR, dialects, and passes. /// @@ -21,6 +27,7 @@ use std::{ffi::c_void, marker::PhantomData, mem::transmute, ops::Deref}; #[derive(Debug)] pub struct Context { raw: MlirContext, + string_cache: DashMap, } impl Context { @@ -28,6 +35,7 @@ impl Context { pub fn new() -> Self { Self { raw: unsafe { mlirContextCreate() }, + string_cache: Default::default(), } } From acfd796b4d194d1c4c0b74e05c1b3594929492e3 Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 16:36:05 +1100 Subject: [PATCH 02/22] Fix --- melior/src/context.rs | 15 +++++++- melior/src/dialect/llvm.rs | 4 +-- melior/src/dialect/ods.rs | 4 +-- melior/src/execution_engine.rs | 18 +++++----- melior/src/pass/manager.rs | 12 +++---- melior/src/pass/operation_manager.rs | 14 ++++---- melior/src/string_ref.rs | 53 +++++++++++++++------------- 7 files changed, 70 insertions(+), 50 deletions(-) diff --git a/melior/src/context.rs b/melior/src/context.rs index 43ed62eae6..ff8a89a5c2 100644 --- a/melior/src/context.rs +++ b/melior/src/context.rs @@ -27,6 +27,8 @@ use std::{ #[derive(Debug)] pub struct Context { raw: MlirContext, + // We need to pass null-terminated strings to functions in the MLIR API although + // Rust's strings are not. string_cache: DashMap, } @@ -86,7 +88,9 @@ impl Context { /// Returns `true` if a given operation is registered in a context. pub fn is_registered_operation(&self, name: &str) -> bool { - unsafe { mlirContextIsRegisteredOperation(self.raw, StringRef::from(name).to_raw()) } + unsafe { + mlirContextIsRegisteredOperation(self.raw, StringRef::from_str(self, name).to_raw()) + } } /// Converts a context into a raw object. @@ -124,6 +128,15 @@ impl Context { pub fn detach_diagnostic_handler(&self, id: DiagnosticHandlerId) { unsafe { mlirContextDetachDiagnosticHandler(self.to_raw(), id.to_raw()) } } + + pub(crate) fn create_c_string(&self, string: &str) -> &CString { + let entry = self + .string_cache + .entry(CString::new(string).unwrap()) + .or_insert_with(Default::default); + + entry.key() + } } impl Drop for Context { diff --git a/melior/src/dialect/llvm.rs b/melior/src/dialect/llvm.rs index b86eed7158..cbd8ed04ce 100644 --- a/melior/src/dialect/llvm.rs +++ b/melior/src/dialect/llvm.rs @@ -359,10 +359,10 @@ mod tests { pass_manager.add_pass(pass::conversion::create_func_to_llvm()); pass_manager - .nested_under("func.func") + .nested_under(&context, "func.func") .add_pass(pass::conversion::create_arith_to_llvm()); pass_manager - .nested_under("func.func") + .nested_under(&context, "func.func") .add_pass(pass::conversion::create_index_to_llvm()); pass_manager.add_pass(pass::conversion::create_scf_to_control_flow()); pass_manager.add_pass(pass::conversion::create_control_flow_to_llvm()); diff --git a/melior/src/dialect/ods.rs b/melior/src/dialect/ods.rs index 20eae779c4..db19b1e685 100644 --- a/melior/src/dialect/ods.rs +++ b/melior/src/dialect/ods.rs @@ -129,10 +129,10 @@ mod tests { pass_manager.add_pass(pass::conversion::create_func_to_llvm()); pass_manager - .nested_under("func.func") + .nested_under(context, "func.func") .add_pass(pass::conversion::create_arith_to_llvm()); pass_manager - .nested_under("func.func") + .nested_under(context, "func.func") .add_pass(pass::conversion::create_index_to_llvm()); pass_manager.add_pass(pass::conversion::create_scf_to_control_flow()); pass_manager.add_pass(pass::conversion::create_control_flow_to_llvm()); diff --git a/melior/src/execution_engine.rs b/melior/src/execution_engine.rs index d37c42e89f..118d3b5051 100644 --- a/melior/src/execution_engine.rs +++ b/melior/src/execution_engine.rs @@ -1,4 +1,4 @@ -use crate::{ir::Module, logical_result::LogicalResult, string_ref::StringRef, Error}; +use crate::{ir::Module, logical_result::LogicalResult, string_ref::StringRef, Context, Error}; use mlir_sys::{ mlirExecutionEngineCreate, mlirExecutionEngineDestroy, mlirExecutionEngineDumpToObjectFile, mlirExecutionEngineInvokePacked, mlirExecutionEngineRegisterSymbol, MlirExecutionEngine, @@ -11,8 +11,9 @@ pub struct ExecutionEngine { impl ExecutionEngine { /// Creates an execution engine. - pub fn new( - module: &Module, + pub fn new<'c>( + context: &'c Context, + module: &Module<'c>, optimization_level: usize, shared_library_paths: &[&str], enable_object_dump: bool, @@ -25,7 +26,7 @@ impl ExecutionEngine { shared_library_paths.len() as i32, shared_library_paths .iter() - .map(|&string| StringRef::from(string).to_raw()) + .map(|&string| StringRef::from_str(context, string).to_raw()) .collect::>() .as_ptr(), enable_object_dump, @@ -105,12 +106,12 @@ mod tests { pass_manager.add_pass(pass::conversion::create_func_to_llvm()); pass_manager - .nested_under("func.func") + .nested_under(&context, "func.func") .add_pass(pass::conversion::create_arith_to_llvm()); assert_eq!(pass_manager.run(&mut module), Ok(())); - let engine = ExecutionEngine::new(&module, 2, &[], false); + let engine = ExecutionEngine::new(&context, &module, 2, &[], false); let mut argument = 42; let mut result = -1; @@ -153,11 +154,12 @@ mod tests { pass_manager.add_pass(pass::conversion::create_func_to_llvm()); pass_manager - .nested_under("func.func") + .nested_under(&context, "func.func") .add_pass(pass::conversion::create_arith_to_llvm()); assert_eq!(pass_manager.run(&mut module), Ok(())); - ExecutionEngine::new(&module, 2, &[], true).dump_to_object_file("/tmp/melior/test.o"); + ExecutionEngine::new(&context, &module, 2, &[], true) + .dump_to_object_file("/tmp/melior/test.o"); } } diff --git a/melior/src/pass/manager.rs b/melior/src/pass/manager.rs index fafbb409f1..57f8d91084 100644 --- a/melior/src/pass/manager.rs +++ b/melior/src/pass/manager.rs @@ -28,11 +28,11 @@ impl<'c> PassManager<'c> { /// Gets an operation pass manager for nested operations corresponding to a /// given name. - pub fn nested_under(&self, name: &str) -> OperationPassManager { + pub fn nested_under(&self, context: &'c Context, name: &str) -> OperationPassManager { unsafe { OperationPassManager::from_raw(mlirPassManagerGetNestedUnder( self.raw, - StringRef::from(name).to_raw(), + StringRef::from_str(context, name).to_raw(), )) } } @@ -178,15 +178,15 @@ mod tests { let manager = PassManager::new(&context); manager - .nested_under("func.func") + .nested_under(&context, "func.func") .add_pass(pass::transform::create_print_op_stats()); assert_eq!(manager.run(&mut module), Ok(())); let manager = PassManager::new(&context); manager - .nested_under("builtin.module") - .nested_under("func.func") + .nested_under(&context, "builtin.module") + .nested_under(&context, "func.func") .add_pass(pass::transform::create_print_op_stats()); assert_eq!(manager.run(&mut module), Ok(())); @@ -196,7 +196,7 @@ mod tests { fn print_pass_pipeline() { let context = create_test_context(); let manager = PassManager::new(&context); - let function_manager = manager.nested_under("func.func"); + let function_manager = manager.nested_under(&context, "func.func"); function_manager.add_pass(pass::transform::create_print_op_stats()); diff --git a/melior/src/pass/operation_manager.rs b/melior/src/pass/operation_manager.rs index b27b0240db..e4ec956bdb 100644 --- a/melior/src/pass/operation_manager.rs +++ b/melior/src/pass/operation_manager.rs @@ -1,5 +1,5 @@ use super::PassManager; -use crate::{pass::Pass, string_ref::StringRef}; +use crate::{pass::Pass, string_ref::StringRef, Context}; use mlir_sys::{ mlirOpPassManagerAddOwnedPass, mlirOpPassManagerGetNestedUnder, mlirPrintPassPipeline, MlirOpPassManager, MlirStringRef, @@ -12,19 +12,19 @@ use std::{ /// An operation pass manager. #[derive(Clone, Copy, Debug)] -pub struct OperationPassManager<'a> { +pub struct OperationPassManager<'c, 'a> { raw: MlirOpPassManager, - _parent: PhantomData<&'a PassManager<'a>>, + _parent: PhantomData<&'a PassManager<'c>>, } -impl<'a> OperationPassManager<'a> { +impl<'c, 'a> OperationPassManager<'c, 'a> { /// Gets an operation pass manager for nested operations corresponding to a /// given name. - pub fn nested_under(&self, name: &str) -> Self { + pub fn nested_under(&self, context: &'c Context, name: &str) -> Self { unsafe { Self::from_raw(mlirOpPassManagerGetNestedUnder( self.raw, - StringRef::from(name).to_raw(), + StringRef::from_str(context, name).to_raw(), )) } } @@ -52,7 +52,7 @@ impl<'a> OperationPassManager<'a> { } } -impl<'a> Display for OperationPassManager<'a> { +impl<'c, 'a> Display for OperationPassManager<'c, 'a> { fn fmt(&self, formatter: &mut Formatter) -> fmt::Result { let mut data = (formatter, Ok(())); diff --git a/melior/src/string_ref.rs b/melior/src/string_ref.rs index 9ee6859fde..89a37bdb95 100644 --- a/melior/src/string_ref.rs +++ b/melior/src/string_ref.rs @@ -1,31 +1,31 @@ -use dashmap::DashMap; +use crate::Context; use mlir_sys::{mlirStringRefCreateFromCString, mlirStringRefEqual, MlirStringRef}; -use once_cell::sync::Lazy; use std::{ - ffi::CString, marker::PhantomData, slice, str::{self, Utf8Error}, }; -// We need to pass null-terminated strings to functions in the MLIR API although -// Rust's strings are not. -static STRING_CACHE: Lazy> = Lazy::new(Default::default); - /// A string reference. // https://mlir.llvm.org/docs/CAPI/#stringref // // TODO The documentation says string refs do not have to be null-terminated. // But it looks like some functions do not handle strings not null-terminated? #[derive(Clone, Copy, Debug)] -pub struct StringRef<'a> { +pub struct StringRef<'c> { raw: MlirStringRef, - _parent: PhantomData<&'a ()>, + _parent: PhantomData<&'c Context>, } -impl<'a> StringRef<'a> { +impl<'c> StringRef<'c> { + pub fn from_str(context: &'c Context, string: &str) -> Self { + let string = context.create_c_string(string).as_ptr(); + + unsafe { Self::from_raw(mlirStringRefCreateFromCString(string)) } + } + /// Converts a string reference into a `str`. - pub fn as_str(&self) -> Result<&'a str, Utf8Error> { + pub fn as_str(&self) -> Result<&'c str, Utf8Error> { unsafe { let bytes = slice::from_raw_parts(self.raw.data as *mut u8, self.raw.length); @@ -63,32 +63,37 @@ impl<'a> PartialEq for StringRef<'a> { impl<'a> Eq for StringRef<'a> {} -impl From<&str> for StringRef<'static> { - fn from(string: &str) -> Self { - let entry = STRING_CACHE - .entry(CString::new(string).unwrap()) - .or_insert_with(Default::default); - - unsafe { Self::from_raw(mlirStringRefCreateFromCString(entry.key().as_ptr())) } - } -} - #[cfg(test)] mod tests { use super::*; #[test] fn equal() { - assert_eq!(StringRef::from("foo"), StringRef::from("foo")); + let context = Context::new(); + + assert_eq!( + StringRef::from_str(&context, "foo"), + StringRef::from_str(&context, "foo") + ); } #[test] fn equal_str() { - assert_eq!(StringRef::from("foo").as_str().unwrap(), "foo"); + let context = Context::new(); + + assert_eq!( + StringRef::from_str(&context, "foo").as_str().unwrap(), + "foo" + ); } #[test] fn not_equal() { - assert_ne!(StringRef::from("foo"), StringRef::from("bar")); + let context = Context::new(); + + assert_ne!( + StringRef::from_str(&context, "foo"), + StringRef::from_str(&context, "bar") + ); } } From e1af493ff2b011ef1ef34e8deacc9f7461b550ef Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 16:52:40 +1100 Subject: [PATCH 03/22] Fix --- macro/src/operation.rs | 18 +++-- macro/src/type.rs | 1 - melior/src/dialect/index.rs | 4 +- melior/src/dialect/llvm.rs | 84 ++++++++++++++-------- melior/src/dialect/memref.rs | 36 ++++++---- melior/src/dialect/scf.rs | 25 ++++--- melior/src/execution_engine.rs | 26 +++++-- melior/src/ir/attribute.rs | 2 +- melior/src/ir/attribute/flat_symbol_ref.rs | 2 +- melior/src/ir/attribute/string.rs | 2 +- melior/src/ir/identifier.rs | 4 +- melior/src/ir/location.rs | 4 +- melior/src/ir/module.rs | 2 +- melior/src/ir/operation.rs | 25 ++++--- melior/src/ir/operation/builder.rs | 7 +- melior/src/ir/type.rs | 2 +- melior/src/pass/external.rs | 11 +-- melior/src/utility.rs | 8 ++- 18 files changed, 170 insertions(+), 93 deletions(-) diff --git a/macro/src/operation.rs b/macro/src/operation.rs index 7c43d621a5..283c7eff37 100644 --- a/macro/src/operation.rs +++ b/macro/src/operation.rs @@ -13,23 +13,25 @@ pub fn generate_binary(dialect: &Ident, names: &[Ident]) -> Result( + context: &'c Context, lhs: crate::ir::Value<'c, '_>, rhs: crate::ir::Value<'c, '_>, location: crate::ir::Location<'c>, ) -> crate::ir::Operation<'c> { - binary_operator(#operation_name, lhs, rhs, location) + binary_operator(context, #operation_name, lhs, rhs, location) } })); } stream.extend(TokenStream::from(quote! { fn binary_operator<'c>( + context: &'c Context, name: &str, lhs: crate::ir::Value<'c, '_>, rhs: crate::ir::Value<'c, '_>, location: crate::ir::Location<'c>, ) -> crate::ir::Operation<'c> { - crate::ir::operation::OperationBuilder::new(name, location) + crate::ir::operation::OperationBuilder::new(context, name, location) .add_operands(&[lhs, rhs]) .enable_result_type_inference() .build() @@ -49,21 +51,23 @@ pub fn generate_unary(dialect: &Ident, names: &[Ident]) -> Result( + context: &'c Context, value: crate::ir::Value<'c, '_>, location: crate::ir::Location<'c>, ) -> crate::ir::Operation<'c> { - unary_operator(#operation_name, value, location) + unary_operator(context, #operation_name, value, location) } })); } stream.extend(TokenStream::from(quote! { fn unary_operator<'c>( + context: &'c Context, name: &str, value: crate::ir::Value<'c, '_>, location: crate::ir::Location<'c>, ) -> crate::ir::Operation<'c> { - crate::ir::operation::OperationBuilder::new(name, location) + crate::ir::operation::OperationBuilder::new(context, name, location) .add_operands(&[value]) .enable_result_type_inference() .build() @@ -86,23 +90,25 @@ pub fn generate_typed_unary( stream.extend(TokenStream::from(quote! { #[doc = #document] pub fn #name<'c>( + context: &'c Context, value: crate::ir::Value<'c, '_>, r#type: crate::ir::Type<'c>, location: crate::ir::Location<'c>, ) -> crate::ir::Operation<'c> { - typed_unary_operator(#operation_name, value, r#type, location) + typed_unary_operator(context, #operation_name, value, r#type, location) } })); } stream.extend(TokenStream::from(quote! { fn typed_unary_operator<'c>( + context: &'c Context, name: &str, value: crate::ir::Value<'c, '_>, r#type: crate::ir::Type<'c>, location: crate::ir::Location<'c>, ) -> crate::ir::Operation<'c> { - crate::ir::operation::OperationBuilder::new(name, location) + crate::ir::operation::OperationBuilder::new(context, name, location) .add_operands(&[value]) .add_results(&[r#type]) .build() diff --git a/macro/src/type.rs b/macro/src/type.rs index 62d01ade7e..d427f33a2c 100644 --- a/macro/src/type.rs +++ b/macro/src/type.rs @@ -1,6 +1,5 @@ use crate::utility::map_name; use convert_case::{Case, Casing}; - use proc_macro::TokenStream; use proc_macro2::Ident; use quote::quote; diff --git a/melior/src/dialect/index.rs b/melior/src/dialect/index.rs index bd123f0c5f..49b66271ab 100644 --- a/melior/src/dialect/index.rs +++ b/melior/src/dialect/index.rs @@ -17,7 +17,7 @@ pub fn constant<'c>( value: IntegerAttribute<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("index.constant", location) + OperationBuilder::new(context, "index.constant", location) .add_attributes(&[(Identifier::new(context, "value"), value.into())]) .enable_result_type_inference() .build() @@ -31,7 +31,7 @@ pub fn cmp<'c>( rhs: Value<'c, '_>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("index.cmp", location) + OperationBuilder::new(context, "index.cmp", location) .add_attributes(&[( Identifier::new(context, "pred"), Attribute::parse( diff --git a/melior/src/dialect/llvm.rs b/melior/src/dialect/llvm.rs index cbd8ed04ce..ab25bf6059 100644 --- a/melior/src/dialect/llvm.rs +++ b/melior/src/dialect/llvm.rs @@ -30,7 +30,7 @@ pub fn extract_value<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.extractvalue", location) + OperationBuilder::new(context, "llvm.extractvalue", location) .add_attributes(&[(Identifier::new(context, "position"), position.into())]) .add_operands(&[container]) .add_results(&[result_type]) @@ -46,7 +46,7 @@ pub fn get_element_ptr<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.getelementptr", location) + OperationBuilder::new(context, "llvm.getelementptr", location) .add_attributes(&[ ( Identifier::new(context, "rawConstantIndices"), @@ -71,7 +71,7 @@ pub fn get_element_ptr_dynamic<'c, const N: usize>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.getelementptr", location) + OperationBuilder::new(context, "llvm.getelementptr", location) .add_attributes(&[ ( Identifier::new(context, "rawConstantIndices"), @@ -96,7 +96,7 @@ pub fn insert_value<'c>( value: Value<'c, '_>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.insertvalue", location) + OperationBuilder::new(context, "llvm.insertvalue", location) .add_attributes(&[(Identifier::new(context, "position"), position.into())]) .add_operands(&[container, value]) .enable_result_type_inference() @@ -104,36 +104,53 @@ pub fn insert_value<'c>( } /// Creates a `llvm.mlir.undef` operation. -pub fn undef<'c>(result_type: Type<'c>, location: Location<'c>) -> Operation<'c> { - OperationBuilder::new("llvm.mlir.undef", location) +pub fn undef<'c>( + context: &'c Context, + result_type: Type<'c>, + location: Location<'c>, +) -> Operation<'c> { + OperationBuilder::new(context, "llvm.mlir.undef", location) .add_results(&[result_type]) .build() } /// Creates a `llvm.mlir.poison` operation. -pub fn poison<'c>(result_type: Type<'c>, location: Location<'c>) -> Operation<'c> { - OperationBuilder::new("llvm.mlir.poison", location) +pub fn poison<'c>( + context: &'c Context, + result_type: Type<'c>, + location: Location<'c>, +) -> Operation<'c> { + OperationBuilder::new(context, "llvm.mlir.poison", location) .add_results(&[result_type]) .build() } /// Creates a `llvm.mlir.null` operation. A null pointer. -pub fn nullptr<'c>(ptr_type: Type<'c>, location: Location<'c>) -> Operation<'c> { - OperationBuilder::new("llvm.mlir.null", location) +pub fn nullptr<'c>( + context: &'c Context, + ptr_type: Type<'c>, + location: Location<'c>, +) -> Operation<'c> { + OperationBuilder::new(context, "llvm.mlir.null", location) .add_results(&[ptr_type]) .build() } /// Creates a `llvm.unreachable` operation. -pub fn unreachable(location: Location) -> Operation { - OperationBuilder::new("llvm.unreachable", location).build() +pub fn unreachable(context: &'c Context, location: Location) -> Operation { + OperationBuilder::new(context, "llvm.unreachable", location).build() } /// Creates a `llvm.bitcast` operation. -pub fn bitcast<'c>(arg: Value<'c, '_>, res: Type<'c>, location: Location<'c>) -> Operation<'c> { - OperationBuilder::new("llvm.bitcast", location) - .add_operands(&[arg]) - .add_results(&[res]) +pub fn bitcast<'c>( + context: &'c Context, + argument: Value<'c, '_>, + result: Type<'c>, + location: Location<'c>, +) -> Operation<'c> { + OperationBuilder::new(context, "llvm.bitcast", location) + .add_operands(&[argument]) + .add_results(&[result]) .build() } @@ -145,7 +162,7 @@ pub fn alloca<'c>( location: Location<'c>, extra_options: AllocaOptions<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.alloca", location) + OperationBuilder::new(context, "llvm.alloca", location) .add_operands(&[array_size]) .add_attributes(&extra_options.into_attributes(context)) .add_results(&[ptr_type]) @@ -160,7 +177,7 @@ pub fn store<'c>( location: Location<'c>, extra_options: LoadStoreOptions<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.store", location) + OperationBuilder::new(context, "llvm.store", location) .add_operands(&[value, addr]) .add_attributes(&extra_options.into_attributes(context)) .build() @@ -174,7 +191,7 @@ pub fn load<'c>( location: Location<'c>, extra_options: LoadStoreOptions<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.load", location) + OperationBuilder::new(context, "llvm.load", location) .add_operands(&[addr]) .add_attributes(&extra_options.into_attributes(context)) .add_results(&[r#type]) @@ -190,7 +207,7 @@ pub fn func<'c>( attributes: &[(Identifier<'c>, Attribute<'c>)], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.func", location) + OperationBuilder::new(context, "llvm.func", location) .add_attributes(&[ (Identifier::new(context, "sym_name"), name.into()), (Identifier::new(context, "function_type"), r#type.into()), @@ -201,8 +218,12 @@ pub fn func<'c>( } // Creates a `llvm.return` operation. -pub fn r#return<'c>(value: Option>, location: Location<'c>) -> Operation<'c> { - let mut builder = OperationBuilder::new("llvm.return", location); +pub fn r#return<'c>( + context: &'c Context, + value: Option>, + location: Location<'c>, +) -> Operation<'c> { + let mut builder = OperationBuilder::new(context, "llvm.return", location); if let Some(value) = value { builder = builder.add_operands(&[value]); @@ -219,7 +240,7 @@ pub fn call_intrinsic<'c>( results: &[Type<'c>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.call_intrinsic", location) + OperationBuilder::new(context, "llvm.call_intrinsic", location) .add_operands(args) .add_attributes(&[(Identifier::new(context, "intrin"), intrin.into())]) .add_results(results) @@ -234,7 +255,7 @@ pub fn intr_ctlz<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.intr.ctlz", location) + OperationBuilder::new(context, "llvm.intr.ctlz", location) .add_attributes(&[( Identifier::new(context, "is_zero_poison"), IntegerAttribute::new(is_zero_poison.into(), IntegerType::new(context, 1).into()) @@ -253,7 +274,7 @@ pub fn intr_cttz<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.intr.cttz", location) + OperationBuilder::new(context, "llvm.intr.cttz", location) .add_attributes(&[( Identifier::new(context, "is_zero_poison"), IntegerAttribute::new(is_zero_poison.into(), IntegerType::new(context, 1).into()) @@ -270,7 +291,7 @@ pub fn intr_ctpop<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.intr.ctpop", location) + OperationBuilder::new(context, "llvm.intr.ctpop", location) .add_operands(&[value]) .add_results(&[result_type]) .build() @@ -278,11 +299,12 @@ pub fn intr_ctpop<'c>( /// Creates a `llvm.intr.bswap` operation. pub fn intr_bswap<'c>( + context: &'c Context, value: Value<'c, '_>, result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.intr.bswap", location) + OperationBuilder::new(context, "llvm.intr.bswap", location) .add_operands(&[value]) .add_results(&[result_type]) .build() @@ -290,11 +312,12 @@ pub fn intr_bswap<'c>( /// Creates a `llvm.intr.bitreverse` operation. pub fn intr_bitreverse<'c>( + context: &'c Context, value: Value<'c, '_>, result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.intr.bitreverse", location) + OperationBuilder::new(context, "llvm.intr.bitreverse", location) .add_operands(&[value]) .add_results(&[result_type]) .build() @@ -308,7 +331,7 @@ pub fn intr_abs<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.intr.abs", location) + OperationBuilder::new(context, "llvm.intr.abs", location) .add_attributes(&[( Identifier::new(context, "is_int_min_poison"), IntegerAttribute::new( @@ -324,11 +347,12 @@ pub fn intr_abs<'c>( /// Creates a `llvm.zext` operation. pub fn zext<'c>( + context: &'c Context, value: Value<'c, '_>, result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.zext", location) + OperationBuilder::new(context, "llvm.zext", location) .add_operands(&[value]) .add_results(&[result_type]) .build() diff --git a/melior/src/dialect/memref.rs b/melior/src/dialect/memref.rs index 7f6351ec4f..7c1f83b380 100644 --- a/melior/src/dialect/memref.rs +++ b/melior/src/dialect/memref.rs @@ -62,7 +62,7 @@ fn allocate<'c>( alignment: Option>, location: Location<'c>, ) -> Operation<'c> { - let mut builder = OperationBuilder::new(name, location); + let mut builder = OperationBuilder::new(context, name, location); builder = builder.add_attributes(&[( Identifier::new(context, "operand_segment_sizes"), @@ -81,30 +81,36 @@ fn allocate<'c>( /// Create a `memref.cast` operation. pub fn cast<'c>( + context: &'c Context, value: Value<'c, '_>, r#type: MemRefType<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("memref.cast", location) + OperationBuilder::new(context, "memref.cast", location) .add_operands(&[value]) .add_results(&[r#type.into()]) .build() } /// Create a `memref.dealloc` operation. -pub fn dealloc<'c>(value: Value<'c, '_>, location: Location<'c>) -> Operation<'c> { - OperationBuilder::new("memref.dealloc", location) +pub fn dealloc<'c>( + context: &'c Context, + value: Value<'c, '_>, + location: Location<'c>, +) -> Operation<'c> { + OperationBuilder::new(context, "memref.dealloc", location) .add_operands(&[value]) .build() } /// Create a `memref.dim` operation. pub fn dim<'c>( + context: &'c Context, value: Value<'c, '_>, index: Value<'c, '_>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("memref.dim", location) + OperationBuilder::new(context, "memref.dim", location) .add_operands(&[value, index]) .enable_result_type_inference() .build() @@ -117,7 +123,7 @@ pub fn get_global<'c>( r#type: MemRefType<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("memref.get_global", location) + OperationBuilder::new(context, "memref.get_global", location) .add_attributes(&[( Identifier::new(context, "name"), FlatSymbolRefAttribute::new(context, name).into(), @@ -138,7 +144,7 @@ pub fn global<'c>( alignment: Option>, location: Location<'c>, ) -> Operation<'c> { - let mut builder = OperationBuilder::new("memref.global", location).add_attributes(&[ + let mut builder = OperationBuilder::new(context, "memref.global", location).add_attributes(&[ ( Identifier::new(context, "sym_name"), StringAttribute::new(context, name).into(), @@ -177,11 +183,12 @@ pub fn global<'c>( /// Create a `memref.load` operation. pub fn load<'c>( + context: &'c Context, memref: Value<'c, '_>, indices: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("memref.load", location) + OperationBuilder::new(context, "memref.load", location) .add_operands(&[memref]) .add_operands(indices) .enable_result_type_inference() @@ -189,8 +196,12 @@ pub fn load<'c>( } /// Create a `memref.rank` operation. -pub fn rank<'c>(value: Value<'c, '_>, location: Location<'c>) -> Operation<'c> { - OperationBuilder::new("memref.rank", location) +pub fn rank<'c>( + context: &'c Context, + value: Value<'c, '_>, + location: Location<'c>, +) -> Operation<'c> { + OperationBuilder::new(context, "memref.rank", location) .add_operands(&[value]) .enable_result_type_inference() .build() @@ -198,12 +209,13 @@ pub fn rank<'c>(value: Value<'c, '_>, location: Location<'c>) -> Operation<'c> { /// Create a `memref.store` operation. pub fn store<'c>( + context: &'c Context, value: Value<'c, '_>, memref: Value<'c, '_>, indices: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("memref.store", location) + OperationBuilder::new(context, "memref.store", location) .add_operands(&[value, memref]) .add_operands(indices) .build() @@ -218,7 +230,7 @@ pub fn realloc<'c>( alignment: Option>, location: Location<'c>, ) -> Operation<'c> { - let mut builder = OperationBuilder::new("memref.realloc", location) + let mut builder = OperationBuilder::new(context, "memref.realloc", location) .add_operands(&[value]) .add_results(&[r#type.into()]); diff --git a/melior/src/dialect/scf.rs b/melior/src/dialect/scf.rs index f83bb4e859..dcb7a2a049 100644 --- a/melior/src/dialect/scf.rs +++ b/melior/src/dialect/scf.rs @@ -10,11 +10,12 @@ use crate::{ /// Creates a `scf.condition` operation. pub fn condition<'c>( + context: &'c Context, condition: Value<'c, '_>, values: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("scf.condition", location) + OperationBuilder::new(context, "scf.condition", location) .add_operands(&[condition]) .add_operands(values) .build() @@ -22,11 +23,12 @@ pub fn condition<'c>( /// Creates a `scf.execute_region` operation. pub fn execute_region<'c>( + context: &'c Context, result_types: &[Type<'c>], region: Region<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("scf.execute_region", location) + OperationBuilder::new(context, "scf.execute_region", location) .add_results(result_types) .add_regions(vec![region]) .build() @@ -34,13 +36,14 @@ pub fn execute_region<'c>( /// Creates a `scf.for` operation. pub fn r#for<'c>( + context: &'c Context, start: Value<'c, '_>, end: Value<'c, '_>, step: Value<'c, '_>, region: Region<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("scf.for", location) + OperationBuilder::new(context, "scf.for", location) .add_operands(&[start, end, step]) .add_regions(vec![region]) .build() @@ -48,13 +51,14 @@ pub fn r#for<'c>( /// Creates a `scf.if` operation. pub fn r#if<'c>( + context: &'c Context, condition: Value<'c, '_>, result_types: &[Type<'c>], then_region: Region<'c>, else_region: Region<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("scf.if", location) + OperationBuilder::new(context, "scf.if", location) .add_operands(&[condition]) .add_results(result_types) .add_regions(vec![then_region, else_region]) @@ -70,7 +74,7 @@ pub fn index_switch<'c>( regions: Vec>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("scf.index_switch", location) + OperationBuilder::new(context, "scf.index_switch", location) .add_operands(&[condition]) .add_results(result_types) .add_attributes(&[(Identifier::new(context, "cases"), cases.into())]) @@ -80,13 +84,14 @@ pub fn index_switch<'c>( /// Creates a `scf.while` operation. pub fn r#while<'c>( + context: &'c Context, initial_values: &[Value<'c, '_>], result_types: &[Type<'c>], before_region: Region<'c>, after_region: Region<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("scf.while", location) + OperationBuilder::new(context, "scf.while", location) .add_operands(initial_values) .add_results(result_types) .add_regions(vec![before_region, after_region]) @@ -94,8 +99,12 @@ pub fn r#while<'c>( } /// Creates a `scf.yield` operation. -pub fn r#yield<'c>(values: &[Value<'c, '_>], location: Location<'c>) -> Operation<'c> { - OperationBuilder::new("scf.yield", location) +pub fn r#yield<'c>( + context: &'c Context, + values: &[Value<'c, '_>], + location: Location<'c>, +) -> Operation<'c> { + OperationBuilder::new(context, "scf.yield", location) .add_operands(values) .build() } diff --git a/melior/src/execution_engine.rs b/melior/src/execution_engine.rs index 118d3b5051..6b8e43cb43 100644 --- a/melior/src/execution_engine.rs +++ b/melior/src/execution_engine.rs @@ -43,10 +43,15 @@ impl ExecutionEngine { /// This function modifies memory locations pointed by the `arguments` /// argument. If those pointers are invalid or misaligned, calling this /// function might result in undefined behavior. - pub unsafe fn invoke_packed(&self, name: &str, arguments: &mut [*mut ()]) -> Result<(), Error> { + pub unsafe fn invoke_packed( + &self, + context: &'c Context, + name: &str, + arguments: &mut [*mut ()], + ) -> Result<(), Error> { let result = LogicalResult::from_raw(mlirExecutionEngineInvokePacked( self.raw, - StringRef::from(name).to_raw(), + StringRef::from_str(context, name).to_raw(), arguments.as_mut_ptr() as _, )); @@ -64,13 +69,22 @@ impl ExecutionEngine { /// This function makes a pointer accessible to the execution engine. If a /// given pointer is invalid or misaligned, calling this function might /// result in undefined behavior. - pub unsafe fn register_symbol(&self, name: &str, ptr: *mut ()) { - mlirExecutionEngineRegisterSymbol(self.raw, StringRef::from(name).to_raw(), ptr as _); + pub unsafe fn register_symbol(&self, context: &'c Context, name: &str, ptr: *mut ()) { + mlirExecutionEngineRegisterSymbol( + self.raw, + StringRef::from_str(context, name).to_raw(), + ptr as _, + ); } /// Dumps a module to an object file. - pub fn dump_to_object_file(&self, path: &str) { - unsafe { mlirExecutionEngineDumpToObjectFile(self.raw, StringRef::from(path).to_raw()) } + pub fn dump_to_object_file(&self, context: &'c Context, path: &str) { + unsafe { + mlirExecutionEngineDumpToObjectFile( + self.raw, + StringRef::from_str(context, path).to_raw(), + ) + } } } diff --git a/melior/src/ir/attribute.rs b/melior/src/ir/attribute.rs index aaa09f8168..2f5766d3b6 100644 --- a/melior/src/ir/attribute.rs +++ b/melior/src/ir/attribute.rs @@ -44,7 +44,7 @@ impl<'c> Attribute<'c> { unsafe { Self::from_option_raw(mlirAttributeParseGet( context.to_raw(), - StringRef::from(source).to_raw(), + StringRef::from_str(context, source).to_raw(), )) } } diff --git a/melior/src/ir/attribute/flat_symbol_ref.rs b/melior/src/ir/attribute/flat_symbol_ref.rs index ccb346d7cf..c98933de04 100644 --- a/melior/src/ir/attribute/flat_symbol_ref.rs +++ b/melior/src/ir/attribute/flat_symbol_ref.rs @@ -14,7 +14,7 @@ impl<'c> FlatSymbolRefAttribute<'c> { unsafe { Self::from_raw(mlirFlatSymbolRefAttrGet( context.to_raw(), - StringRef::from(symbol).to_raw(), + StringRef::from_str(context, symbol).to_raw(), )) } } diff --git a/melior/src/ir/attribute/string.rs b/melior/src/ir/attribute/string.rs index 4c2eec8591..22fa0c542d 100644 --- a/melior/src/ir/attribute/string.rs +++ b/melior/src/ir/attribute/string.rs @@ -14,7 +14,7 @@ impl<'c> StringAttribute<'c> { unsafe { Self::from_raw(mlirStringAttrGet( context.to_raw(), - StringRef::from(string).to_raw(), + StringRef::from_str(context, string).to_raw(), )) } } diff --git a/melior/src/ir/identifier.rs b/melior/src/ir/identifier.rs index 6dd118a8a2..ea84e458f8 100644 --- a/melior/src/ir/identifier.rs +++ b/melior/src/ir/identifier.rs @@ -17,11 +17,11 @@ pub struct Identifier<'c> { impl<'c> Identifier<'c> { /// Creates an identifier. - pub fn new(context: &Context, name: &str) -> Self { + pub fn new(context: &'c Context, name: &str) -> Self { unsafe { Self::from_raw(mlirIdentifierGet( context.to_raw(), - StringRef::from(name).to_raw(), + StringRef::from_str(context, name).to_raw(), )) } } diff --git a/melior/src/ir/location.rs b/melior/src/ir/location.rs index bb9e584fee..3179842f2c 100644 --- a/melior/src/ir/location.rs +++ b/melior/src/ir/location.rs @@ -27,7 +27,7 @@ impl<'c> Location<'c> { unsafe { Self::from_raw(mlirLocationFileLineColGet( context.to_raw(), - StringRef::from(filename).to_raw(), + StringRef::from_str(context, filename).to_raw(), line as u32, column as u32, )) @@ -51,7 +51,7 @@ impl<'c> Location<'c> { unsafe { Self::from_raw(mlirLocationNameGet( context.to_raw(), - StringRef::from(name).to_raw(), + StringRef::from_str(context, name).to_raw(), child.to_raw(), )) } diff --git a/melior/src/ir/module.rs b/melior/src/ir/module.rs index cf0c1d60f7..8a21948a1b 100644 --- a/melior/src/ir/module.rs +++ b/melior/src/ir/module.rs @@ -28,7 +28,7 @@ impl<'c> Module<'c> { unsafe { Self::from_option_raw(mlirModuleCreateParse( context.to_raw(), - StringRef::from(source).to_raw(), + StringRef::from_str(context, source).to_raw(), )) } } diff --git a/melior/src/ir/operation.rs b/melior/src/ir/operation.rs index 3d37f6809e..0f01cd53e3 100644 --- a/melior/src/ir/operation.rs +++ b/melior/src/ir/operation.rs @@ -184,37 +184,42 @@ impl<'c> Operation<'c> { } /// Gets a attribute with the given name. - pub fn attribute(&self, name: &str) -> Result, Error> { + pub fn attribute(&self, context: &'c Context, name: &str) -> Result, Error> { unsafe { Attribute::from_option_raw(mlirOperationGetAttributeByName( self.raw, - StringRef::from(name).to_raw(), + StringRef::from_str(context, name).to_raw(), )) } .ok_or(Error::AttributeNotFound(name.into())) } /// Checks if the operation has a attribute with the given name. - pub fn has_attribute(&self, name: &str) -> bool { - self.attribute(name).is_ok() + pub fn has_attribute(&self, context: &'c Context, name: &str) -> bool { + self.attribute(context, name).is_ok() } /// Sets the attribute with the given name to the given attribute. - pub fn set_attribute(&mut self, name: &str, attribute: &Attribute<'c>) { + pub fn set_attribute(&mut self, context: &'c Context, name: &str, attribute: &Attribute<'c>) { unsafe { mlirOperationSetAttributeByName( self.raw, - StringRef::from(name).to_raw(), + StringRef::from_str(context, name).to_raw(), attribute.to_raw(), ) } } /// Removes the attribute with the given name. - pub fn remove_attribute(&mut self, name: &str) -> Result<(), Error> { - unsafe { mlirOperationRemoveAttributeByName(self.raw, StringRef::from(name).to_raw()) } - .then_some(()) - .ok_or(Error::AttributeNotFound(name.into())) + pub fn remove_attribute(&mut self, context: &'c Context, name: &str) -> Result<(), Error> { + unsafe { + mlirOperationRemoveAttributeByName( + self.raw, + StringRef::from_str(context, name).to_raw(), + ) + } + .then_some(()) + .ok_or(Error::AttributeNotFound(name.into())) } /// Gets the next operation in the same block. diff --git a/melior/src/ir/operation/builder.rs b/melior/src/ir/operation/builder.rs index 954673e8c9..dadcb0e94e 100644 --- a/melior/src/ir/operation/builder.rs +++ b/melior/src/ir/operation/builder.rs @@ -20,10 +20,13 @@ pub struct OperationBuilder<'c> { impl<'c> OperationBuilder<'c> { /// Creates an operation builder. - pub fn new(name: &str, location: Location<'c>) -> Self { + pub fn new(context: &'c Context, name: &str, location: Location<'c>) -> Self { Self { raw: unsafe { - mlirOperationStateGet(StringRef::from(name).to_raw(), location.to_raw()) + mlirOperationStateGet( + StringRef::from_str(context, name).to_raw(), + location.to_raw(), + ) }, _context: Default::default(), } diff --git a/melior/src/ir/type.rs b/melior/src/ir/type.rs index 0a3e882136..479cd0881a 100644 --- a/melior/src/ir/type.rs +++ b/melior/src/ir/type.rs @@ -43,7 +43,7 @@ impl<'c> Type<'c> { unsafe { Self::from_option_raw(mlirTypeParseGet( context.to_raw(), - StringRef::from(source).to_raw(), + StringRef::from_str(context, source).to_raw(), )) } } diff --git a/melior/src/pass/external.rs b/melior/src/pass/external.rs index 0fa3df2eac..0eb2b4f914 100644 --- a/melior/src/pass/external.rs +++ b/melior/src/pass/external.rs @@ -4,7 +4,7 @@ use super::Pass; use crate::{ dialect::DialectHandle, ir::{r#type::TypeId, OperationRef}, - ContextRef, StringRef, + Context, ContextRef, StringRef, }; use mlir_sys::{ mlirCreateExternalPass, mlirExternalPassSignalFailure, MlirContext, MlirExternalPass, @@ -162,6 +162,7 @@ impl<'c, F: FnMut(OperationRef<'c, '_>, ExternalPass<'_>) + Clone> RunExternalPa /// ); /// ``` pub fn create_external<'c, T: RunExternalPass<'c>>( + context: &'c Context, pass: T, pass_id: TypeId, name: &str, @@ -173,10 +174,10 @@ pub fn create_external<'c, T: RunExternalPass<'c>>( unsafe { Pass::from_raw(mlirCreateExternalPass( pass_id.to_raw(), - StringRef::from(name).to_raw(), - StringRef::from(argument).to_raw(), - StringRef::from(description).to_raw(), - StringRef::from(op_name).to_raw(), + StringRef::from_str(context, name).to_raw(), + StringRef::from_str(context, argument).to_raw(), + StringRef::from_str(context, description).to_raw(), + StringRef::from_str(context, op_name).to_raw(), dependent_dialects.len() as isize, dependent_dialects.as_ptr() as _, MlirExternalPassCallbacks { diff --git a/melior/src/utility.rs b/melior/src/utility.rs index cb632b6334..a93ecc8bc4 100644 --- a/melior/src/utility.rs +++ b/melior/src/utility.rs @@ -33,13 +33,17 @@ pub fn register_all_passes() { } /// Parses a pass pipeline. -pub fn parse_pass_pipeline(manager: pass::OperationPassManager, source: &str) -> Result<(), Error> { +pub fn parse_pass_pipeline( + context: &Context, + manager: pass::OperationPassManager, + source: &str, +) -> Result<(), Error> { let mut error_message = None; let result = LogicalResult::from_raw(unsafe { mlirParsePassPipeline( manager.to_raw(), - StringRef::from(source).to_raw(), + StringRef::from_str(context, source).to_raw(), Some(handle_parse_error), &mut error_message as *mut _ as *mut _, ) From 2dd04d9644b1ba0c0d996bad8284f1f239a56707 Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 16:55:08 +1100 Subject: [PATCH 04/22] Fix --- melior/src/dialect/cf.rs | 25 +++++++++++++------------ melior/src/dialect/func.rs | 17 +++++++++++------ 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/melior/src/dialect/cf.rs b/melior/src/dialect/cf.rs index ce553edaed..64505d28b2 100644 --- a/melior/src/dialect/cf.rs +++ b/melior/src/dialect/cf.rs @@ -19,7 +19,7 @@ pub fn assert<'c>( message: &str, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("cf.assert", location) + OperationBuilder::new(context, "cf.assert", location) .add_attributes(&[( Identifier::new(context, "msg"), StringAttribute::new(context, message).into(), @@ -30,11 +30,12 @@ pub fn assert<'c>( /// Creates a `cf.br` operation. pub fn br<'c>( + context: &'c Context, successor: &Block<'c>, destination_operands: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("cf.br", location) + OperationBuilder::new(context, "cf.br", location) .add_operands(destination_operands) .add_successors(&[successor]) .build() @@ -50,7 +51,7 @@ pub fn cond_br<'c>( false_successor_operands: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("cf.cond_br", location) + OperationBuilder::new(context, "cf.cond_br", location) .add_attributes(&[( Identifier::new(context, "operand_segment_sizes"), DenseI32ArrayAttribute::new( @@ -89,7 +90,7 @@ pub fn switch<'c>( .chain(case_destinations.iter().copied()) .unzip(); - Ok(OperationBuilder::new("cf.switch", location) + Ok(OperationBuilder::new(context, "cf.switch", location) .add_attributes(&[ ( Identifier::new(context, "case_values"), @@ -183,7 +184,7 @@ mod tests { block.append_operation(assert(&context, operand, "assert message", location)); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -222,9 +223,9 @@ mod tests { .result(0) .unwrap(); - block.append_operation(br(&dest_block, &[operand.into()], location)); + block.append_operation(br(&context, &dest_block, &[operand.into()], location)); - dest_block.append_operation(func::r#return(&[], location)); + dest_block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -289,8 +290,8 @@ mod tests { location, )); - true_block.append_operation(func::r#return(&[], location)); - false_block.append_operation(func::r#return(&[], location)); + true_block.append_operation(func::r#return(&context, &[], location)); + false_block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -348,9 +349,9 @@ mod tests { .unwrap(), ); - default_block.append_operation(func::r#return(&[], location)); - first_block.append_operation(func::r#return(&[], location)); - second_block.append_operation(func::r#return(&[], location)); + default_block.append_operation(func::r#return(&context, &[], location)); + first_block.append_operation(func::r#return(&context, &[], location)); + second_block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); diff --git a/melior/src/dialect/func.rs b/melior/src/dialect/func.rs index deb5256512..42a9a8336e 100644 --- a/melior/src/dialect/func.rs +++ b/melior/src/dialect/func.rs @@ -18,7 +18,7 @@ pub fn call<'c>( result_types: &[Type<'c>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("func.call", location) + OperationBuilder::new(context, "func.call", location) .add_attributes(&[(Identifier::new(context, "callee"), function.into())]) .add_operands(arguments) .add_results(result_types) @@ -27,12 +27,13 @@ pub fn call<'c>( /// Create a `func.call_indirect` operation. pub fn call_indirect<'c>( + context: &'c Context, function: Value<'c, '_>, arguments: &[Value<'c, '_>], result_types: &[Type<'c>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("func.call_indirect", location) + OperationBuilder::new(context, "func.call_indirect", location) .add_operands(&[function]) .add_operands(arguments) .add_results(result_types) @@ -46,7 +47,7 @@ pub fn constant<'c>( r#type: FunctionType<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("func.constant", location) + OperationBuilder::new(context, "func.constant", location) .add_attributes(&[(Identifier::new(context, "value"), function.into())]) .add_results(&[r#type.into()]) .build() @@ -61,7 +62,7 @@ pub fn func<'c>( attributes: &[(Identifier<'c>, Attribute<'c>)], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("func.func", location) + OperationBuilder::new(context, "func.func", location) .add_attributes(&[ (Identifier::new(context, "sym_name"), name.into()), (Identifier::new(context, "function_type"), r#type.into()), @@ -72,8 +73,12 @@ pub fn func<'c>( } /// Create a `func.return` operation. -pub fn r#return<'c>(operands: &[Value<'c, '_>], location: Location<'c>) -> Operation<'c> { - OperationBuilder::new("func.return", location) +pub fn r#return<'c>( + context: &'c Context, + operands: &[Value<'c, '_>], + location: Location<'c>, +) -> Operation<'c> { + OperationBuilder::new(context, "func.return", location) .add_operands(operands) .build() } From e4b329ba227617a18d92dfcc5771a5c49134ba9c Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 18:33:47 +1100 Subject: [PATCH 05/22] Fix --- melior/src/context.rs | 10 ++++------ melior/src/dialect/arith.rs | 34 +++++++++++++++++++++++++++++----- melior/src/dialect/llvm.rs | 3 ++- melior/src/execution_engine.rs | 6 +++--- 4 files changed, 38 insertions(+), 15 deletions(-) diff --git a/melior/src/context.rs b/melior/src/context.rs index ff8a89a5c2..42ccb06348 100644 --- a/melior/src/context.rs +++ b/melior/src/context.rs @@ -56,7 +56,7 @@ impl Context { unsafe { Dialect::from_raw(mlirContextGetOrLoadDialect( self.raw, - StringRef::from(name).to_raw(), + StringRef::from_str(&self, name).to_raw(), )) } } @@ -130,12 +130,10 @@ impl Context { } pub(crate) fn create_c_string(&self, string: &str) -> &CString { - let entry = self - .string_cache + self.string_cache .entry(CString::new(string).unwrap()) - .or_insert_with(Default::default); - - entry.key() + .or_insert_with(Default::default) + .key() } } diff --git a/melior/src/dialect/arith.rs b/melior/src/dialect/arith.rs index 38e06a57ec..f3de7ab665 100644 --- a/melior/src/dialect/arith.rs +++ b/melior/src/dialect/arith.rs @@ -16,7 +16,7 @@ pub fn constant<'c>( value: Attribute<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("arith.constant", location) + OperationBuilder::new(context, "arith.constant", location) .add_attributes(&[(Identifier::new(context, "value"), value)]) .enable_result_type_inference() .build() @@ -86,7 +86,7 @@ fn cmp<'c>( rhs: Value<'c, '_>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(name, location) + OperationBuilder::new(context, name, location) .add_attributes(&[( Identifier::new(context, "predicate"), IntegerAttribute::new(predicate, IntegerType::new(context, 64).into()).into(), @@ -98,12 +98,13 @@ fn cmp<'c>( /// Creates an `arith.select` operation. pub fn select<'c>( + context: &'c Context, condition: Value<'c, '_>, true_value: Value<'c, '_>, false_value: Value<'c, '_>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("arith.select", location) + OperationBuilder::new(context, "arith.select", location) .add_operands(&[condition, true_value, false_value]) .add_results(&[true_value.r#type()]) .build() @@ -205,6 +206,7 @@ mod tests { let name = name.as_string_ref().as_str().unwrap(); block.append_operation(func::r#return( + context, &[block.append_operation(operation).result(0).unwrap().into()], location, )); @@ -255,6 +257,7 @@ mod tests { &context, |block| { negf( + &context, block.argument(0).unwrap().into(), Location::unknown(&context), ) @@ -331,6 +334,7 @@ mod tests { &context, |block| { bitcast( + &context, block.argument(0).unwrap().into(), float_type, Location::unknown(&context), @@ -349,6 +353,7 @@ mod tests { &context, |block| { extf( + &context, block.argument(0).unwrap().into(), Type::float64(&context), Location::unknown(&context), @@ -371,6 +376,7 @@ mod tests { &context, |block| { extsi( + &context, block.argument(0).unwrap().into(), IntegerType::new(&context, 64).into(), Location::unknown(&context), @@ -393,6 +399,7 @@ mod tests { &context, |block| { extui( + &context, block.argument(0).unwrap().into(), IntegerType::new(&context, 64).into(), Location::unknown(&context), @@ -415,6 +422,7 @@ mod tests { &context, |block| { fptosi( + &context, block.argument(0).unwrap().into(), IntegerType::new(&context, 64).into(), Location::unknown(&context), @@ -437,6 +445,7 @@ mod tests { &context, |block| { fptoui( + &context, block.argument(0).unwrap().into(), IntegerType::new(&context, 64).into(), Location::unknown(&context), @@ -459,6 +468,7 @@ mod tests { &context, |block| { index_cast( + &context, block.argument(0).unwrap().into(), IntegerType::new(&context, 64).into(), Location::unknown(&context), @@ -481,6 +491,7 @@ mod tests { &context, |block| { index_castui( + &context, block.argument(0).unwrap().into(), IntegerType::new(&context, 64).into(), Location::unknown(&context), @@ -503,6 +514,7 @@ mod tests { &context, |block| { sitofp( + &context, block.argument(0).unwrap().into(), Type::float64(&context), Location::unknown(&context), @@ -525,6 +537,7 @@ mod tests { &context, |block| { trunci( + &context, block.argument(0).unwrap().into(), IntegerType::new(&context, 32).into(), Location::unknown(&context), @@ -547,6 +560,7 @@ mod tests { &context, |block| { uitofp( + &context, block.argument(0).unwrap().into(), Type::float64(&context), Location::unknown(&context), @@ -576,12 +590,17 @@ mod tests { let block = Block::new(&[(integer_type, location), (integer_type, location)]); let sum = block.append_operation(addi( + &context, block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )); - block.append_operation(func::r#return(&[sum.result(0).unwrap().into()], location)); + block.append_operation(func::r#return( + &context, + &[sum.result(0).unwrap().into()], + location, + )); let region = Region::new(); region.append_block(block); @@ -624,13 +643,18 @@ mod tests { ]); let val = block.append_operation(select( + &context, block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), block.argument(2).unwrap().into(), location, )); - block.append_operation(func::r#return(&[val.result(0).unwrap().into()], location)); + block.append_operation(func::r#return( + &context, + &[val.result(0).unwrap().into()], + location, + )); let region = Region::new(); region.append_block(block); diff --git a/melior/src/dialect/llvm.rs b/melior/src/dialect/llvm.rs index ab25bf6059..4ee4c94e16 100644 --- a/melior/src/dialect/llvm.rs +++ b/melior/src/dialect/llvm.rs @@ -137,7 +137,7 @@ pub fn nullptr<'c>( } /// Creates a `llvm.unreachable` operation. -pub fn unreachable(context: &'c Context, location: Location) -> Operation { +pub fn unreachable<'c>(context: &'c Context, location: Location<'c>) -> Operation<'c> { OperationBuilder::new(context, "llvm.unreachable", location).build() } @@ -287,6 +287,7 @@ pub fn intr_cttz<'c>( /// Creates a `llvm.intr.ctlz` operation. pub fn intr_ctpop<'c>( + context: &'c Context, value: Value<'c, '_>, result_type: Type<'c>, location: Location<'c>, diff --git a/melior/src/execution_engine.rs b/melior/src/execution_engine.rs index 6b8e43cb43..c200a3cafd 100644 --- a/melior/src/execution_engine.rs +++ b/melior/src/execution_engine.rs @@ -45,7 +45,7 @@ impl ExecutionEngine { /// function might result in undefined behavior. pub unsafe fn invoke_packed( &self, - context: &'c Context, + context: &Context, name: &str, arguments: &mut [*mut ()], ) -> Result<(), Error> { @@ -69,7 +69,7 @@ impl ExecutionEngine { /// This function makes a pointer accessible to the execution engine. If a /// given pointer is invalid or misaligned, calling this function might /// result in undefined behavior. - pub unsafe fn register_symbol(&self, context: &'c Context, name: &str, ptr: *mut ()) { + pub unsafe fn register_symbol(&self, context: &Context, name: &str, ptr: *mut ()) { mlirExecutionEngineRegisterSymbol( self.raw, StringRef::from_str(context, name).to_raw(), @@ -78,7 +78,7 @@ impl ExecutionEngine { } /// Dumps a module to an object file. - pub fn dump_to_object_file(&self, context: &'c Context, path: &str) { + pub fn dump_to_object_file(&self, context: &Context, path: &str) { unsafe { mlirExecutionEngineDumpToObjectFile( self.raw, From 34f24aaa899c8a6f21263d6c2c32087806fe4b72 Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 18:52:36 +1100 Subject: [PATCH 06/22] Fix --- melior/src/context.rs | 7 ++----- melior/src/string_ref.rs | 8 +++++++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/melior/src/context.rs b/melior/src/context.rs index 42ccb06348..cd700f2f5c 100644 --- a/melior/src/context.rs +++ b/melior/src/context.rs @@ -129,11 +129,8 @@ impl Context { unsafe { mlirContextDetachDiagnosticHandler(self.to_raw(), id.to_raw()) } } - pub(crate) fn create_c_string(&self, string: &str) -> &CString { - self.string_cache - .entry(CString::new(string).unwrap()) - .or_insert_with(Default::default) - .key() + pub(crate) fn string_cache(&self) -> &DashMap { + &self.string_cache } } diff --git a/melior/src/string_ref.rs b/melior/src/string_ref.rs index 89a37bdb95..553bda07f8 100644 --- a/melior/src/string_ref.rs +++ b/melior/src/string_ref.rs @@ -1,6 +1,7 @@ use crate::Context; use mlir_sys::{mlirStringRefCreateFromCString, mlirStringRefEqual, MlirStringRef}; use std::{ + ffi::CString, marker::PhantomData, slice, str::{self, Utf8Error}, @@ -19,7 +20,12 @@ pub struct StringRef<'c> { impl<'c> StringRef<'c> { pub fn from_str(context: &'c Context, string: &str) -> Self { - let string = context.create_c_string(string).as_ptr(); + let string = context + .string_cache() + .entry(CString::new(string).unwrap()) + .or_insert_with(Default::default) + .key() + .as_ptr(); unsafe { Self::from_raw(mlirStringRefCreateFromCString(string)) } } From ee1873b565ff3a949ec7a1e9d577b72a96a20557 Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 18:55:37 +1100 Subject: [PATCH 07/22] Fix --- melior/src/lib.rs | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/melior/src/lib.rs b/melior/src/lib.rs index eda303821f..1fb86a35b6 100644 --- a/melior/src/lib.rs +++ b/melior/src/lib.rs @@ -74,12 +74,17 @@ mod tests { let block = Block::new(&[(integer_type, location), (integer_type, location)]); let sum = block.append_operation(arith::addi( + &context, block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )); - block.append_operation(func::r#return(&[sum.result(0).unwrap().into()], location)); + block.append_operation(func::r#return( + &context, + &[sum.result(0).unwrap().into()], + location, + )); let region = Region::new(); region.append_block(block); @@ -124,7 +129,7 @@ mod tests { )); let dim = function_block.append_operation( - OperationBuilder::new("memref.dim", location) + OperationBuilder::new(&context, "memref.dim", location) .add_operands(&[ function_block.argument(0).unwrap().into(), zero.result(0).unwrap().into(), @@ -145,7 +150,7 @@ mod tests { let f32_type = Type::float32(&context); let lhs = loop_block.append_operation( - OperationBuilder::new("memref.load", location) + OperationBuilder::new(&context, "memref.load", location) .add_operands(&[ function_block.argument(0).unwrap().into(), loop_block.argument(0).unwrap().into(), @@ -155,7 +160,7 @@ mod tests { ); let rhs = loop_block.append_operation( - OperationBuilder::new("memref.load", location) + OperationBuilder::new(&context, "memref.load", location) .add_operands(&[ function_block.argument(1).unwrap().into(), loop_block.argument(0).unwrap().into(), @@ -180,10 +185,11 @@ mod tests { .build(), ); - loop_block.append_operation(scf::r#yield(&[], location)); + loop_block.append_operation(scf::r#yield(&context, &[], location)); } function_block.append_operation(scf::r#for( + &context, zero.result(0).unwrap().into(), dim.result(0).unwrap().into(), one.result(0).unwrap().into(), @@ -195,7 +201,7 @@ mod tests { location, )); - function_block.append_operation(func::r#return(&[], location)); + function_block.append_operation(func::r#return(&context, &[], location)); let function_region = Region::new(); function_region.append_block(function_block); @@ -251,6 +257,7 @@ mod tests { let block = Block::new(&[(integer_type, location), (integer_type, location)]); block.append_operation(func::r#return( + &context, &[compile_add( &context, &block, From cb0ef5337149e50e309c0c52d647899ee3dab87a Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 18:59:16 +1100 Subject: [PATCH 08/22] Fix --- macro/src/dialect/operation/builder.rs | 2 +- macro/src/operation.rs | 6 +++--- melior/src/dialect/func.rs | 6 +++--- melior/src/dialect/index.rs | 4 ++-- melior/src/dialect/llvm.rs | 22 +++++++++++----------- melior/src/dialect/memref.rs | 12 ++++++------ melior/src/dialect/scf.rs | 10 +++++----- melior/src/ir/block.rs | 16 ++++++++-------- melior/src/ir/module.rs | 4 ++-- melior/src/ir/operation.rs | 26 +++++++++++++------------- melior/src/ir/operation/builder.rs | 14 +++++++------- melior/src/ir/value.rs | 12 ++++++------ melior/src/lib.rs | 5 +++-- melior/src/pass/external.rs | 1 + melior/src/pass/manager.rs | 2 ++ 15 files changed, 73 insertions(+), 69 deletions(-) diff --git a/macro/src/dialect/operation/builder.rs b/macro/src/dialect/operation/builder.rs index 5e4209c4ea..6417a0086d 100644 --- a/macro/src/dialect/operation/builder.rs +++ b/macro/src/dialect/operation/builder.rs @@ -183,7 +183,7 @@ impl<'o> OperationBuilder<'o> { pub fn new(location: ::melior::ir::Location<'c>) -> Self { Self { context: location.context(), - builder: ::melior::ir::operation::OperationBuilder::new(#name, location), + builder: ::melior::ir::operation::OperationBuilder::new(&context, #name, location), #(#phantoms),* } } diff --git a/macro/src/operation.rs b/macro/src/operation.rs index 283c7eff37..b494adca21 100644 --- a/macro/src/operation.rs +++ b/macro/src/operation.rs @@ -31,7 +31,7 @@ pub fn generate_binary(dialect: &Ident, names: &[Ident]) -> Result, location: crate::ir::Location<'c>, ) -> crate::ir::Operation<'c> { - crate::ir::operation::OperationBuilder::new(context, name, location) + crate::ir::operation::OperationBuilder::new(&context, context, name, location) .add_operands(&[lhs, rhs]) .enable_result_type_inference() .build() @@ -67,7 +67,7 @@ pub fn generate_unary(dialect: &Ident, names: &[Ident]) -> Result, location: crate::ir::Location<'c>, ) -> crate::ir::Operation<'c> { - crate::ir::operation::OperationBuilder::new(context, name, location) + crate::ir::operation::OperationBuilder::new(&context, context, name, location) .add_operands(&[value]) .enable_result_type_inference() .build() @@ -108,7 +108,7 @@ pub fn generate_typed_unary( r#type: crate::ir::Type<'c>, location: crate::ir::Location<'c>, ) -> crate::ir::Operation<'c> { - crate::ir::operation::OperationBuilder::new(context, name, location) + crate::ir::operation::OperationBuilder::new(&context, context, name, location) .add_operands(&[value]) .add_results(&[r#type]) .build() diff --git a/melior/src/dialect/func.rs b/melior/src/dialect/func.rs index 42a9a8336e..0def68f608 100644 --- a/melior/src/dialect/func.rs +++ b/melior/src/dialect/func.rs @@ -47,7 +47,7 @@ pub fn constant<'c>( r#type: FunctionType<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "func.constant", location) + OperationBuilder::new(&context, context, "func.constant", location) .add_attributes(&[(Identifier::new(context, "value"), function.into())]) .add_results(&[r#type.into()]) .build() @@ -62,7 +62,7 @@ pub fn func<'c>( attributes: &[(Identifier<'c>, Attribute<'c>)], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "func.func", location) + OperationBuilder::new(&context, context, "func.func", location) .add_attributes(&[ (Identifier::new(context, "sym_name"), name.into()), (Identifier::new(context, "function_type"), r#type.into()), @@ -78,7 +78,7 @@ pub fn r#return<'c>( operands: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "func.return", location) + OperationBuilder::new(&context, context, "func.return", location) .add_operands(operands) .build() } diff --git a/melior/src/dialect/index.rs b/melior/src/dialect/index.rs index 49b66271ab..07806a2686 100644 --- a/melior/src/dialect/index.rs +++ b/melior/src/dialect/index.rs @@ -17,7 +17,7 @@ pub fn constant<'c>( value: IntegerAttribute<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "index.constant", location) + OperationBuilder::new(&context, context, "index.constant", location) .add_attributes(&[(Identifier::new(context, "value"), value.into())]) .enable_result_type_inference() .build() @@ -31,7 +31,7 @@ pub fn cmp<'c>( rhs: Value<'c, '_>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "index.cmp", location) + OperationBuilder::new(&context, context, "index.cmp", location) .add_attributes(&[( Identifier::new(context, "pred"), Attribute::parse( diff --git a/melior/src/dialect/llvm.rs b/melior/src/dialect/llvm.rs index 4ee4c94e16..cb55b60a25 100644 --- a/melior/src/dialect/llvm.rs +++ b/melior/src/dialect/llvm.rs @@ -30,7 +30,7 @@ pub fn extract_value<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "llvm.extractvalue", location) + OperationBuilder::new(&context, context, "llvm.extractvalue", location) .add_attributes(&[(Identifier::new(context, "position"), position.into())]) .add_operands(&[container]) .add_results(&[result_type]) @@ -46,7 +46,7 @@ pub fn get_element_ptr<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "llvm.getelementptr", location) + OperationBuilder::new(&context, context, "llvm.getelementptr", location) .add_attributes(&[ ( Identifier::new(context, "rawConstantIndices"), @@ -71,7 +71,7 @@ pub fn get_element_ptr_dynamic<'c, const N: usize>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "llvm.getelementptr", location) + OperationBuilder::new(&context, context, "llvm.getelementptr", location) .add_attributes(&[ ( Identifier::new(context, "rawConstantIndices"), @@ -96,7 +96,7 @@ pub fn insert_value<'c>( value: Value<'c, '_>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "llvm.insertvalue", location) + OperationBuilder::new(&context, context, "llvm.insertvalue", location) .add_attributes(&[(Identifier::new(context, "position"), position.into())]) .add_operands(&[container, value]) .enable_result_type_inference() @@ -109,7 +109,7 @@ pub fn undef<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "llvm.mlir.undef", location) + OperationBuilder::new(&context, context, "llvm.mlir.undef", location) .add_results(&[result_type]) .build() } @@ -120,7 +120,7 @@ pub fn poison<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "llvm.mlir.poison", location) + OperationBuilder::new(&context, context, "llvm.mlir.poison", location) .add_results(&[result_type]) .build() } @@ -131,14 +131,14 @@ pub fn nullptr<'c>( ptr_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "llvm.mlir.null", location) + OperationBuilder::new(&context, context, "llvm.mlir.null", location) .add_results(&[ptr_type]) .build() } /// Creates a `llvm.unreachable` operation. pub fn unreachable<'c>(context: &'c Context, location: Location<'c>) -> Operation<'c> { - OperationBuilder::new(context, "llvm.unreachable", location).build() + OperationBuilder::new(&context, context, "llvm.unreachable", location).build() } /// Creates a `llvm.bitcast` operation. @@ -148,7 +148,7 @@ pub fn bitcast<'c>( result: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "llvm.bitcast", location) + OperationBuilder::new(&context, context, "llvm.bitcast", location) .add_operands(&[argument]) .add_results(&[result]) .build() @@ -162,7 +162,7 @@ pub fn alloca<'c>( location: Location<'c>, extra_options: AllocaOptions<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "llvm.alloca", location) + OperationBuilder::new(&context, context, "llvm.alloca", location) .add_operands(&[array_size]) .add_attributes(&extra_options.into_attributes(context)) .add_results(&[ptr_type]) @@ -177,7 +177,7 @@ pub fn store<'c>( location: Location<'c>, extra_options: LoadStoreOptions<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "llvm.store", location) + OperationBuilder::new(&context, context, "llvm.store", location) .add_operands(&[value, addr]) .add_attributes(&extra_options.into_attributes(context)) .build() diff --git a/melior/src/dialect/memref.rs b/melior/src/dialect/memref.rs index 7c1f83b380..9b49076691 100644 --- a/melior/src/dialect/memref.rs +++ b/melior/src/dialect/memref.rs @@ -62,7 +62,7 @@ fn allocate<'c>( alignment: Option>, location: Location<'c>, ) -> Operation<'c> { - let mut builder = OperationBuilder::new(context, name, location); + let mut builder = OperationBuilder::new(&context, context, name, location); builder = builder.add_attributes(&[( Identifier::new(context, "operand_segment_sizes"), @@ -86,7 +86,7 @@ pub fn cast<'c>( r#type: MemRefType<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "memref.cast", location) + OperationBuilder::new(&context, context, "memref.cast", location) .add_operands(&[value]) .add_results(&[r#type.into()]) .build() @@ -123,7 +123,7 @@ pub fn get_global<'c>( r#type: MemRefType<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "memref.get_global", location) + OperationBuilder::new(&context, context, "memref.get_global", location) .add_attributes(&[( Identifier::new(context, "name"), FlatSymbolRefAttribute::new(context, name).into(), @@ -188,7 +188,7 @@ pub fn load<'c>( indices: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "memref.load", location) + OperationBuilder::new(&context, context, "memref.load", location) .add_operands(&[memref]) .add_operands(indices) .enable_result_type_inference() @@ -201,7 +201,7 @@ pub fn rank<'c>( value: Value<'c, '_>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "memref.rank", location) + OperationBuilder::new(&context, context, "memref.rank", location) .add_operands(&[value]) .enable_result_type_inference() .build() @@ -215,7 +215,7 @@ pub fn store<'c>( indices: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "memref.store", location) + OperationBuilder::new(&context, context, "memref.store", location) .add_operands(&[value, memref]) .add_operands(indices) .build() diff --git a/melior/src/dialect/scf.rs b/melior/src/dialect/scf.rs index dcb7a2a049..21eaa4647a 100644 --- a/melior/src/dialect/scf.rs +++ b/melior/src/dialect/scf.rs @@ -15,7 +15,7 @@ pub fn condition<'c>( values: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "scf.condition", location) + OperationBuilder::new(&context, context, "scf.condition", location) .add_operands(&[condition]) .add_operands(values) .build() @@ -28,7 +28,7 @@ pub fn execute_region<'c>( region: Region<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "scf.execute_region", location) + OperationBuilder::new(&context, context, "scf.execute_region", location) .add_results(result_types) .add_regions(vec![region]) .build() @@ -58,7 +58,7 @@ pub fn r#if<'c>( else_region: Region<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "scf.if", location) + OperationBuilder::new(&context, context, "scf.if", location) .add_operands(&[condition]) .add_results(result_types) .add_regions(vec![then_region, else_region]) @@ -74,7 +74,7 @@ pub fn index_switch<'c>( regions: Vec>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "scf.index_switch", location) + OperationBuilder::new(&context, context, "scf.index_switch", location) .add_operands(&[condition]) .add_results(result_types) .add_attributes(&[(Identifier::new(context, "cases"), cases.into())]) @@ -104,7 +104,7 @@ pub fn r#yield<'c>( values: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(context, "scf.yield", location) + OperationBuilder::new(&context, context, "scf.yield", location) .add_operands(values) .build() } diff --git a/melior/src/ir/block.rs b/melior/src/ir/block.rs index 0d1879a6c7..ffa88e1fa2 100644 --- a/melior/src/ir/block.rs +++ b/melior/src/ir/block.rs @@ -404,7 +404,7 @@ mod tests { let block = Block::new(&[]); let operation = block.append_operation( - OperationBuilder::new("func.return", Location::unknown(&context)).build(), + OperationBuilder::new(&context, "func.return", Location::unknown(&context)).build(), ); assert_eq!(block.terminator(), Some(operation)); @@ -422,7 +422,7 @@ mod tests { let block = Block::new(&[]); let operation = block - .append_operation(OperationBuilder::new("foo", Location::unknown(&context)).build()); + .append_operation(OperationBuilder::new(&context, "foo", Location::unknown(&context)).build()); assert_eq!(block.first_operation(), Some(operation)); } @@ -440,7 +440,7 @@ mod tests { context.set_allow_unregistered_dialects(true); let block = Block::new(&[]); - block.append_operation(OperationBuilder::new("foo", Location::unknown(&context)).build()); + block.append_operation(OperationBuilder::new(&context, "foo", Location::unknown(&context)).build()); } #[test] @@ -451,7 +451,7 @@ mod tests { block.insert_operation( 0, - OperationBuilder::new("foo", Location::unknown(&context)).build(), + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(), ); } @@ -462,10 +462,10 @@ mod tests { let block = Block::new(&[]); let first_operation = block - .append_operation(OperationBuilder::new("foo", Location::unknown(&context)).build()); + .append_operation(OperationBuilder::new(&context, "foo", Location::unknown(&context)).build()); let second_operation = block.insert_operation_after( first_operation, - OperationBuilder::new("foo", Location::unknown(&context)).build(), + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(), ); assert_eq!(block.first_operation(), Some(first_operation)); @@ -482,10 +482,10 @@ mod tests { let block = Block::new(&[]); let second_operation = block - .append_operation(OperationBuilder::new("foo", Location::unknown(&context)).build()); + .append_operation(OperationBuilder::new(&context, "foo", Location::unknown(&context)).build()); let first_operation = block.insert_operation_before( second_operation, - OperationBuilder::new("foo", Location::unknown(&context)).build(), + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(), ); assert_eq!(block.first_operation(), Some(first_operation)); diff --git a/melior/src/ir/module.rs b/melior/src/ir/module.rs index 8a21948a1b..87053a02a4 100644 --- a/melior/src/ir/module.rs +++ b/melior/src/ir/module.rs @@ -126,7 +126,7 @@ mod tests { region.append_block(Block::new(&[])); let module = Module::from_operation( - OperationBuilder::new("builtin.module", Location::unknown(&context)) + OperationBuilder::new(&context, "builtin.module", Location::unknown(&context)) .add_regions(vec![region]) .build(), ) @@ -141,7 +141,7 @@ mod tests { let context = create_test_context(); assert!(Module::from_operation( - OperationBuilder::new("func.func", Location::unknown(&context),).build() + OperationBuilder::new(&context, "func.func", Location::unknown(&context),).build() ) .is_none()); } diff --git a/melior/src/ir/operation.rs b/melior/src/ir/operation.rs index 0f01cd53e3..bb71208325 100644 --- a/melior/src/ir/operation.rs +++ b/melior/src/ir/operation.rs @@ -431,7 +431,7 @@ mod tests { fn new() { let context = create_test_context(); context.set_allow_unregistered_dialects(true); - OperationBuilder::new("foo", Location::unknown(&context)).build(); + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(); } #[test] @@ -440,7 +440,7 @@ mod tests { context.set_allow_unregistered_dialects(true); assert_eq!( - OperationBuilder::new("foo", Location::unknown(&context),) + OperationBuilder::new(&context, "foo", Location::unknown(&context),) .build() .name(), Identifier::new(&context, "foo") @@ -453,7 +453,7 @@ mod tests { context.set_allow_unregistered_dialects(true); let block = Block::new(&[]); let operation = block - .append_operation(OperationBuilder::new("foo", Location::unknown(&context)).build()); + .append_operation(OperationBuilder::new(&context, "foo", Location::unknown(&context)).build()); assert_eq!(operation.block().as_deref(), Some(&block)); } @@ -463,7 +463,7 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); assert_eq!( - OperationBuilder::new("foo", Location::unknown(&context)) + OperationBuilder::new(&context, "foo", Location::unknown(&context)) .build() .block(), None @@ -475,7 +475,7 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); assert_eq!( - OperationBuilder::new("foo", Location::unknown(&context)) + OperationBuilder::new(&context, "foo", Location::unknown(&context)) .build() .result(0) .unwrap_err(), @@ -492,7 +492,7 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); assert_eq!( - OperationBuilder::new("foo", Location::unknown(&context),) + OperationBuilder::new(&context, "foo", Location::unknown(&context),) .build() .region(0), Err(Error::PositionOutOfBounds { @@ -514,7 +514,7 @@ mod tests { let argument: Value = block.argument(0).unwrap().into(); let operands = vec![argument, argument, argument]; - let operation = OperationBuilder::new("foo", Location::unknown(&context)) + let operation = OperationBuilder::new(&context, "foo", Location::unknown(&context)) .add_operands(&operands) .build(); @@ -529,7 +529,7 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); - let operation = OperationBuilder::new("foo", Location::unknown(&context)) + let operation = OperationBuilder::new(&context, "foo", Location::unknown(&context)) .add_regions(vec![Region::new()]) .build(); @@ -544,7 +544,7 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); - let mut operation = OperationBuilder::new("foo", Location::unknown(&context)) + let mut operation = OperationBuilder::new(&context, "foo", Location::unknown(&context)) .add_attributes(&[( Identifier::new(&context, "foo"), StringAttribute::new(&context, "bar").into(), @@ -575,7 +575,7 @@ mod tests { fn clone() { let context = create_test_context(); context.set_allow_unregistered_dialects(true); - let operation = OperationBuilder::new("foo", Location::unknown(&context)).build(); + let operation = OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(); let _ = operation.clone(); } @@ -586,7 +586,7 @@ mod tests { context.set_allow_unregistered_dialects(true); assert_eq!( - OperationBuilder::new("foo", Location::unknown(&context),) + OperationBuilder::new(&context, "foo", Location::unknown(&context),) .build() .to_string(), "\"foo\"() : () -> ()\n" @@ -601,7 +601,7 @@ mod tests { assert_eq!( format!( "{:?}", - OperationBuilder::new("foo", Location::unknown(&context)).build() + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build() ), "Operation(\n\"foo\"() : () -> ()\n)" ); @@ -613,7 +613,7 @@ mod tests { context.set_allow_unregistered_dialects(true); assert_eq!( - OperationBuilder::new("foo", Location::unknown(&context)) + OperationBuilder::new(&context, "foo", Location::unknown(&context)) .build() .to_string_with_flags( OperationPrintingFlags::new() diff --git a/melior/src/ir/operation/builder.rs b/melior/src/ir/operation/builder.rs index dadcb0e94e..d45f6bd6ae 100644 --- a/melior/src/ir/operation/builder.rs +++ b/melior/src/ir/operation/builder.rs @@ -137,7 +137,7 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); - OperationBuilder::new("foo", Location::unknown(&context)).build(); + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(); } #[test] @@ -150,7 +150,7 @@ mod tests { let block = Block::new(&[(r#type, location)]); let argument = block.argument(0).unwrap().into(); - OperationBuilder::new("foo", Location::unknown(&context)) + OperationBuilder::new(&context, "foo", Location::unknown(&context)) .add_operands(&[argument]) .build(); } @@ -160,7 +160,7 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); - OperationBuilder::new("foo", Location::unknown(&context)) + OperationBuilder::new(&context, "foo", Location::unknown(&context)) .add_results(&[Type::parse(&context, "i1").unwrap()]) .build(); } @@ -170,7 +170,7 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); - OperationBuilder::new("foo", Location::unknown(&context)) + OperationBuilder::new(&context, "foo", Location::unknown(&context)) .add_regions(vec![Region::new()]) .build(); } @@ -180,7 +180,7 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); - OperationBuilder::new("foo", Location::unknown(&context)) + OperationBuilder::new(&context, "foo", Location::unknown(&context)) .add_successors(&[&Block::new(&[])]) .build(); } @@ -190,7 +190,7 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); - OperationBuilder::new("foo", Location::unknown(&context)) + OperationBuilder::new(&context, "foo", Location::unknown(&context)) .add_attributes(&[( Identifier::new(&context, "foo"), Attribute::parse(&context, "unit").unwrap(), @@ -209,7 +209,7 @@ mod tests { let argument = block.argument(0).unwrap().into(); assert_eq!( - OperationBuilder::new("arith.addi", location) + OperationBuilder::new(&context, "arith.addi", location) .add_operands(&[argument, argument]) .enable_result_type_inference() .build() diff --git a/melior/src/ir/value.rs b/melior/src/ir/value.rs index b97451fde3..dcc6293a1e 100644 --- a/melior/src/ir/value.rs +++ b/melior/src/ir/value.rs @@ -133,7 +133,7 @@ mod tests { let location = Location::unknown(&context); let index_type = Type::index(&context); - let value = OperationBuilder::new("arith.constant", location) + let value = OperationBuilder::new(&context, "arith.constant", location) .add_results(&[index_type]) .add_attributes(&[( Identifier::new(&context, "value"), @@ -150,7 +150,7 @@ mod tests { let location = Location::unknown(&context); let index_type = Type::index(&context); - let operation = OperationBuilder::new("arith.constant", location) + let operation = OperationBuilder::new(&context, "arith.constant", location) .add_results(&[index_type]) .add_attributes(&[( Identifier::new(&context, "value"), @@ -169,7 +169,7 @@ mod tests { let index_type = Type::index(&context); let operation = || { - OperationBuilder::new("arith.constant", location) + OperationBuilder::new(&context, "arith.constant", location) .add_results(&[index_type]) .add_attributes(&[( Identifier::new(&context, "value"), @@ -192,7 +192,7 @@ mod tests { let location = Location::unknown(&context); let index_type = Type::index(&context); - let operation = OperationBuilder::new("arith.constant", location) + let operation = OperationBuilder::new(&context, "arith.constant", location) .add_results(&[index_type]) .add_attributes(&[( Identifier::new(&context, "value"), @@ -213,7 +213,7 @@ mod tests { let location = Location::unknown(&context); let index_type = Type::index(&context); - let operation = OperationBuilder::new("arith.constant", location) + let operation = OperationBuilder::new(&context, "arith.constant", location) .add_results(&[index_type]) .add_attributes(&[( Identifier::new(&context, "value"), @@ -234,7 +234,7 @@ mod tests { let location = Location::unknown(&context); let index_type = Type::index(&context); - let operation = OperationBuilder::new("arith.constant", location) + let operation = OperationBuilder::new(&context, "arith.constant", location) .add_results(&[index_type]) .add_attributes(&[( Identifier::new(&context, "value"), diff --git a/melior/src/lib.rs b/melior/src/lib.rs index 1fb86a35b6..5291f3b22f 100644 --- a/melior/src/lib.rs +++ b/melior/src/lib.rs @@ -170,13 +170,14 @@ mod tests { ); let add = loop_block.append_operation(arith::addf( + &context, lhs.result(0).unwrap().into(), rhs.result(0).unwrap().into(), location, )); loop_block.append_operation( - OperationBuilder::new("memref.store", location) + OperationBuilder::new(&context, "memref.store", location) .add_operands(&[ add.result(0).unwrap().into(), function_block.argument(0).unwrap().into(), @@ -241,7 +242,7 @@ mod tests { rhs: Value<'c, '_>, ) -> Value<'c, 'a> { block - .append_operation(arith::addi(lhs, rhs, Location::unknown(context))) + .append_operation(arith::addi(&context, lhs, rhs, Location::unknown(context))) .result(0) .unwrap() .into() diff --git a/melior/src/pass/external.rs b/melior/src/pass/external.rs index 0eb2b4f914..db0928f47b 100644 --- a/melior/src/pass/external.rs +++ b/melior/src/pass/external.rs @@ -307,6 +307,7 @@ mod tests { let pass_manager = PassManager::new(&context); pass_manager.add_pass(create_external( + &context, |operation: OperationRef, pass: ExternalPass<'_>| { assert!(operation.verify()); assert!( diff --git a/melior/src/pass/manager.rs b/melior/src/pass/manager.rs index 57f8d91084..4f99588852 100644 --- a/melior/src/pass/manager.rs +++ b/melior/src/pass/manager.rs @@ -216,6 +216,7 @@ mod tests { let manager = PassManager::new(&context); insta::assert_display_snapshot!(parse_pass_pipeline( + &context, manager.as_operation_pass_manager(), "builtin.module(func.func(print-op-stats{json=false}),\ func.func(print-op-stats{json=false}))" @@ -226,6 +227,7 @@ mod tests { assert_eq!( parse_pass_pipeline( + &context, manager.as_operation_pass_manager(), "builtin.module(func.func(print-op-stats{json=false}),\ func.func(print-op-stats{json=false}))" From 4190f4b896de3fbe82e6442fe6ffe8fa7f462145 Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 19:03:22 +1100 Subject: [PATCH 09/22] Fix --- melior/src/execution_engine.rs | 3 ++- melior/src/ir/operation.rs | 21 +++++++++++++-------- melior/src/ir/operation/result.rs | 2 +- melior/src/ir/value.rs | 4 ++-- melior/src/pass/external.rs | 3 ++- 5 files changed, 20 insertions(+), 13 deletions(-) diff --git a/melior/src/execution_engine.rs b/melior/src/execution_engine.rs index c200a3cafd..a84bf97ac9 100644 --- a/melior/src/execution_engine.rs +++ b/melior/src/execution_engine.rs @@ -133,6 +133,7 @@ mod tests { assert_eq!( unsafe { engine.invoke_packed( + &context, "add", &mut [ &mut argument as *mut i32 as *mut (), @@ -174,6 +175,6 @@ mod tests { assert_eq!(pass_manager.run(&mut module), Ok(())); ExecutionEngine::new(&context, &module, 2, &[], true) - .dump_to_object_file("/tmp/melior/test.o"); + .dump_to_object_file(&context, "/tmp/melior/test.o"); } } diff --git a/melior/src/ir/operation.rs b/melior/src/ir/operation.rs index bb71208325..d93d1450ee 100644 --- a/melior/src/ir/operation.rs +++ b/melior/src/ir/operation.rs @@ -452,8 +452,9 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); let block = Block::new(&[]); - let operation = block - .append_operation(OperationBuilder::new(&context, "foo", Location::unknown(&context)).build()); + let operation = block.append_operation( + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(), + ); assert_eq!(operation.block().as_deref(), Some(&block)); } @@ -550,16 +551,20 @@ mod tests { StringAttribute::new(&context, "bar").into(), )]) .build(); - assert!(operation.has_attribute("foo")); + assert!(operation.has_attribute(&context, "foo")); assert_eq!( - operation.attribute("foo").map(|a| a.to_string()), + operation.attribute(&context, "foo").map(|a| a.to_string()), Ok("\"bar\"".into()) ); - assert!(operation.remove_attribute("foo").is_ok()); - assert!(operation.remove_attribute("foo").is_err()); - operation.set_attribute("foo", &StringAttribute::new(&context, "foo").into()); + assert!(operation.remove_attribute(&context, "foo").is_ok()); + assert!(operation.remove_attribute(&context, "foo").is_err()); + operation.set_attribute( + &context, + "foo", + &StringAttribute::new(&context, "foo").into(), + ); assert_eq!( - operation.attribute("foo").map(|a| a.to_string()), + operation.attribute(&context, "foo").map(|a| a.to_string()), Ok("\"foo\"".into()) ); assert_eq!( diff --git a/melior/src/ir/operation/result.rs b/melior/src/ir/operation/result.rs index f0067f23a3..cdcf4db63a 100644 --- a/melior/src/ir/operation/result.rs +++ b/melior/src/ir/operation/result.rs @@ -71,7 +71,7 @@ mod tests { context.set_allow_unregistered_dialects(true); let r#type = Type::parse(&context, "index").unwrap(); - let operation = OperationBuilder::new("foo", Location::unknown(&context)) + let operation = OperationBuilder::new(&context, "foo", Location::unknown(&context)) .add_results(&[r#type]) .build(); diff --git a/melior/src/ir/value.rs b/melior/src/ir/value.rs index dcc6293a1e..073e987747 100644 --- a/melior/src/ir/value.rs +++ b/melior/src/ir/value.rs @@ -90,7 +90,7 @@ mod tests { let location = Location::unknown(&context); let index_type = Type::index(&context); - let operation = OperationBuilder::new("arith.constant", location) + let operation = OperationBuilder::new(&context, "arith.constant", location) .add_results(&[index_type]) .add_attributes(&[( Identifier::new(&context, "value"), @@ -107,7 +107,7 @@ mod tests { let location = Location::unknown(&context); let r#type = Type::index(&context); - let operation = OperationBuilder::new("arith.constant", location) + let operation = OperationBuilder::new(&context, "arith.constant", location) .add_results(&[r#type]) .add_attributes(&[( Identifier::new(&context, "value"), diff --git a/melior/src/pass/external.rs b/melior/src/pass/external.rs index db0928f47b..609b36ff93 100644 --- a/melior/src/pass/external.rs +++ b/melior/src/pass/external.rs @@ -220,7 +220,7 @@ mod tests { TypeAttribute::new(FunctionType::new(context, &[], &[]).into()), { let block = Block::new(&[]); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -276,6 +276,7 @@ mod tests { impl TestPass { fn create(self) -> Pass { create_external( + &context, self, TypeId::create(&TEST_PASS), "test pass", From 329773d8e6dc8bf03974f6b380c87481dd3f5f77 Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 19:04:06 +1100 Subject: [PATCH 10/22] Fix --- melior/src/dialect/ods.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/melior/src/dialect/ods.rs b/melior/src/dialect/ods.rs index db19b1e685..c16a94a4b4 100644 --- a/melior/src/dialect/ods.rs +++ b/melior/src/dialect/ods.rs @@ -199,7 +199,7 @@ mod tests { .into(), ); - block.append_operation(func::r#return(&[], location)); + block.append_operation(&context, func::r#return(&[], location)); }); } @@ -224,7 +224,7 @@ mod tests { .into(), ); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); }); } } From a6a6c7749d3bc29d1fc68d3d37cfbe9bd6d55b25 Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 19:05:28 +1100 Subject: [PATCH 11/22] Fix --- melior/src/dialect/ods.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/melior/src/dialect/ods.rs b/melior/src/dialect/ods.rs index c16a94a4b4..b88321db84 100644 --- a/melior/src/dialect/ods.rs +++ b/melior/src/dialect/ods.rs @@ -199,7 +199,7 @@ mod tests { .into(), ); - block.append_operation(&context, func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); }); } From 32150cb98d0f3ee3d0040fc533566e4e6f333862 Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 19:10:53 +1100 Subject: [PATCH 12/22] Fix --- macro/src/dialect/operation/builder.rs | 6 +++--- melior/src/dialect/ods.rs | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/macro/src/dialect/operation/builder.rs b/macro/src/dialect/operation/builder.rs index 6417a0086d..fb8064c50c 100644 --- a/macro/src/dialect/operation/builder.rs +++ b/macro/src/dialect/operation/builder.rs @@ -142,7 +142,7 @@ impl<'o> OperationBuilder<'o> { #[doc = #doc] pub struct #builder_ident <'c, #(#iter_arguments),* > { builder: ::melior::ir::operation::OperationBuilder<'c>, - context: ::melior::ContextRef<'c>, + context: &'c ::melior::Context, #(#phantom_fields),* } @@ -180,9 +180,9 @@ impl<'o> OperationBuilder<'o> { quote! { impl<'c> #builder_ident<'c, #(#arguments),*> { - pub fn new(location: ::melior::ir::Location<'c>) -> Self { + pub fn new(context: &'c ::melior::Context, location: ::melior::ir::Location<'c>) -> Self { Self { - context: location.context(), + context, builder: ::melior::ir::operation::OperationBuilder::new(&context, #name, location), #(#phantoms),* } diff --git a/melior/src/dialect/ods.rs b/melior/src/dialect/ods.rs index b88321db84..ec69e28043 100644 --- a/melior/src/dialect/ods.rs +++ b/melior/src/dialect/ods.rs @@ -215,7 +215,7 @@ mod tests { let i64_type = IntegerType::new(&context, 64); block.append_operation( - llvm::AllocaOpBuilder::new(location) + llvm::AllocaOpBuilder::new(&context, location) .alignment(IntegerAttribute::new(8, i64_type.into())) .elem_type(TypeAttribute::new(i64_type.into())) .array_size(alloca_size) From f83ccf76f348209ae119b9e1e424b6fafa723eeb Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 19:57:25 +1100 Subject: [PATCH 13/22] Fix attribute accessor --- macro/src/dialect/operation/accessors.rs | 18 +++++++++--------- macro/src/dialect/operation/builder.rs | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/macro/src/dialect/operation/accessors.rs b/macro/src/dialect/operation/accessors.rs index 857ec6f38e..314fb9c3fe 100644 --- a/macro/src/dialect/operation/accessors.rs +++ b/macro/src/dialect/operation/accessors.rs @@ -154,11 +154,11 @@ impl<'a> OperationField<'a> { let name = &self.name; Some(if constraint.is_unit()? { - quote! { self.operation.attribute(#name).is_some() } + quote! { self.operation.attribute(context, #name).is_some() } } else { quote! { self.operation - .attribute(#name)? + .attribute(context, #name)? .try_into() .map_err(::melior::Error::from) } @@ -174,7 +174,7 @@ impl<'a> OperationField<'a> { if constraint.is_unit()? || constraint.is_optional()? { Some(quote! { - self.operation.remove_attribute(#name) + self.operation.remove_attribute(context, #name) }) } else { None @@ -193,14 +193,14 @@ impl<'a> OperationField<'a> { Ok(Some(if constraint.is_unit()? { quote! { if value { - self.operation.set_attribute(#name, Attribute::unit(&self.operation.context())); + self.operation.set_attribute(context, #name, Attribute::unit(&self.operation.context())); } else { - self.operation.remove_attribute(#name) + self.operation.remove_attribute(context, #name) } } } else { quote! { - self.operation.set_attribute(#name, &value.into()); + self.operation.set_attribute(context, #name, &value.into()); } })) } @@ -213,7 +213,7 @@ impl<'a> OperationField<'a> { let parameter_type = &self.kind.parameter_type()?; quote! { - pub fn #ident(&mut self, value: #parameter_type) { + pub fn #ident(&mut self, context: &'c ::melior::Context, value: #parameter_type) { #body } } @@ -225,7 +225,7 @@ impl<'a> OperationField<'a> { let ident = sanitize_snake_case_name(&format!("remove_{}", self.name))?; self.remover_impl()?.map(|body| { quote! { - pub fn #ident(&mut self) -> Result<(), ::melior::Error> { + pub fn #ident(&mut self, context: &'c ::melior::Context) -> Result<(), ::melior::Error> { #body } } @@ -236,7 +236,7 @@ impl<'a> OperationField<'a> { let return_type = &self.kind.return_type()?; self.getter_impl()?.map(|body| { quote! { - pub fn #ident(&self) -> #return_type { + pub fn #ident(&self, context: &'c ::melior::Context) -> #return_type { #body } } diff --git a/macro/src/dialect/operation/builder.rs b/macro/src/dialect/operation/builder.rs index fb8064c50c..7a38d8372a 100644 --- a/macro/src/dialect/operation/builder.rs +++ b/macro/src/dialect/operation/builder.rs @@ -51,7 +51,7 @@ impl<'o> OperationBuilder<'o> { quote! { &[( - ::melior::ir::Identifier::new(unsafe { self.context.to_ref() }, #name_string), + ::melior::ir::Identifier::new(self.context, #name_string), #name.into(), )] } From 81a5ccafca41eec157b7bf4c3243d2fe0b3bb9ff Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 20:06:19 +1100 Subject: [PATCH 14/22] Fix --- macro/src/dialect/operation/builder.rs | 7 ++++--- melior/src/dialect/func.rs | 6 +++--- melior/src/dialect/index.rs | 4 ++-- melior/src/dialect/llvm.rs | 22 +++++++++++----------- melior/src/dialect/memref.rs | 12 ++++++------ melior/src/dialect/ods.rs | 1 + melior/src/dialect/scf.rs | 20 +++++++++++++------- 7 files changed, 40 insertions(+), 32 deletions(-) diff --git a/macro/src/dialect/operation/builder.rs b/macro/src/dialect/operation/builder.rs index 7a38d8372a..8cf30b7f94 100644 --- a/macro/src/dialect/operation/builder.rs +++ b/macro/src/dialect/operation/builder.rs @@ -196,9 +196,10 @@ impl<'o> OperationBuilder<'o> { let arguments = self.type_state.arguments_all_set(false); quote! { pub fn builder( + context: &'c ::melior::Context, location: ::melior::ir::Location<'c> ) -> #builder_ident<'c, #(#arguments),*> { - #builder_ident::new(location) + #builder_ident::new(context, location) } } } @@ -229,8 +230,8 @@ impl<'o> OperationBuilder<'o> { Ok(quote! { #[allow(clippy::too_many_arguments)] #[doc = #doc] - pub fn #name<'c>(#(#arguments),*) -> #class_name<'c> { - #class_name::builder(location)#(#builder_calls)*.build() + pub fn #name<'c>(context: &'c ::melior::Context, #(#arguments),*) -> #class_name<'c> { + #class_name::builder(context, location)#(#builder_calls)*.build() } }) } diff --git a/melior/src/dialect/func.rs b/melior/src/dialect/func.rs index 0def68f608..9367ee80b4 100644 --- a/melior/src/dialect/func.rs +++ b/melior/src/dialect/func.rs @@ -47,7 +47,7 @@ pub fn constant<'c>( r#type: FunctionType<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "func.constant", location) + OperationBuilder::new(&context, "func.constant", location) .add_attributes(&[(Identifier::new(context, "value"), function.into())]) .add_results(&[r#type.into()]) .build() @@ -62,7 +62,7 @@ pub fn func<'c>( attributes: &[(Identifier<'c>, Attribute<'c>)], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "func.func", location) + OperationBuilder::new(&context, "func.func", location) .add_attributes(&[ (Identifier::new(context, "sym_name"), name.into()), (Identifier::new(context, "function_type"), r#type.into()), @@ -78,7 +78,7 @@ pub fn r#return<'c>( operands: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "func.return", location) + OperationBuilder::new(&context, "func.return", location) .add_operands(operands) .build() } diff --git a/melior/src/dialect/index.rs b/melior/src/dialect/index.rs index 07806a2686..a39c8fa1ce 100644 --- a/melior/src/dialect/index.rs +++ b/melior/src/dialect/index.rs @@ -17,7 +17,7 @@ pub fn constant<'c>( value: IntegerAttribute<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "index.constant", location) + OperationBuilder::new(&context, "index.constant", location) .add_attributes(&[(Identifier::new(context, "value"), value.into())]) .enable_result_type_inference() .build() @@ -31,7 +31,7 @@ pub fn cmp<'c>( rhs: Value<'c, '_>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "index.cmp", location) + OperationBuilder::new(&context, "index.cmp", location) .add_attributes(&[( Identifier::new(context, "pred"), Attribute::parse( diff --git a/melior/src/dialect/llvm.rs b/melior/src/dialect/llvm.rs index cb55b60a25..2c8faa0490 100644 --- a/melior/src/dialect/llvm.rs +++ b/melior/src/dialect/llvm.rs @@ -30,7 +30,7 @@ pub fn extract_value<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "llvm.extractvalue", location) + OperationBuilder::new(&context, "llvm.extractvalue", location) .add_attributes(&[(Identifier::new(context, "position"), position.into())]) .add_operands(&[container]) .add_results(&[result_type]) @@ -46,7 +46,7 @@ pub fn get_element_ptr<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "llvm.getelementptr", location) + OperationBuilder::new(&context, "llvm.getelementptr", location) .add_attributes(&[ ( Identifier::new(context, "rawConstantIndices"), @@ -71,7 +71,7 @@ pub fn get_element_ptr_dynamic<'c, const N: usize>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "llvm.getelementptr", location) + OperationBuilder::new(&context, "llvm.getelementptr", location) .add_attributes(&[ ( Identifier::new(context, "rawConstantIndices"), @@ -96,7 +96,7 @@ pub fn insert_value<'c>( value: Value<'c, '_>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "llvm.insertvalue", location) + OperationBuilder::new(&context, "llvm.insertvalue", location) .add_attributes(&[(Identifier::new(context, "position"), position.into())]) .add_operands(&[container, value]) .enable_result_type_inference() @@ -109,7 +109,7 @@ pub fn undef<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "llvm.mlir.undef", location) + OperationBuilder::new(&context, "llvm.mlir.undef", location) .add_results(&[result_type]) .build() } @@ -120,7 +120,7 @@ pub fn poison<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "llvm.mlir.poison", location) + OperationBuilder::new(&context, "llvm.mlir.poison", location) .add_results(&[result_type]) .build() } @@ -131,14 +131,14 @@ pub fn nullptr<'c>( ptr_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "llvm.mlir.null", location) + OperationBuilder::new(&context, "llvm.mlir.null", location) .add_results(&[ptr_type]) .build() } /// Creates a `llvm.unreachable` operation. pub fn unreachable<'c>(context: &'c Context, location: Location<'c>) -> Operation<'c> { - OperationBuilder::new(&context, context, "llvm.unreachable", location).build() + OperationBuilder::new(&context, "llvm.unreachable", location).build() } /// Creates a `llvm.bitcast` operation. @@ -148,7 +148,7 @@ pub fn bitcast<'c>( result: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "llvm.bitcast", location) + OperationBuilder::new(&context, "llvm.bitcast", location) .add_operands(&[argument]) .add_results(&[result]) .build() @@ -162,7 +162,7 @@ pub fn alloca<'c>( location: Location<'c>, extra_options: AllocaOptions<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "llvm.alloca", location) + OperationBuilder::new(&context, "llvm.alloca", location) .add_operands(&[array_size]) .add_attributes(&extra_options.into_attributes(context)) .add_results(&[ptr_type]) @@ -177,7 +177,7 @@ pub fn store<'c>( location: Location<'c>, extra_options: LoadStoreOptions<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "llvm.store", location) + OperationBuilder::new(&context, "llvm.store", location) .add_operands(&[value, addr]) .add_attributes(&extra_options.into_attributes(context)) .build() diff --git a/melior/src/dialect/memref.rs b/melior/src/dialect/memref.rs index 9b49076691..9dbeff077d 100644 --- a/melior/src/dialect/memref.rs +++ b/melior/src/dialect/memref.rs @@ -62,7 +62,7 @@ fn allocate<'c>( alignment: Option>, location: Location<'c>, ) -> Operation<'c> { - let mut builder = OperationBuilder::new(&context, context, name, location); + let mut builder = OperationBuilder::new(&context, name, location); builder = builder.add_attributes(&[( Identifier::new(context, "operand_segment_sizes"), @@ -86,7 +86,7 @@ pub fn cast<'c>( r#type: MemRefType<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "memref.cast", location) + OperationBuilder::new(&context, "memref.cast", location) .add_operands(&[value]) .add_results(&[r#type.into()]) .build() @@ -123,7 +123,7 @@ pub fn get_global<'c>( r#type: MemRefType<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "memref.get_global", location) + OperationBuilder::new(&context, "memref.get_global", location) .add_attributes(&[( Identifier::new(context, "name"), FlatSymbolRefAttribute::new(context, name).into(), @@ -188,7 +188,7 @@ pub fn load<'c>( indices: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "memref.load", location) + OperationBuilder::new(&context, "memref.load", location) .add_operands(&[memref]) .add_operands(indices) .enable_result_type_inference() @@ -201,7 +201,7 @@ pub fn rank<'c>( value: Value<'c, '_>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "memref.rank", location) + OperationBuilder::new(&context, "memref.rank", location) .add_operands(&[value]) .enable_result_type_inference() .build() @@ -215,7 +215,7 @@ pub fn store<'c>( indices: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "memref.store", location) + OperationBuilder::new(&context, "memref.store", location) .add_operands(&[value, memref]) .add_operands(indices) .build() diff --git a/melior/src/dialect/ods.rs b/melior/src/dialect/ods.rs index ec69e28043..1c764ce274 100644 --- a/melior/src/dialect/ods.rs +++ b/melior/src/dialect/ods.rs @@ -192,6 +192,7 @@ mod tests { block.append_operation( llvm::alloca( + &context, dialect::llvm::r#type::pointer(i64_type.into(), 0).into(), alloca_size, location, diff --git a/melior/src/dialect/scf.rs b/melior/src/dialect/scf.rs index 21eaa4647a..ae49fd7931 100644 --- a/melior/src/dialect/scf.rs +++ b/melior/src/dialect/scf.rs @@ -15,7 +15,7 @@ pub fn condition<'c>( values: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "scf.condition", location) + OperationBuilder::new(&context, "scf.condition", location) .add_operands(&[condition]) .add_operands(values) .build() @@ -28,7 +28,7 @@ pub fn execute_region<'c>( region: Region<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "scf.execute_region", location) + OperationBuilder::new(&context, "scf.execute_region", location) .add_results(result_types) .add_regions(vec![region]) .build() @@ -58,7 +58,7 @@ pub fn r#if<'c>( else_region: Region<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "scf.if", location) + OperationBuilder::new(&context, "scf.if", location) .add_operands(&[condition]) .add_results(result_types) .add_regions(vec![then_region, else_region]) @@ -74,7 +74,7 @@ pub fn index_switch<'c>( regions: Vec>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "scf.index_switch", location) + OperationBuilder::new(&context, "scf.index_switch", location) .add_operands(&[condition]) .add_results(result_types) .add_attributes(&[(Identifier::new(context, "cases"), cases.into())]) @@ -104,7 +104,7 @@ pub fn r#yield<'c>( values: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, context, "scf.yield", location) + OperationBuilder::new(&context, "scf.yield", location) .add_operands(values) .build() } @@ -559,6 +559,7 @@ mod tests { )); block.append_operation(r#while( + &context, &[initial.result(0).unwrap().into()], &[float_type], { @@ -578,6 +579,7 @@ mod tests { )); block.append_operation(super::condition( + &context, condition.result(0).unwrap().into(), &[result.result(0).unwrap().into()], location, @@ -597,6 +599,7 @@ mod tests { )); block.append_operation(r#yield( + &context, &[result.result(0).unwrap().into()], location, )); @@ -608,7 +611,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -645,6 +648,7 @@ mod tests { )); block.append_operation(r#while( + &context, &[ initial.result(0).unwrap().into(), initial.result(0).unwrap().into(), @@ -668,6 +672,7 @@ mod tests { )); block.append_operation(super::condition( + &context, condition.result(0).unwrap().into(), &[ result.result(0).unwrap().into(), @@ -691,6 +696,7 @@ mod tests { )); block.append_operation(r#yield( + &context, &[ result.result(0).unwrap().into(), result.result(0).unwrap().into(), @@ -705,7 +711,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); From 99204aecef5d2cf5865fd12a2bba9d247dfe9fbf Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 20:08:40 +1100 Subject: [PATCH 15/22] Fix --- melior/src/dialect/scf.rs | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/melior/src/dialect/scf.rs b/melior/src/dialect/scf.rs index ae49fd7931..8ea3e19db7 100644 --- a/melior/src/dialect/scf.rs +++ b/melior/src/dialect/scf.rs @@ -140,6 +140,7 @@ mod tests { let block = Block::new(&[]); block.append_operation(execute_region( + &context, &[index_type], { let block = Block::new(&[]); @@ -151,6 +152,7 @@ mod tests { )); block.append_operation(r#yield( + &context, &[value.result(0).unwrap().into()], location, )); @@ -162,7 +164,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -210,12 +212,13 @@ mod tests { )); block.append_operation(r#for( + &context, start.result(0).unwrap().into(), end.result(0).unwrap().into(), step.result(0).unwrap().into(), { let block = Block::new(&[(Type::index(&context), location)]); - block.append_operation(r#yield(&[], location)); + block.append_operation(r#yield(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -224,7 +227,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -264,6 +267,7 @@ mod tests { )); let result = block.append_operation(r#if( + &context, condition.result(0).unwrap().into(), &[index_type], { @@ -276,6 +280,7 @@ mod tests { )); block.append_operation(r#yield( + &context, &[result.result(0).unwrap().into()], location, )); @@ -294,6 +299,7 @@ mod tests { )); block.append_operation(r#yield( + &context, &[result.result(0).unwrap().into()], location, )); @@ -306,6 +312,7 @@ mod tests { )); block.append_operation(func::r#return( + &context, &[result.result(0).unwrap().into()], location, )); @@ -344,12 +351,13 @@ mod tests { )); block.append_operation(r#if( + &context, condition.result(0).unwrap().into(), &[], { let block = Block::new(&[]); - block.append_operation(r#yield(&[], location)); + block.append_operation(r#yield(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -359,7 +367,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -404,7 +412,7 @@ mod tests { { let block = Block::new(&[]); - block.append_operation(r#yield(&[], location)); + block.append_operation(r#yield(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -413,7 +421,7 @@ mod tests { { let block = Block::new(&[]); - block.append_operation(r#yield(&[], location)); + block.append_operation(r#yield(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -422,7 +430,7 @@ mod tests { { let block = Block::new(&[]); - block.append_operation(r#yield(&[], location)); + block.append_operation(r#yield(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -432,7 +440,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -472,6 +480,7 @@ mod tests { )); block.append_operation(r#while( + &context, &[initial.result(0).unwrap().into()], &[index_type], { @@ -491,6 +500,7 @@ mod tests { )); block.append_operation(super::condition( + &context, condition.result(0).unwrap().into(), &[result.result(0).unwrap().into()], location, @@ -510,6 +520,7 @@ mod tests { )); block.append_operation(r#yield( + &context, &[result.result(0).unwrap().into()], location, )); @@ -521,7 +532,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); From d7d173e8c7d6f8d523c44aec30a0255914861f10 Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 20:12:35 +1100 Subject: [PATCH 16/22] Fix --- melior/src/dialect/index.rs | 7 ++++- melior/src/dialect/llvm.rs | 50 +++++++++++++++++++++--------------- melior/src/dialect/memref.rs | 24 +++++++++++++---- 3 files changed, 54 insertions(+), 27 deletions(-) diff --git a/melior/src/dialect/index.rs b/melior/src/dialect/index.rs index a39c8fa1ce..967cac3bfe 100644 --- a/melior/src/dialect/index.rs +++ b/melior/src/dialect/index.rs @@ -229,12 +229,17 @@ mod tests { let block = Block::new(&[(integer_type, location), (integer_type, location)]); let sum = block.append_operation(add( + &context, block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )); - block.append_operation(func::r#return(&[sum.result(0).unwrap().into()], location)); + block.append_operation(func::r#return( + &context, + &[sum.result(0).unwrap().into()], + location, + )); let region = Region::new(); region.append_block(block); diff --git a/melior/src/dialect/llvm.rs b/melior/src/dialect/llvm.rs index 2c8faa0490..198e13994a 100644 --- a/melior/src/dialect/llvm.rs +++ b/melior/src/dialect/llvm.rs @@ -421,7 +421,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -462,7 +462,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -513,7 +513,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -562,7 +562,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -594,9 +594,9 @@ mod tests { { let block = Block::new(&[(struct_type, location)]); - block.append_operation(undef(struct_type, location)); + block.append_operation(undef(&context, struct_type, location)); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -628,9 +628,9 @@ mod tests { { let block = Block::new(&[(struct_type, location)]); - block.append_operation(poison(struct_type, location)); + block.append_operation(poison(&context, struct_type, location)); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -670,7 +670,7 @@ mod tests { AllocaOptions::new().elem_type(Some(TypeAttribute::new(integer_type))), )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -710,7 +710,7 @@ mod tests { Default::default(), )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -750,7 +750,7 @@ mod tests { Default::default(), )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -793,7 +793,7 @@ mod tests { .nontemporal(true), )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -851,7 +851,7 @@ mod tests { { let block = Block::new(&[]); - block.append_operation(r#return(None, location)); + block.append_operation(r#return(&context, None, location)); let region = Region::new(); region.append_block(block); @@ -868,7 +868,11 @@ mod tests { { let block = Block::new(&[(struct_type, location)]); - block.append_operation(r#return(Some(block.argument(0).unwrap().into()), location)); + block.append_operation(r#return( + &context, + Some(block.argument(0).unwrap().into()), + location, + )); let region = Region::new(); region.append_block(block); @@ -913,7 +917,7 @@ mod tests { .unwrap() .into(); - block.append_operation(func::r#return(&[res], location)); + block.append_operation(func::r#return(&context, &[res], location)); let region = Region::new(); region.append_block(block); @@ -958,7 +962,7 @@ mod tests { .unwrap() .into(); - block.append_operation(func::r#return(&[res], location)); + block.append_operation(func::r#return(&context, &[res], location)); let region = Region::new(); region.append_block(block); @@ -993,6 +997,7 @@ mod tests { let res = block .append_operation(intr_ctpop( + &context, block.argument(0).unwrap().into(), integer_type, location, @@ -1001,7 +1006,7 @@ mod tests { .unwrap() .into(); - block.append_operation(func::r#return(&[res], location)); + block.append_operation(func::r#return(&context, &[res], location)); let region = Region::new(); region.append_block(block); @@ -1036,6 +1041,7 @@ mod tests { let res = block .append_operation(intr_bswap( + &context, block.argument(0).unwrap().into(), integer_type, location, @@ -1044,7 +1050,7 @@ mod tests { .unwrap() .into(); - block.append_operation(func::r#return(&[res], location)); + block.append_operation(func::r#return(&context, &[res], location)); let region = Region::new(); region.append_block(block); @@ -1079,6 +1085,7 @@ mod tests { let res = block .append_operation(intr_bitreverse( + &context, block.argument(0).unwrap().into(), integer_type, location, @@ -1087,7 +1094,7 @@ mod tests { .unwrap() .into(); - block.append_operation(func::r#return(&[res], location)); + block.append_operation(func::r#return(&context, &[res], location)); let region = Region::new(); region.append_block(block); @@ -1132,7 +1139,7 @@ mod tests { .unwrap() .into(); - block.append_operation(func::r#return(&[res], location)); + block.append_operation(func::r#return(&context, &[res], location)); let region = Region::new(); region.append_block(block); @@ -1168,6 +1175,7 @@ mod tests { let res = block .append_operation(zext( + &context, block.argument(0).unwrap().into(), integer_double_type, location, @@ -1176,7 +1184,7 @@ mod tests { .unwrap() .into(); - block.append_operation(func::r#return(&[res], location)); + block.append_operation(func::r#return(&context, &[res], location)); let region = Region::new(); region.append_block(block); diff --git a/melior/src/dialect/memref.rs b/melior/src/dialect/memref.rs index 9dbeff077d..17bd3b98ae 100644 --- a/melior/src/dialect/memref.rs +++ b/melior/src/dialect/memref.rs @@ -267,7 +267,7 @@ mod tests { let block = Block::new(&[]); build_block(&block); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -302,7 +302,11 @@ mod tests { None, location, )); - block.append_operation(dealloc(memref.result(0).unwrap().into(), location)); + block.append_operation(dealloc( + &context, + memref.result(0).unwrap().into(), + location, + )); }) } @@ -364,6 +368,7 @@ mod tests { )); block.append_operation(cast( + &context, memref.result(0).unwrap().into(), Type::parse(&context, "memref") .unwrap() @@ -396,6 +401,7 @@ mod tests { )); block.append_operation(dim( + &context, memref.result(0).unwrap().into(), index.result(0).unwrap().into(), location, @@ -429,7 +435,7 @@ mod tests { let block = Block::new(&[]); block.append_operation(get_global(&context, "foo", mem_ref_type, location)); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -510,7 +516,12 @@ mod tests { None, location, )); - block.append_operation(load(memref.result(0).unwrap().into(), &[], location)); + block.append_operation(load( + &context, + memref.result(0).unwrap().into(), + &[], + location, + )); }) } @@ -536,6 +547,7 @@ mod tests { )); block.append_operation(load( + &context, memref.result(0).unwrap().into(), &[index.result(0).unwrap().into()], location, @@ -557,7 +569,7 @@ mod tests { None, location, )); - block.append_operation(rank(memref.result(0).unwrap().into(), location)); + block.append_operation(rank(&context, memref.result(0).unwrap().into(), location)); }) } @@ -583,6 +595,7 @@ mod tests { )); block.append_operation(store( + &context, value.result(0).unwrap().into(), memref.result(0).unwrap().into(), &[], @@ -619,6 +632,7 @@ mod tests { )); block.append_operation(store( + &context, value.result(0).unwrap().into(), memref.result(0).unwrap().into(), &[index.result(0).unwrap().into()], From 740b06f702426837b830bbbbb1d1417cff3d82a0 Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 20:14:49 +1100 Subject: [PATCH 17/22] Fix --- melior/src/dialect/func.rs | 11 ++++++++--- melior/src/dialect/index.rs | 3 +++ melior/src/pass/external.rs | 6 +++--- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/melior/src/dialect/func.rs b/melior/src/dialect/func.rs index 9367ee80b4..6019eba412 100644 --- a/melior/src/dialect/func.rs +++ b/melior/src/dialect/func.rs @@ -118,7 +118,7 @@ mod tests { .result(0) .unwrap() .into(); - block.append_operation(r#return(&[value], location)); + block.append_operation(r#return(&context, &[value], location)); let region = Region::new(); region.append_block(block); @@ -158,6 +158,7 @@ mod tests { )); let value = block .append_operation(call_indirect( + &context, function.result(0).unwrap().into(), &[block.argument(0).unwrap().into()], &[index_type], @@ -166,7 +167,7 @@ mod tests { .result(0) .unwrap() .into(); - block.append_operation(r#return(&[value], location)); + block.append_operation(r#return(&context, &[value], location)); let region = Region::new(); region.append_block(block); @@ -194,7 +195,11 @@ mod tests { let function = { let block = Block::new(&[(integer_type, location)]); - block.append_operation(r#return(&[block.argument(0).unwrap().into()], location)); + block.append_operation(r#return( + &context, + &[block.argument(0).unwrap().into()], + location, + )); let region = Region::new(); region.append_block(block); diff --git a/melior/src/dialect/index.rs b/melior/src/dialect/index.rs index 967cac3bfe..53689fbd01 100644 --- a/melior/src/dialect/index.rs +++ b/melior/src/dialect/index.rs @@ -105,6 +105,7 @@ mod tests { let name = name.as_string_ref().as_str().unwrap(); block.append_operation(func::r#return( + &context, &[block.append_operation(operation).result(0).unwrap().into()], location, )); @@ -180,6 +181,7 @@ mod tests { &context, |block| { casts( + &context, block.argument(0).unwrap().into(), IntegerType::new(&context, 64).into(), Location::unknown(&context), @@ -201,6 +203,7 @@ mod tests { &context, |block| { castu( + &context, block.argument(0).unwrap().into(), IntegerType::new(&context, 64).into(), Location::unknown(&context), diff --git a/melior/src/pass/external.rs b/melior/src/pass/external.rs index 609b36ff93..5940066b4e 100644 --- a/melior/src/pass/external.rs +++ b/melior/src/pass/external.rs @@ -274,9 +274,9 @@ mod tests { } impl TestPass { - fn create(self) -> Pass { + fn into_pass(self, context: &Context) -> Pass { create_external( - &context, + context, self, TypeId::create(&TEST_PASS), "test pass", @@ -294,7 +294,7 @@ mod tests { let pass_manager = PassManager::new(&context); let test_pass = TestPass { value: 10 }; - pass_manager.add_pass(test_pass.create()); + pass_manager.add_pass(test_pass.into_pass(&context)); pass_manager.run(&mut module).unwrap(); } From 262dd0cac030c5a53a29b2a338b2eda677c73ee9 Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 20:16:20 +1100 Subject: [PATCH 18/22] Fix --- macro/src/operation.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/macro/src/operation.rs b/macro/src/operation.rs index b494adca21..f071c6b834 100644 --- a/macro/src/operation.rs +++ b/macro/src/operation.rs @@ -31,7 +31,7 @@ pub fn generate_binary(dialect: &Ident, names: &[Ident]) -> Result, location: crate::ir::Location<'c>, ) -> crate::ir::Operation<'c> { - crate::ir::operation::OperationBuilder::new(&context, context, name, location) + crate::ir::operation::OperationBuilder::new(&context, name, location) .add_operands(&[lhs, rhs]) .enable_result_type_inference() .build() @@ -67,7 +67,7 @@ pub fn generate_unary(dialect: &Ident, names: &[Ident]) -> Result, location: crate::ir::Location<'c>, ) -> crate::ir::Operation<'c> { - crate::ir::operation::OperationBuilder::new(&context, context, name, location) + crate::ir::operation::OperationBuilder::new(&context, name, location) .add_operands(&[value]) .enable_result_type_inference() .build() @@ -108,7 +108,7 @@ pub fn generate_typed_unary( r#type: crate::ir::Type<'c>, location: crate::ir::Location<'c>, ) -> crate::ir::Operation<'c> { - crate::ir::operation::OperationBuilder::new(&context, context, name, location) + crate::ir::operation::OperationBuilder::new(&context, name, location) .add_operands(&[value]) .add_results(&[r#type]) .build() From bcdab4a3a7af5deb9aa85cd992b17cf5c72826d3 Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 20:20:08 +1100 Subject: [PATCH 19/22] Fix --- macro/src/dialect/operation/accessors.rs | 2 +- macro/tests/operand.rs | 23 +++++++++++++++------ macro/tests/region.rs | 26 ++++++++++++++++-------- 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/macro/src/dialect/operation/accessors.rs b/macro/src/dialect/operation/accessors.rs index 314fb9c3fe..e9337a6539 100644 --- a/macro/src/dialect/operation/accessors.rs +++ b/macro/src/dialect/operation/accessors.rs @@ -89,7 +89,7 @@ impl<'a> OperationField<'a> { let attribute = ::melior::ir::attribute::DenseI32ArrayAttribute::<'c>::try_from( self.operation - .attribute(#attribute_name)? + .attribute(context, #attribute_name)? )?; let start = (0..#index) .map(|index| attribute.element(index)) diff --git a/macro/tests/operand.rs b/macro/tests/operand.rs index 27991a61fc..2843cc442b 100644 --- a/macro/tests/operand.rs +++ b/macro/tests/operand.rs @@ -18,14 +18,15 @@ fn simple() { let r#type = Type::parse(&context, "i32").unwrap(); let block = Block::new(&[(r#type, location), (r#type, location)]); let op = operand_test::simple( + &context, r#type, block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, ); - assert_eq!(op.lhs().unwrap(), block.argument(0).unwrap().into()); - assert_eq!(op.rhs().unwrap(), block.argument(1).unwrap().into()); + assert_eq!(op.lhs(&context).unwrap(), block.argument(0).unwrap().into()); + assert_eq!(op.rhs(&context).unwrap(), block.argument(1).unwrap().into()); assert_eq!(op.operation().operand_count(), 2); } @@ -39,6 +40,7 @@ fn variadic_after_single() { let r#type = Type::parse(&context, "i32").unwrap(); let block = Block::new(&[(r#type, location), (r#type, location), (r#type, location)]); let op = operand_test::variadic( + &context, r#type, block.argument(0).unwrap().into(), &[ @@ -48,9 +50,18 @@ fn variadic_after_single() { location, ); - assert_eq!(op.first().unwrap(), block.argument(0).unwrap().into()); - assert_eq!(op.others().next(), Some(block.argument(2).unwrap().into())); - assert_eq!(op.others().nth(1), Some(block.argument(1).unwrap().into())); + assert_eq!( + op.first(&context).unwrap(), + block.argument(0).unwrap().into() + ); + assert_eq!( + op.others(&context).next(), + Some(block.argument(2).unwrap().into()) + ); + assert_eq!( + op.others(&context).nth(1), + Some(block.argument(1).unwrap().into()) + ); assert_eq!(op.operation().operand_count(), 3); - assert_eq!(op.others().count(), 2); + assert_eq!(op.others(&context).count(), 2); } diff --git a/macro/tests/region.rs b/macro/tests/region.rs index a2bfd9b090..5a5fffcce2 100644 --- a/macro/tests/region.rs +++ b/macro/tests/region.rs @@ -19,10 +19,10 @@ fn single() { let block = Block::new(&[]); let r1 = Region::new(); r1.append_block(block); - region_test::single(r1, location) + region_test::single(&context, r1, location) }; - assert!(op.default_region().unwrap().first_block().is_some()); + assert!(op.default_region(&context).unwrap().first_block().is_some()); } #[test] @@ -36,14 +36,14 @@ fn variadic_after_single() { let block = Block::new(&[]); let (r1, r2, r3) = (Region::new(), Region::new(), Region::new()); r2.append_block(block); - region_test::variadic(r1, vec![r2, r3], location) + region_test::variadic(&context, r1, vec![r2, r3], location) }; let op2 = { let block = Block::new(&[]); let (r1, r2, r3) = (Region::new(), Region::new(), Region::new()); r2.append_block(block); - region_test::VariadicOp::builder(location) + region_test::VariadicOp::builder(&context, location) .default_region(r1) .other_regions(vec![r2, r3]) .build() @@ -51,8 +51,18 @@ fn variadic_after_single() { assert_eq!(op.operation().to_string(), op2.operation().to_string()); - assert!(op.default_region().unwrap().first_block().is_none()); - assert_eq!(op.other_regions().count(), 2); - assert!(op.other_regions().next().unwrap().first_block().is_some()); - assert!(op.other_regions().nth(1).unwrap().first_block().is_none()); + assert!(op.default_region(&context).unwrap().first_block().is_none()); + assert_eq!(op.other_regions(&context).count(), 2); + assert!(op + .other_regions(&context) + .next() + .unwrap() + .first_block() + .is_some()); + assert!(op + .other_regions(&context) + .nth(1) + .unwrap() + .first_block() + .is_none()); } From 585e02b8eeb6347850b3574f868b598040d5ef3d Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 20:41:16 +1100 Subject: [PATCH 20/22] Remove to_ref method from context --- README.md | 3 ++- melior/src/context.rs | 24 ------------------------ melior/src/pass/external.rs | 20 ++++++++++++++------ 3 files changed, 16 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index abf3f4fe98..b68f66fe07 100644 --- a/README.md +++ b/README.md @@ -42,12 +42,13 @@ module.body().append_operation(func::func( let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block.append_operation(arith::addi( + &context, block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location )); - block.append_operation(func::r#return(&[sum.result(0).unwrap().into()], location)); + block.append_operation(func::r#return(&context, &[sum.result(0).unwrap().into()], location)); let region = Region::new(); region.append_block(block); diff --git a/melior/src/context.rs b/melior/src/context.rs index cd700f2f5c..049e611a0a 100644 --- a/melior/src/context.rs +++ b/melior/src/context.rs @@ -162,22 +162,6 @@ pub struct ContextRef<'c> { } impl<'c> ContextRef<'c> { - /// Gets a context. - /// - /// This function is different from `deref` because the correct lifetime is - /// kept for the return type. - /// - /// # Safety - /// - /// The returned reference is safe to use only in the lifetime scope of the - /// context reference. - pub unsafe fn to_ref(&self) -> &'c Context { - // As we can't deref ContextRef<'a> into `&'a Context`, we forcibly cast its - // lifetime here to extend it from the lifetime of `ObjectRef<'a>` itself into - // `'a`. - transmute(self) - } - /// Creates a context reference from a raw object. /// /// # Safety @@ -191,14 +175,6 @@ impl<'c> ContextRef<'c> { } } -impl<'a> Deref for ContextRef<'a> { - type Target = Context; - - fn deref(&self) -> &Self::Target { - unsafe { transmute(self) } - } -} - impl<'a> PartialEq for ContextRef<'a> { fn eq(&self, other: &Self) -> bool { unsafe { mlirContextEqual(self.raw, other.raw) } diff --git a/melior/src/pass/external.rs b/melior/src/pass/external.rs index 5940066b4e..9ae020f90f 100644 --- a/melior/src/pass/external.rs +++ b/melior/src/pass/external.rs @@ -140,6 +140,7 @@ impl<'c, F: FnMut(OperationRef<'c, '_>, ExternalPass<'_>) + Clone> RunExternalPa /// /// ``` /// use melior::{ +/// Context, /// ir::{r#type::TypeId, OperationRef}, /// pass::{create_external, ExternalPass}, /// }; @@ -149,7 +150,10 @@ impl<'c, F: FnMut(OperationRef<'c, '_>, ExternalPass<'_>) + Clone> RunExternalPa /// /// static EXAMPLE_PASS: PassId = PassId; /// +/// let context = Context::new(); +/// /// create_external( +/// &context, /// |operation: OperationRef, _pass: ExternalPass| { /// operation.dump(); /// }, @@ -237,11 +241,12 @@ mod tests { static TEST_PASS: PassId = PassId; #[derive(Clone, Debug)] - struct TestPass { + struct TestPass<'c> { + context: &'c Context, value: i32, } - impl<'c> RunExternalPass<'c> for TestPass { + impl<'c> RunExternalPass<'c> for TestPass<'c> { fn construct(&mut self) { assert_eq!(self.value, 10); } @@ -268,12 +273,12 @@ mod tests { .first_operation() .expect("body has a function") .name() - == Identifier::new(&operation.context(), "func.func") + == Identifier::new(self.context, "func.func") ); } } - impl TestPass { + impl<'c> TestPass<'c> { fn into_pass(self, context: &Context) -> Pass { create_external( context, @@ -293,7 +298,10 @@ mod tests { let mut module = create_module(&context); let pass_manager = PassManager::new(&context); - let test_pass = TestPass { value: 10 }; + let test_pass = TestPass { + context: &context, + value: 10, + }; pass_manager.add_pass(test_pass.into_pass(&context)); pass_manager.run(&mut module).unwrap(); } @@ -320,7 +328,7 @@ mod tests { .first_operation() .expect("body has a function") .name() - == Identifier::new(&operation.context(), "func.func") + == Identifier::new(&context, "func.func") ); pass.signal_failure(); }, From 9c402b6a017d9aa3576e6245e10e29aea7edd9da Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 20:43:15 +1100 Subject: [PATCH 21/22] Fix linting --- melior/benches/main.rs | 9 ++++++--- melior/src/context.rs | 4 +--- melior/src/dialect/func.rs | 6 +++--- melior/src/dialect/index.rs | 6 +++--- melior/src/dialect/llvm.rs | 26 +++++++++++++------------- melior/src/dialect/memref.rs | 14 +++++++------- melior/src/dialect/scf.rs | 10 +++++----- melior/src/lib.rs | 2 +- melior/src/pass/external.rs | 3 ++- melior/src/string_ref.rs | 2 +- 10 files changed, 42 insertions(+), 40 deletions(-) diff --git a/melior/benches/main.rs b/melior/benches/main.rs index eb30d61b32..642f834152 100644 --- a/melior/benches/main.rs +++ b/melior/benches/main.rs @@ -1,5 +1,5 @@ use criterion::{criterion_group, criterion_main, Bencher, Criterion}; -use melior::StringRef; +use melior::{Context, StringRef}; const ITERATION_COUNT: usize = 1000000; @@ -10,19 +10,22 @@ fn generate_strings() -> Vec { } fn string_ref_create(bencher: &mut Bencher) { + let context = Context::new(); let strings = generate_strings(); bencher.iter(|| { for string in &strings { - let _ = StringRef::from(string.as_str()); + let _ = StringRef::from_str(&context, string.as_str()); } }); } fn string_ref_create_cached(bencher: &mut Bencher) { + let context = Context::new(); + bencher.iter(|| { for _ in 0..ITERATION_COUNT { - let _ = StringRef::from("foo"); + let _ = StringRef::from_str(&context, "foo"); } }); } diff --git a/melior/src/context.rs b/melior/src/context.rs index 049e611a0a..addd690976 100644 --- a/melior/src/context.rs +++ b/melior/src/context.rs @@ -16,8 +16,6 @@ use mlir_sys::{ use std::{ ffi::{c_void, CString}, marker::PhantomData, - mem::transmute, - ops::Deref, }; /// A context of IR, dialects, and passes. @@ -56,7 +54,7 @@ impl Context { unsafe { Dialect::from_raw(mlirContextGetOrLoadDialect( self.raw, - StringRef::from_str(&self, name).to_raw(), + StringRef::from_str(self, name).to_raw(), )) } } diff --git a/melior/src/dialect/func.rs b/melior/src/dialect/func.rs index 6019eba412..fbb036b1db 100644 --- a/melior/src/dialect/func.rs +++ b/melior/src/dialect/func.rs @@ -47,7 +47,7 @@ pub fn constant<'c>( r#type: FunctionType<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "func.constant", location) + OperationBuilder::new(context, "func.constant", location) .add_attributes(&[(Identifier::new(context, "value"), function.into())]) .add_results(&[r#type.into()]) .build() @@ -62,7 +62,7 @@ pub fn func<'c>( attributes: &[(Identifier<'c>, Attribute<'c>)], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "func.func", location) + OperationBuilder::new(context, "func.func", location) .add_attributes(&[ (Identifier::new(context, "sym_name"), name.into()), (Identifier::new(context, "function_type"), r#type.into()), @@ -78,7 +78,7 @@ pub fn r#return<'c>( operands: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "func.return", location) + OperationBuilder::new(context, "func.return", location) .add_operands(operands) .build() } diff --git a/melior/src/dialect/index.rs b/melior/src/dialect/index.rs index 53689fbd01..19fea786b3 100644 --- a/melior/src/dialect/index.rs +++ b/melior/src/dialect/index.rs @@ -17,7 +17,7 @@ pub fn constant<'c>( value: IntegerAttribute<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "index.constant", location) + OperationBuilder::new(context, "index.constant", location) .add_attributes(&[(Identifier::new(context, "value"), value.into())]) .enable_result_type_inference() .build() @@ -31,7 +31,7 @@ pub fn cmp<'c>( rhs: Value<'c, '_>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "index.cmp", location) + OperationBuilder::new(context, "index.cmp", location) .add_attributes(&[( Identifier::new(context, "pred"), Attribute::parse( @@ -105,7 +105,7 @@ mod tests { let name = name.as_string_ref().as_str().unwrap(); block.append_operation(func::r#return( - &context, + context, &[block.append_operation(operation).result(0).unwrap().into()], location, )); diff --git a/melior/src/dialect/llvm.rs b/melior/src/dialect/llvm.rs index 198e13994a..2b99085360 100644 --- a/melior/src/dialect/llvm.rs +++ b/melior/src/dialect/llvm.rs @@ -30,7 +30,7 @@ pub fn extract_value<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "llvm.extractvalue", location) + OperationBuilder::new(context, "llvm.extractvalue", location) .add_attributes(&[(Identifier::new(context, "position"), position.into())]) .add_operands(&[container]) .add_results(&[result_type]) @@ -46,7 +46,7 @@ pub fn get_element_ptr<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "llvm.getelementptr", location) + OperationBuilder::new(context, "llvm.getelementptr", location) .add_attributes(&[ ( Identifier::new(context, "rawConstantIndices"), @@ -71,7 +71,7 @@ pub fn get_element_ptr_dynamic<'c, const N: usize>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "llvm.getelementptr", location) + OperationBuilder::new(context, "llvm.getelementptr", location) .add_attributes(&[ ( Identifier::new(context, "rawConstantIndices"), @@ -96,7 +96,7 @@ pub fn insert_value<'c>( value: Value<'c, '_>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "llvm.insertvalue", location) + OperationBuilder::new(context, "llvm.insertvalue", location) .add_attributes(&[(Identifier::new(context, "position"), position.into())]) .add_operands(&[container, value]) .enable_result_type_inference() @@ -109,7 +109,7 @@ pub fn undef<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "llvm.mlir.undef", location) + OperationBuilder::new(context, "llvm.mlir.undef", location) .add_results(&[result_type]) .build() } @@ -120,7 +120,7 @@ pub fn poison<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "llvm.mlir.poison", location) + OperationBuilder::new(context, "llvm.mlir.poison", location) .add_results(&[result_type]) .build() } @@ -131,14 +131,14 @@ pub fn nullptr<'c>( ptr_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "llvm.mlir.null", location) + OperationBuilder::new(context, "llvm.mlir.null", location) .add_results(&[ptr_type]) .build() } /// Creates a `llvm.unreachable` operation. pub fn unreachable<'c>(context: &'c Context, location: Location<'c>) -> Operation<'c> { - OperationBuilder::new(&context, "llvm.unreachable", location).build() + OperationBuilder::new(context, "llvm.unreachable", location).build() } /// Creates a `llvm.bitcast` operation. @@ -148,7 +148,7 @@ pub fn bitcast<'c>( result: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "llvm.bitcast", location) + OperationBuilder::new(context, "llvm.bitcast", location) .add_operands(&[argument]) .add_results(&[result]) .build() @@ -162,7 +162,7 @@ pub fn alloca<'c>( location: Location<'c>, extra_options: AllocaOptions<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "llvm.alloca", location) + OperationBuilder::new(context, "llvm.alloca", location) .add_operands(&[array_size]) .add_attributes(&extra_options.into_attributes(context)) .add_results(&[ptr_type]) @@ -177,7 +177,7 @@ pub fn store<'c>( location: Location<'c>, extra_options: LoadStoreOptions<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "llvm.store", location) + OperationBuilder::new(context, "llvm.store", location) .add_operands(&[value, addr]) .add_attributes(&extra_options.into_attributes(context)) .build() @@ -384,10 +384,10 @@ mod tests { pass_manager.add_pass(pass::conversion::create_func_to_llvm()); pass_manager - .nested_under(&context, "func.func") + .nested_under(context, "func.func") .add_pass(pass::conversion::create_arith_to_llvm()); pass_manager - .nested_under(&context, "func.func") + .nested_under(context, "func.func") .add_pass(pass::conversion::create_index_to_llvm()); pass_manager.add_pass(pass::conversion::create_scf_to_control_flow()); pass_manager.add_pass(pass::conversion::create_control_flow_to_llvm()); diff --git a/melior/src/dialect/memref.rs b/melior/src/dialect/memref.rs index 17bd3b98ae..ba45efeabc 100644 --- a/melior/src/dialect/memref.rs +++ b/melior/src/dialect/memref.rs @@ -62,7 +62,7 @@ fn allocate<'c>( alignment: Option>, location: Location<'c>, ) -> Operation<'c> { - let mut builder = OperationBuilder::new(&context, name, location); + let mut builder = OperationBuilder::new(context, name, location); builder = builder.add_attributes(&[( Identifier::new(context, "operand_segment_sizes"), @@ -86,7 +86,7 @@ pub fn cast<'c>( r#type: MemRefType<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "memref.cast", location) + OperationBuilder::new(context, "memref.cast", location) .add_operands(&[value]) .add_results(&[r#type.into()]) .build() @@ -123,7 +123,7 @@ pub fn get_global<'c>( r#type: MemRefType<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "memref.get_global", location) + OperationBuilder::new(context, "memref.get_global", location) .add_attributes(&[( Identifier::new(context, "name"), FlatSymbolRefAttribute::new(context, name).into(), @@ -188,7 +188,7 @@ pub fn load<'c>( indices: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "memref.load", location) + OperationBuilder::new(context, "memref.load", location) .add_operands(&[memref]) .add_operands(indices) .enable_result_type_inference() @@ -201,7 +201,7 @@ pub fn rank<'c>( value: Value<'c, '_>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "memref.rank", location) + OperationBuilder::new(context, "memref.rank", location) .add_operands(&[value]) .enable_result_type_inference() .build() @@ -215,7 +215,7 @@ pub fn store<'c>( indices: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "memref.store", location) + OperationBuilder::new(context, "memref.store", location) .add_operands(&[value, memref]) .add_operands(indices) .build() @@ -267,7 +267,7 @@ mod tests { let block = Block::new(&[]); build_block(&block); - block.append_operation(func::r#return(&context, &[], location)); + block.append_operation(func::r#return(context, &[], location)); let region = Region::new(); region.append_block(block); diff --git a/melior/src/dialect/scf.rs b/melior/src/dialect/scf.rs index 8ea3e19db7..8b3daa1bf5 100644 --- a/melior/src/dialect/scf.rs +++ b/melior/src/dialect/scf.rs @@ -15,7 +15,7 @@ pub fn condition<'c>( values: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "scf.condition", location) + OperationBuilder::new(context, "scf.condition", location) .add_operands(&[condition]) .add_operands(values) .build() @@ -28,7 +28,7 @@ pub fn execute_region<'c>( region: Region<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "scf.execute_region", location) + OperationBuilder::new(context, "scf.execute_region", location) .add_results(result_types) .add_regions(vec![region]) .build() @@ -58,7 +58,7 @@ pub fn r#if<'c>( else_region: Region<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "scf.if", location) + OperationBuilder::new(context, "scf.if", location) .add_operands(&[condition]) .add_results(result_types) .add_regions(vec![then_region, else_region]) @@ -74,7 +74,7 @@ pub fn index_switch<'c>( regions: Vec>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "scf.index_switch", location) + OperationBuilder::new(context, "scf.index_switch", location) .add_operands(&[condition]) .add_results(result_types) .add_attributes(&[(Identifier::new(context, "cases"), cases.into())]) @@ -104,7 +104,7 @@ pub fn r#yield<'c>( values: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(&context, "scf.yield", location) + OperationBuilder::new(context, "scf.yield", location) .add_operands(values) .build() } diff --git a/melior/src/lib.rs b/melior/src/lib.rs index 5291f3b22f..80f6b46a4b 100644 --- a/melior/src/lib.rs +++ b/melior/src/lib.rs @@ -242,7 +242,7 @@ mod tests { rhs: Value<'c, '_>, ) -> Value<'c, 'a> { block - .append_operation(arith::addi(&context, lhs, rhs, Location::unknown(context))) + .append_operation(arith::addi(context, lhs, rhs, Location::unknown(context))) .result(0) .unwrap() .into() diff --git a/melior/src/pass/external.rs b/melior/src/pass/external.rs index 9ae020f90f..482b425cbd 100644 --- a/melior/src/pass/external.rs +++ b/melior/src/pass/external.rs @@ -165,6 +165,7 @@ impl<'c, F: FnMut(OperationRef<'c, '_>, ExternalPass<'_>) + Clone> RunExternalPa /// &[], /// ); /// ``` +#[allow(clippy::too_many_arguments)] pub fn create_external<'c, T: RunExternalPass<'c>>( context: &'c Context, pass: T, @@ -224,7 +225,7 @@ mod tests { TypeAttribute::new(FunctionType::new(context, &[], &[]).into()), { let block = Block::new(&[]); - block.append_operation(func::r#return(&context, &[], location)); + block.append_operation(func::r#return(context, &[], location)); let region = Region::new(); region.append_block(block); diff --git a/melior/src/string_ref.rs b/melior/src/string_ref.rs index 553bda07f8..5ef68a1216 100644 --- a/melior/src/string_ref.rs +++ b/melior/src/string_ref.rs @@ -23,7 +23,7 @@ impl<'c> StringRef<'c> { let string = context .string_cache() .entry(CString::new(string).unwrap()) - .or_insert_with(Default::default) + .or_default() .key() .as_ptr(); From afd30a6def6066a66d3c6433e6fe8d1969575036 Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Sat, 14 Oct 2023 20:43:30 +1100 Subject: [PATCH 22/22] Fix linting --- melior/src/ir/block.rs | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/melior/src/ir/block.rs b/melior/src/ir/block.rs index ffa88e1fa2..74b60a87bd 100644 --- a/melior/src/ir/block.rs +++ b/melior/src/ir/block.rs @@ -421,8 +421,9 @@ mod tests { context.set_allow_unregistered_dialects(true); let block = Block::new(&[]); - let operation = block - .append_operation(OperationBuilder::new(&context, "foo", Location::unknown(&context)).build()); + let operation = block.append_operation( + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(), + ); assert_eq!(block.first_operation(), Some(operation)); } @@ -440,7 +441,9 @@ mod tests { context.set_allow_unregistered_dialects(true); let block = Block::new(&[]); - block.append_operation(OperationBuilder::new(&context, "foo", Location::unknown(&context)).build()); + block.append_operation( + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(), + ); } #[test] @@ -461,8 +464,9 @@ mod tests { context.set_allow_unregistered_dialects(true); let block = Block::new(&[]); - let first_operation = block - .append_operation(OperationBuilder::new(&context, "foo", Location::unknown(&context)).build()); + let first_operation = block.append_operation( + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(), + ); let second_operation = block.insert_operation_after( first_operation, OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(), @@ -481,8 +485,9 @@ mod tests { context.set_allow_unregistered_dialects(true); let block = Block::new(&[]); - let second_operation = block - .append_operation(OperationBuilder::new(&context, "foo", Location::unknown(&context)).build()); + let second_operation = block.append_operation( + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(), + ); let first_operation = block.insert_operation_before( second_operation, OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(),