diff --git a/README.md b/README.md index abf3f4fe98..b68f66fe07 100644 --- a/README.md +++ b/README.md @@ -42,12 +42,13 @@ module.body().append_operation(func::func( let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block.append_operation(arith::addi( + &context, block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location )); - block.append_operation(func::r#return(&[sum.result(0).unwrap().into()], location)); + block.append_operation(func::r#return(&context, &[sum.result(0).unwrap().into()], location)); let region = Region::new(); region.append_block(block); diff --git a/macro/src/dialect/operation/accessors.rs b/macro/src/dialect/operation/accessors.rs index 857ec6f38e..e9337a6539 100644 --- a/macro/src/dialect/operation/accessors.rs +++ b/macro/src/dialect/operation/accessors.rs @@ -89,7 +89,7 @@ impl<'a> OperationField<'a> { let attribute = ::melior::ir::attribute::DenseI32ArrayAttribute::<'c>::try_from( self.operation - .attribute(#attribute_name)? + .attribute(context, #attribute_name)? )?; let start = (0..#index) .map(|index| attribute.element(index)) @@ -154,11 +154,11 @@ impl<'a> OperationField<'a> { let name = &self.name; Some(if constraint.is_unit()? { - quote! { self.operation.attribute(#name).is_some() } + quote! { self.operation.attribute(context, #name).is_some() } } else { quote! { self.operation - .attribute(#name)? + .attribute(context, #name)? .try_into() .map_err(::melior::Error::from) } @@ -174,7 +174,7 @@ impl<'a> OperationField<'a> { if constraint.is_unit()? || constraint.is_optional()? { Some(quote! { - self.operation.remove_attribute(#name) + self.operation.remove_attribute(context, #name) }) } else { None @@ -193,14 +193,14 @@ impl<'a> OperationField<'a> { Ok(Some(if constraint.is_unit()? { quote! { if value { - self.operation.set_attribute(#name, Attribute::unit(&self.operation.context())); + self.operation.set_attribute(context, #name, Attribute::unit(&self.operation.context())); } else { - self.operation.remove_attribute(#name) + self.operation.remove_attribute(context, #name) } } } else { quote! { - self.operation.set_attribute(#name, &value.into()); + self.operation.set_attribute(context, #name, &value.into()); } })) } @@ -213,7 +213,7 @@ impl<'a> OperationField<'a> { let parameter_type = &self.kind.parameter_type()?; quote! { - pub fn #ident(&mut self, value: #parameter_type) { + pub fn #ident(&mut self, context: &'c ::melior::Context, value: #parameter_type) { #body } } @@ -225,7 +225,7 @@ impl<'a> OperationField<'a> { let ident = sanitize_snake_case_name(&format!("remove_{}", self.name))?; self.remover_impl()?.map(|body| { quote! { - pub fn #ident(&mut self) -> Result<(), ::melior::Error> { + pub fn #ident(&mut self, context: &'c ::melior::Context) -> Result<(), ::melior::Error> { #body } } @@ -236,7 +236,7 @@ impl<'a> OperationField<'a> { let return_type = &self.kind.return_type()?; self.getter_impl()?.map(|body| { quote! { - pub fn #ident(&self) -> #return_type { + pub fn #ident(&self, context: &'c ::melior::Context) -> #return_type { #body } } diff --git a/macro/src/dialect/operation/builder.rs b/macro/src/dialect/operation/builder.rs index 5e4209c4ea..8cf30b7f94 100644 --- a/macro/src/dialect/operation/builder.rs +++ b/macro/src/dialect/operation/builder.rs @@ -51,7 +51,7 @@ impl<'o> OperationBuilder<'o> { quote! { &[( - ::melior::ir::Identifier::new(unsafe { self.context.to_ref() }, #name_string), + ::melior::ir::Identifier::new(self.context, #name_string), #name.into(), )] } @@ -142,7 +142,7 @@ impl<'o> OperationBuilder<'o> { #[doc = #doc] pub struct #builder_ident <'c, #(#iter_arguments),* > { builder: ::melior::ir::operation::OperationBuilder<'c>, - context: ::melior::ContextRef<'c>, + context: &'c ::melior::Context, #(#phantom_fields),* } @@ -180,10 +180,10 @@ impl<'o> OperationBuilder<'o> { quote! { impl<'c> #builder_ident<'c, #(#arguments),*> { - pub fn new(location: ::melior::ir::Location<'c>) -> Self { + pub fn new(context: &'c ::melior::Context, location: ::melior::ir::Location<'c>) -> Self { Self { - context: location.context(), - builder: ::melior::ir::operation::OperationBuilder::new(#name, location), + context, + builder: ::melior::ir::operation::OperationBuilder::new(&context, #name, location), #(#phantoms),* } } @@ -196,9 +196,10 @@ impl<'o> OperationBuilder<'o> { let arguments = self.type_state.arguments_all_set(false); quote! { pub fn builder( + context: &'c ::melior::Context, location: ::melior::ir::Location<'c> ) -> #builder_ident<'c, #(#arguments),*> { - #builder_ident::new(location) + #builder_ident::new(context, location) } } } @@ -229,8 +230,8 @@ impl<'o> OperationBuilder<'o> { Ok(quote! { #[allow(clippy::too_many_arguments)] #[doc = #doc] - pub fn #name<'c>(#(#arguments),*) -> #class_name<'c> { - #class_name::builder(location)#(#builder_calls)*.build() + pub fn #name<'c>(context: &'c ::melior::Context, #(#arguments),*) -> #class_name<'c> { + #class_name::builder(context, location)#(#builder_calls)*.build() } }) } diff --git a/macro/src/operation.rs b/macro/src/operation.rs index 7c43d621a5..f071c6b834 100644 --- a/macro/src/operation.rs +++ b/macro/src/operation.rs @@ -13,23 +13,25 @@ pub fn generate_binary(dialect: &Ident, names: &[Ident]) -> Result( + context: &'c Context, lhs: crate::ir::Value<'c, '_>, rhs: crate::ir::Value<'c, '_>, location: crate::ir::Location<'c>, ) -> crate::ir::Operation<'c> { - binary_operator(#operation_name, lhs, rhs, location) + binary_operator(context, #operation_name, lhs, rhs, location) } })); } stream.extend(TokenStream::from(quote! { fn binary_operator<'c>( + context: &'c Context, name: &str, lhs: crate::ir::Value<'c, '_>, rhs: crate::ir::Value<'c, '_>, location: crate::ir::Location<'c>, ) -> crate::ir::Operation<'c> { - crate::ir::operation::OperationBuilder::new(name, location) + crate::ir::operation::OperationBuilder::new(&context, name, location) .add_operands(&[lhs, rhs]) .enable_result_type_inference() .build() @@ -49,21 +51,23 @@ pub fn generate_unary(dialect: &Ident, names: &[Ident]) -> Result( + context: &'c Context, value: crate::ir::Value<'c, '_>, location: crate::ir::Location<'c>, ) -> crate::ir::Operation<'c> { - unary_operator(#operation_name, value, location) + unary_operator(context, #operation_name, value, location) } })); } stream.extend(TokenStream::from(quote! { fn unary_operator<'c>( + context: &'c Context, name: &str, value: crate::ir::Value<'c, '_>, location: crate::ir::Location<'c>, ) -> crate::ir::Operation<'c> { - crate::ir::operation::OperationBuilder::new(name, location) + crate::ir::operation::OperationBuilder::new(&context, name, location) .add_operands(&[value]) .enable_result_type_inference() .build() @@ -86,23 +90,25 @@ pub fn generate_typed_unary( stream.extend(TokenStream::from(quote! { #[doc = #document] pub fn #name<'c>( + context: &'c Context, value: crate::ir::Value<'c, '_>, r#type: crate::ir::Type<'c>, location: crate::ir::Location<'c>, ) -> crate::ir::Operation<'c> { - typed_unary_operator(#operation_name, value, r#type, location) + typed_unary_operator(context, #operation_name, value, r#type, location) } })); } stream.extend(TokenStream::from(quote! { fn typed_unary_operator<'c>( + context: &'c Context, name: &str, value: crate::ir::Value<'c, '_>, r#type: crate::ir::Type<'c>, location: crate::ir::Location<'c>, ) -> crate::ir::Operation<'c> { - crate::ir::operation::OperationBuilder::new(name, location) + crate::ir::operation::OperationBuilder::new(&context, name, location) .add_operands(&[value]) .add_results(&[r#type]) .build() diff --git a/macro/src/type.rs b/macro/src/type.rs index 62d01ade7e..d427f33a2c 100644 --- a/macro/src/type.rs +++ b/macro/src/type.rs @@ -1,6 +1,5 @@ use crate::utility::map_name; use convert_case::{Case, Casing}; - use proc_macro::TokenStream; use proc_macro2::Ident; use quote::quote; diff --git a/macro/tests/operand.rs b/macro/tests/operand.rs index 27991a61fc..2843cc442b 100644 --- a/macro/tests/operand.rs +++ b/macro/tests/operand.rs @@ -18,14 +18,15 @@ fn simple() { let r#type = Type::parse(&context, "i32").unwrap(); let block = Block::new(&[(r#type, location), (r#type, location)]); let op = operand_test::simple( + &context, r#type, block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, ); - assert_eq!(op.lhs().unwrap(), block.argument(0).unwrap().into()); - assert_eq!(op.rhs().unwrap(), block.argument(1).unwrap().into()); + assert_eq!(op.lhs(&context).unwrap(), block.argument(0).unwrap().into()); + assert_eq!(op.rhs(&context).unwrap(), block.argument(1).unwrap().into()); assert_eq!(op.operation().operand_count(), 2); } @@ -39,6 +40,7 @@ fn variadic_after_single() { let r#type = Type::parse(&context, "i32").unwrap(); let block = Block::new(&[(r#type, location), (r#type, location), (r#type, location)]); let op = operand_test::variadic( + &context, r#type, block.argument(0).unwrap().into(), &[ @@ -48,9 +50,18 @@ fn variadic_after_single() { location, ); - assert_eq!(op.first().unwrap(), block.argument(0).unwrap().into()); - assert_eq!(op.others().next(), Some(block.argument(2).unwrap().into())); - assert_eq!(op.others().nth(1), Some(block.argument(1).unwrap().into())); + assert_eq!( + op.first(&context).unwrap(), + block.argument(0).unwrap().into() + ); + assert_eq!( + op.others(&context).next(), + Some(block.argument(2).unwrap().into()) + ); + assert_eq!( + op.others(&context).nth(1), + Some(block.argument(1).unwrap().into()) + ); assert_eq!(op.operation().operand_count(), 3); - assert_eq!(op.others().count(), 2); + assert_eq!(op.others(&context).count(), 2); } diff --git a/macro/tests/region.rs b/macro/tests/region.rs index a2bfd9b090..5a5fffcce2 100644 --- a/macro/tests/region.rs +++ b/macro/tests/region.rs @@ -19,10 +19,10 @@ fn single() { let block = Block::new(&[]); let r1 = Region::new(); r1.append_block(block); - region_test::single(r1, location) + region_test::single(&context, r1, location) }; - assert!(op.default_region().unwrap().first_block().is_some()); + assert!(op.default_region(&context).unwrap().first_block().is_some()); } #[test] @@ -36,14 +36,14 @@ fn variadic_after_single() { let block = Block::new(&[]); let (r1, r2, r3) = (Region::new(), Region::new(), Region::new()); r2.append_block(block); - region_test::variadic(r1, vec![r2, r3], location) + region_test::variadic(&context, r1, vec![r2, r3], location) }; let op2 = { let block = Block::new(&[]); let (r1, r2, r3) = (Region::new(), Region::new(), Region::new()); r2.append_block(block); - region_test::VariadicOp::builder(location) + region_test::VariadicOp::builder(&context, location) .default_region(r1) .other_regions(vec![r2, r3]) .build() @@ -51,8 +51,18 @@ fn variadic_after_single() { assert_eq!(op.operation().to_string(), op2.operation().to_string()); - assert!(op.default_region().unwrap().first_block().is_none()); - assert_eq!(op.other_regions().count(), 2); - assert!(op.other_regions().next().unwrap().first_block().is_some()); - assert!(op.other_regions().nth(1).unwrap().first_block().is_none()); + assert!(op.default_region(&context).unwrap().first_block().is_none()); + assert_eq!(op.other_regions(&context).count(), 2); + assert!(op + .other_regions(&context) + .next() + .unwrap() + .first_block() + .is_some()); + assert!(op + .other_regions(&context) + .nth(1) + .unwrap() + .first_block() + .is_none()); } diff --git a/melior/benches/main.rs b/melior/benches/main.rs index eb30d61b32..642f834152 100644 --- a/melior/benches/main.rs +++ b/melior/benches/main.rs @@ -1,5 +1,5 @@ use criterion::{criterion_group, criterion_main, Bencher, Criterion}; -use melior::StringRef; +use melior::{Context, StringRef}; const ITERATION_COUNT: usize = 1000000; @@ -10,19 +10,22 @@ fn generate_strings() -> Vec { } fn string_ref_create(bencher: &mut Bencher) { + let context = Context::new(); let strings = generate_strings(); bencher.iter(|| { for string in &strings { - let _ = StringRef::from(string.as_str()); + let _ = StringRef::from_str(&context, string.as_str()); } }); } fn string_ref_create_cached(bencher: &mut Bencher) { + let context = Context::new(); + bencher.iter(|| { for _ in 0..ITERATION_COUNT { - let _ = StringRef::from("foo"); + let _ = StringRef::from_str(&context, "foo"); } }); } diff --git a/melior/src/context.rs b/melior/src/context.rs index ec14451448..addd690976 100644 --- a/melior/src/context.rs +++ b/melior/src/context.rs @@ -4,6 +4,7 @@ use crate::{ logical_result::LogicalResult, string_ref::StringRef, }; +use dashmap::DashMap; use mlir_sys::{ mlirContextAppendDialectRegistry, mlirContextAttachDiagnosticHandler, mlirContextCreate, mlirContextDestroy, mlirContextDetachDiagnosticHandler, mlirContextEnableMultithreading, @@ -12,7 +13,10 @@ use mlir_sys::{ mlirContextIsRegisteredOperation, mlirContextLoadAllAvailableDialects, mlirContextSetAllowUnregisteredDialects, MlirContext, MlirDiagnostic, MlirLogicalResult, }; -use std::{ffi::c_void, marker::PhantomData, mem::transmute, ops::Deref}; +use std::{ + ffi::{c_void, CString}, + marker::PhantomData, +}; /// A context of IR, dialects, and passes. /// @@ -21,6 +25,9 @@ use std::{ffi::c_void, marker::PhantomData, mem::transmute, ops::Deref}; #[derive(Debug)] pub struct Context { raw: MlirContext, + // We need to pass null-terminated strings to functions in the MLIR API although + // Rust's strings are not. + string_cache: DashMap, } impl Context { @@ -28,6 +35,7 @@ impl Context { pub fn new() -> Self { Self { raw: unsafe { mlirContextCreate() }, + string_cache: Default::default(), } } @@ -46,7 +54,7 @@ impl Context { unsafe { Dialect::from_raw(mlirContextGetOrLoadDialect( self.raw, - StringRef::from(name).to_raw(), + StringRef::from_str(self, name).to_raw(), )) } } @@ -78,7 +86,9 @@ impl Context { /// Returns `true` if a given operation is registered in a context. pub fn is_registered_operation(&self, name: &str) -> bool { - unsafe { mlirContextIsRegisteredOperation(self.raw, StringRef::from(name).to_raw()) } + unsafe { + mlirContextIsRegisteredOperation(self.raw, StringRef::from_str(self, name).to_raw()) + } } /// Converts a context into a raw object. @@ -116,6 +126,10 @@ impl Context { pub fn detach_diagnostic_handler(&self, id: DiagnosticHandlerId) { unsafe { mlirContextDetachDiagnosticHandler(self.to_raw(), id.to_raw()) } } + + pub(crate) fn string_cache(&self) -> &DashMap { + &self.string_cache + } } impl Drop for Context { @@ -146,22 +160,6 @@ pub struct ContextRef<'c> { } impl<'c> ContextRef<'c> { - /// Gets a context. - /// - /// This function is different from `deref` because the correct lifetime is - /// kept for the return type. - /// - /// # Safety - /// - /// The returned reference is safe to use only in the lifetime scope of the - /// context reference. - pub unsafe fn to_ref(&self) -> &'c Context { - // As we can't deref ContextRef<'a> into `&'a Context`, we forcibly cast its - // lifetime here to extend it from the lifetime of `ObjectRef<'a>` itself into - // `'a`. - transmute(self) - } - /// Creates a context reference from a raw object. /// /// # Safety @@ -175,14 +173,6 @@ impl<'c> ContextRef<'c> { } } -impl<'a> Deref for ContextRef<'a> { - type Target = Context; - - fn deref(&self) -> &Self::Target { - unsafe { transmute(self) } - } -} - impl<'a> PartialEq for ContextRef<'a> { fn eq(&self, other: &Self) -> bool { unsafe { mlirContextEqual(self.raw, other.raw) } diff --git a/melior/src/dialect/arith.rs b/melior/src/dialect/arith.rs index 38e06a57ec..f3de7ab665 100644 --- a/melior/src/dialect/arith.rs +++ b/melior/src/dialect/arith.rs @@ -16,7 +16,7 @@ pub fn constant<'c>( value: Attribute<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("arith.constant", location) + OperationBuilder::new(context, "arith.constant", location) .add_attributes(&[(Identifier::new(context, "value"), value)]) .enable_result_type_inference() .build() @@ -86,7 +86,7 @@ fn cmp<'c>( rhs: Value<'c, '_>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new(name, location) + OperationBuilder::new(context, name, location) .add_attributes(&[( Identifier::new(context, "predicate"), IntegerAttribute::new(predicate, IntegerType::new(context, 64).into()).into(), @@ -98,12 +98,13 @@ fn cmp<'c>( /// Creates an `arith.select` operation. pub fn select<'c>( + context: &'c Context, condition: Value<'c, '_>, true_value: Value<'c, '_>, false_value: Value<'c, '_>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("arith.select", location) + OperationBuilder::new(context, "arith.select", location) .add_operands(&[condition, true_value, false_value]) .add_results(&[true_value.r#type()]) .build() @@ -205,6 +206,7 @@ mod tests { let name = name.as_string_ref().as_str().unwrap(); block.append_operation(func::r#return( + context, &[block.append_operation(operation).result(0).unwrap().into()], location, )); @@ -255,6 +257,7 @@ mod tests { &context, |block| { negf( + &context, block.argument(0).unwrap().into(), Location::unknown(&context), ) @@ -331,6 +334,7 @@ mod tests { &context, |block| { bitcast( + &context, block.argument(0).unwrap().into(), float_type, Location::unknown(&context), @@ -349,6 +353,7 @@ mod tests { &context, |block| { extf( + &context, block.argument(0).unwrap().into(), Type::float64(&context), Location::unknown(&context), @@ -371,6 +376,7 @@ mod tests { &context, |block| { extsi( + &context, block.argument(0).unwrap().into(), IntegerType::new(&context, 64).into(), Location::unknown(&context), @@ -393,6 +399,7 @@ mod tests { &context, |block| { extui( + &context, block.argument(0).unwrap().into(), IntegerType::new(&context, 64).into(), Location::unknown(&context), @@ -415,6 +422,7 @@ mod tests { &context, |block| { fptosi( + &context, block.argument(0).unwrap().into(), IntegerType::new(&context, 64).into(), Location::unknown(&context), @@ -437,6 +445,7 @@ mod tests { &context, |block| { fptoui( + &context, block.argument(0).unwrap().into(), IntegerType::new(&context, 64).into(), Location::unknown(&context), @@ -459,6 +468,7 @@ mod tests { &context, |block| { index_cast( + &context, block.argument(0).unwrap().into(), IntegerType::new(&context, 64).into(), Location::unknown(&context), @@ -481,6 +491,7 @@ mod tests { &context, |block| { index_castui( + &context, block.argument(0).unwrap().into(), IntegerType::new(&context, 64).into(), Location::unknown(&context), @@ -503,6 +514,7 @@ mod tests { &context, |block| { sitofp( + &context, block.argument(0).unwrap().into(), Type::float64(&context), Location::unknown(&context), @@ -525,6 +537,7 @@ mod tests { &context, |block| { trunci( + &context, block.argument(0).unwrap().into(), IntegerType::new(&context, 32).into(), Location::unknown(&context), @@ -547,6 +560,7 @@ mod tests { &context, |block| { uitofp( + &context, block.argument(0).unwrap().into(), Type::float64(&context), Location::unknown(&context), @@ -576,12 +590,17 @@ mod tests { let block = Block::new(&[(integer_type, location), (integer_type, location)]); let sum = block.append_operation(addi( + &context, block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )); - block.append_operation(func::r#return(&[sum.result(0).unwrap().into()], location)); + block.append_operation(func::r#return( + &context, + &[sum.result(0).unwrap().into()], + location, + )); let region = Region::new(); region.append_block(block); @@ -624,13 +643,18 @@ mod tests { ]); let val = block.append_operation(select( + &context, block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), block.argument(2).unwrap().into(), location, )); - block.append_operation(func::r#return(&[val.result(0).unwrap().into()], location)); + block.append_operation(func::r#return( + &context, + &[val.result(0).unwrap().into()], + location, + )); let region = Region::new(); region.append_block(block); diff --git a/melior/src/dialect/cf.rs b/melior/src/dialect/cf.rs index ce553edaed..64505d28b2 100644 --- a/melior/src/dialect/cf.rs +++ b/melior/src/dialect/cf.rs @@ -19,7 +19,7 @@ pub fn assert<'c>( message: &str, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("cf.assert", location) + OperationBuilder::new(context, "cf.assert", location) .add_attributes(&[( Identifier::new(context, "msg"), StringAttribute::new(context, message).into(), @@ -30,11 +30,12 @@ pub fn assert<'c>( /// Creates a `cf.br` operation. pub fn br<'c>( + context: &'c Context, successor: &Block<'c>, destination_operands: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("cf.br", location) + OperationBuilder::new(context, "cf.br", location) .add_operands(destination_operands) .add_successors(&[successor]) .build() @@ -50,7 +51,7 @@ pub fn cond_br<'c>( false_successor_operands: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("cf.cond_br", location) + OperationBuilder::new(context, "cf.cond_br", location) .add_attributes(&[( Identifier::new(context, "operand_segment_sizes"), DenseI32ArrayAttribute::new( @@ -89,7 +90,7 @@ pub fn switch<'c>( .chain(case_destinations.iter().copied()) .unzip(); - Ok(OperationBuilder::new("cf.switch", location) + Ok(OperationBuilder::new(context, "cf.switch", location) .add_attributes(&[ ( Identifier::new(context, "case_values"), @@ -183,7 +184,7 @@ mod tests { block.append_operation(assert(&context, operand, "assert message", location)); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -222,9 +223,9 @@ mod tests { .result(0) .unwrap(); - block.append_operation(br(&dest_block, &[operand.into()], location)); + block.append_operation(br(&context, &dest_block, &[operand.into()], location)); - dest_block.append_operation(func::r#return(&[], location)); + dest_block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -289,8 +290,8 @@ mod tests { location, )); - true_block.append_operation(func::r#return(&[], location)); - false_block.append_operation(func::r#return(&[], location)); + true_block.append_operation(func::r#return(&context, &[], location)); + false_block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -348,9 +349,9 @@ mod tests { .unwrap(), ); - default_block.append_operation(func::r#return(&[], location)); - first_block.append_operation(func::r#return(&[], location)); - second_block.append_operation(func::r#return(&[], location)); + default_block.append_operation(func::r#return(&context, &[], location)); + first_block.append_operation(func::r#return(&context, &[], location)); + second_block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); diff --git a/melior/src/dialect/func.rs b/melior/src/dialect/func.rs index deb5256512..fbb036b1db 100644 --- a/melior/src/dialect/func.rs +++ b/melior/src/dialect/func.rs @@ -18,7 +18,7 @@ pub fn call<'c>( result_types: &[Type<'c>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("func.call", location) + OperationBuilder::new(context, "func.call", location) .add_attributes(&[(Identifier::new(context, "callee"), function.into())]) .add_operands(arguments) .add_results(result_types) @@ -27,12 +27,13 @@ pub fn call<'c>( /// Create a `func.call_indirect` operation. pub fn call_indirect<'c>( + context: &'c Context, function: Value<'c, '_>, arguments: &[Value<'c, '_>], result_types: &[Type<'c>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("func.call_indirect", location) + OperationBuilder::new(context, "func.call_indirect", location) .add_operands(&[function]) .add_operands(arguments) .add_results(result_types) @@ -46,7 +47,7 @@ pub fn constant<'c>( r#type: FunctionType<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("func.constant", location) + OperationBuilder::new(context, "func.constant", location) .add_attributes(&[(Identifier::new(context, "value"), function.into())]) .add_results(&[r#type.into()]) .build() @@ -61,7 +62,7 @@ pub fn func<'c>( attributes: &[(Identifier<'c>, Attribute<'c>)], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("func.func", location) + OperationBuilder::new(context, "func.func", location) .add_attributes(&[ (Identifier::new(context, "sym_name"), name.into()), (Identifier::new(context, "function_type"), r#type.into()), @@ -72,8 +73,12 @@ pub fn func<'c>( } /// Create a `func.return` operation. -pub fn r#return<'c>(operands: &[Value<'c, '_>], location: Location<'c>) -> Operation<'c> { - OperationBuilder::new("func.return", location) +pub fn r#return<'c>( + context: &'c Context, + operands: &[Value<'c, '_>], + location: Location<'c>, +) -> Operation<'c> { + OperationBuilder::new(context, "func.return", location) .add_operands(operands) .build() } @@ -113,7 +118,7 @@ mod tests { .result(0) .unwrap() .into(); - block.append_operation(r#return(&[value], location)); + block.append_operation(r#return(&context, &[value], location)); let region = Region::new(); region.append_block(block); @@ -153,6 +158,7 @@ mod tests { )); let value = block .append_operation(call_indirect( + &context, function.result(0).unwrap().into(), &[block.argument(0).unwrap().into()], &[index_type], @@ -161,7 +167,7 @@ mod tests { .result(0) .unwrap() .into(); - block.append_operation(r#return(&[value], location)); + block.append_operation(r#return(&context, &[value], location)); let region = Region::new(); region.append_block(block); @@ -189,7 +195,11 @@ mod tests { let function = { let block = Block::new(&[(integer_type, location)]); - block.append_operation(r#return(&[block.argument(0).unwrap().into()], location)); + block.append_operation(r#return( + &context, + &[block.argument(0).unwrap().into()], + location, + )); let region = Region::new(); region.append_block(block); diff --git a/melior/src/dialect/index.rs b/melior/src/dialect/index.rs index bd123f0c5f..19fea786b3 100644 --- a/melior/src/dialect/index.rs +++ b/melior/src/dialect/index.rs @@ -17,7 +17,7 @@ pub fn constant<'c>( value: IntegerAttribute<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("index.constant", location) + OperationBuilder::new(context, "index.constant", location) .add_attributes(&[(Identifier::new(context, "value"), value.into())]) .enable_result_type_inference() .build() @@ -31,7 +31,7 @@ pub fn cmp<'c>( rhs: Value<'c, '_>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("index.cmp", location) + OperationBuilder::new(context, "index.cmp", location) .add_attributes(&[( Identifier::new(context, "pred"), Attribute::parse( @@ -105,6 +105,7 @@ mod tests { let name = name.as_string_ref().as_str().unwrap(); block.append_operation(func::r#return( + context, &[block.append_operation(operation).result(0).unwrap().into()], location, )); @@ -180,6 +181,7 @@ mod tests { &context, |block| { casts( + &context, block.argument(0).unwrap().into(), IntegerType::new(&context, 64).into(), Location::unknown(&context), @@ -201,6 +203,7 @@ mod tests { &context, |block| { castu( + &context, block.argument(0).unwrap().into(), IntegerType::new(&context, 64).into(), Location::unknown(&context), @@ -229,12 +232,17 @@ mod tests { let block = Block::new(&[(integer_type, location), (integer_type, location)]); let sum = block.append_operation(add( + &context, block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )); - block.append_operation(func::r#return(&[sum.result(0).unwrap().into()], location)); + block.append_operation(func::r#return( + &context, + &[sum.result(0).unwrap().into()], + location, + )); let region = Region::new(); region.append_block(block); diff --git a/melior/src/dialect/llvm.rs b/melior/src/dialect/llvm.rs index b86eed7158..2b99085360 100644 --- a/melior/src/dialect/llvm.rs +++ b/melior/src/dialect/llvm.rs @@ -30,7 +30,7 @@ pub fn extract_value<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.extractvalue", location) + OperationBuilder::new(context, "llvm.extractvalue", location) .add_attributes(&[(Identifier::new(context, "position"), position.into())]) .add_operands(&[container]) .add_results(&[result_type]) @@ -46,7 +46,7 @@ pub fn get_element_ptr<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.getelementptr", location) + OperationBuilder::new(context, "llvm.getelementptr", location) .add_attributes(&[ ( Identifier::new(context, "rawConstantIndices"), @@ -71,7 +71,7 @@ pub fn get_element_ptr_dynamic<'c, const N: usize>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.getelementptr", location) + OperationBuilder::new(context, "llvm.getelementptr", location) .add_attributes(&[ ( Identifier::new(context, "rawConstantIndices"), @@ -96,7 +96,7 @@ pub fn insert_value<'c>( value: Value<'c, '_>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.insertvalue", location) + OperationBuilder::new(context, "llvm.insertvalue", location) .add_attributes(&[(Identifier::new(context, "position"), position.into())]) .add_operands(&[container, value]) .enable_result_type_inference() @@ -104,36 +104,53 @@ pub fn insert_value<'c>( } /// Creates a `llvm.mlir.undef` operation. -pub fn undef<'c>(result_type: Type<'c>, location: Location<'c>) -> Operation<'c> { - OperationBuilder::new("llvm.mlir.undef", location) +pub fn undef<'c>( + context: &'c Context, + result_type: Type<'c>, + location: Location<'c>, +) -> Operation<'c> { + OperationBuilder::new(context, "llvm.mlir.undef", location) .add_results(&[result_type]) .build() } /// Creates a `llvm.mlir.poison` operation. -pub fn poison<'c>(result_type: Type<'c>, location: Location<'c>) -> Operation<'c> { - OperationBuilder::new("llvm.mlir.poison", location) +pub fn poison<'c>( + context: &'c Context, + result_type: Type<'c>, + location: Location<'c>, +) -> Operation<'c> { + OperationBuilder::new(context, "llvm.mlir.poison", location) .add_results(&[result_type]) .build() } /// Creates a `llvm.mlir.null` operation. A null pointer. -pub fn nullptr<'c>(ptr_type: Type<'c>, location: Location<'c>) -> Operation<'c> { - OperationBuilder::new("llvm.mlir.null", location) +pub fn nullptr<'c>( + context: &'c Context, + ptr_type: Type<'c>, + location: Location<'c>, +) -> Operation<'c> { + OperationBuilder::new(context, "llvm.mlir.null", location) .add_results(&[ptr_type]) .build() } /// Creates a `llvm.unreachable` operation. -pub fn unreachable(location: Location) -> Operation { - OperationBuilder::new("llvm.unreachable", location).build() +pub fn unreachable<'c>(context: &'c Context, location: Location<'c>) -> Operation<'c> { + OperationBuilder::new(context, "llvm.unreachable", location).build() } /// Creates a `llvm.bitcast` operation. -pub fn bitcast<'c>(arg: Value<'c, '_>, res: Type<'c>, location: Location<'c>) -> Operation<'c> { - OperationBuilder::new("llvm.bitcast", location) - .add_operands(&[arg]) - .add_results(&[res]) +pub fn bitcast<'c>( + context: &'c Context, + argument: Value<'c, '_>, + result: Type<'c>, + location: Location<'c>, +) -> Operation<'c> { + OperationBuilder::new(context, "llvm.bitcast", location) + .add_operands(&[argument]) + .add_results(&[result]) .build() } @@ -145,7 +162,7 @@ pub fn alloca<'c>( location: Location<'c>, extra_options: AllocaOptions<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.alloca", location) + OperationBuilder::new(context, "llvm.alloca", location) .add_operands(&[array_size]) .add_attributes(&extra_options.into_attributes(context)) .add_results(&[ptr_type]) @@ -160,7 +177,7 @@ pub fn store<'c>( location: Location<'c>, extra_options: LoadStoreOptions<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.store", location) + OperationBuilder::new(context, "llvm.store", location) .add_operands(&[value, addr]) .add_attributes(&extra_options.into_attributes(context)) .build() @@ -174,7 +191,7 @@ pub fn load<'c>( location: Location<'c>, extra_options: LoadStoreOptions<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.load", location) + OperationBuilder::new(context, "llvm.load", location) .add_operands(&[addr]) .add_attributes(&extra_options.into_attributes(context)) .add_results(&[r#type]) @@ -190,7 +207,7 @@ pub fn func<'c>( attributes: &[(Identifier<'c>, Attribute<'c>)], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.func", location) + OperationBuilder::new(context, "llvm.func", location) .add_attributes(&[ (Identifier::new(context, "sym_name"), name.into()), (Identifier::new(context, "function_type"), r#type.into()), @@ -201,8 +218,12 @@ pub fn func<'c>( } // Creates a `llvm.return` operation. -pub fn r#return<'c>(value: Option>, location: Location<'c>) -> Operation<'c> { - let mut builder = OperationBuilder::new("llvm.return", location); +pub fn r#return<'c>( + context: &'c Context, + value: Option>, + location: Location<'c>, +) -> Operation<'c> { + let mut builder = OperationBuilder::new(context, "llvm.return", location); if let Some(value) = value { builder = builder.add_operands(&[value]); @@ -219,7 +240,7 @@ pub fn call_intrinsic<'c>( results: &[Type<'c>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.call_intrinsic", location) + OperationBuilder::new(context, "llvm.call_intrinsic", location) .add_operands(args) .add_attributes(&[(Identifier::new(context, "intrin"), intrin.into())]) .add_results(results) @@ -234,7 +255,7 @@ pub fn intr_ctlz<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.intr.ctlz", location) + OperationBuilder::new(context, "llvm.intr.ctlz", location) .add_attributes(&[( Identifier::new(context, "is_zero_poison"), IntegerAttribute::new(is_zero_poison.into(), IntegerType::new(context, 1).into()) @@ -253,7 +274,7 @@ pub fn intr_cttz<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.intr.cttz", location) + OperationBuilder::new(context, "llvm.intr.cttz", location) .add_attributes(&[( Identifier::new(context, "is_zero_poison"), IntegerAttribute::new(is_zero_poison.into(), IntegerType::new(context, 1).into()) @@ -266,11 +287,12 @@ pub fn intr_cttz<'c>( /// Creates a `llvm.intr.ctlz` operation. pub fn intr_ctpop<'c>( + context: &'c Context, value: Value<'c, '_>, result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.intr.ctpop", location) + OperationBuilder::new(context, "llvm.intr.ctpop", location) .add_operands(&[value]) .add_results(&[result_type]) .build() @@ -278,11 +300,12 @@ pub fn intr_ctpop<'c>( /// Creates a `llvm.intr.bswap` operation. pub fn intr_bswap<'c>( + context: &'c Context, value: Value<'c, '_>, result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.intr.bswap", location) + OperationBuilder::new(context, "llvm.intr.bswap", location) .add_operands(&[value]) .add_results(&[result_type]) .build() @@ -290,11 +313,12 @@ pub fn intr_bswap<'c>( /// Creates a `llvm.intr.bitreverse` operation. pub fn intr_bitreverse<'c>( + context: &'c Context, value: Value<'c, '_>, result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.intr.bitreverse", location) + OperationBuilder::new(context, "llvm.intr.bitreverse", location) .add_operands(&[value]) .add_results(&[result_type]) .build() @@ -308,7 +332,7 @@ pub fn intr_abs<'c>( result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.intr.abs", location) + OperationBuilder::new(context, "llvm.intr.abs", location) .add_attributes(&[( Identifier::new(context, "is_int_min_poison"), IntegerAttribute::new( @@ -324,11 +348,12 @@ pub fn intr_abs<'c>( /// Creates a `llvm.zext` operation. pub fn zext<'c>( + context: &'c Context, value: Value<'c, '_>, result_type: Type<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("llvm.zext", location) + OperationBuilder::new(context, "llvm.zext", location) .add_operands(&[value]) .add_results(&[result_type]) .build() @@ -359,10 +384,10 @@ mod tests { pass_manager.add_pass(pass::conversion::create_func_to_llvm()); pass_manager - .nested_under("func.func") + .nested_under(context, "func.func") .add_pass(pass::conversion::create_arith_to_llvm()); pass_manager - .nested_under("func.func") + .nested_under(context, "func.func") .add_pass(pass::conversion::create_index_to_llvm()); pass_manager.add_pass(pass::conversion::create_scf_to_control_flow()); pass_manager.add_pass(pass::conversion::create_control_flow_to_llvm()); @@ -396,7 +421,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -437,7 +462,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -488,7 +513,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -537,7 +562,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -569,9 +594,9 @@ mod tests { { let block = Block::new(&[(struct_type, location)]); - block.append_operation(undef(struct_type, location)); + block.append_operation(undef(&context, struct_type, location)); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -603,9 +628,9 @@ mod tests { { let block = Block::new(&[(struct_type, location)]); - block.append_operation(poison(struct_type, location)); + block.append_operation(poison(&context, struct_type, location)); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -645,7 +670,7 @@ mod tests { AllocaOptions::new().elem_type(Some(TypeAttribute::new(integer_type))), )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -685,7 +710,7 @@ mod tests { Default::default(), )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -725,7 +750,7 @@ mod tests { Default::default(), )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -768,7 +793,7 @@ mod tests { .nontemporal(true), )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -826,7 +851,7 @@ mod tests { { let block = Block::new(&[]); - block.append_operation(r#return(None, location)); + block.append_operation(r#return(&context, None, location)); let region = Region::new(); region.append_block(block); @@ -843,7 +868,11 @@ mod tests { { let block = Block::new(&[(struct_type, location)]); - block.append_operation(r#return(Some(block.argument(0).unwrap().into()), location)); + block.append_operation(r#return( + &context, + Some(block.argument(0).unwrap().into()), + location, + )); let region = Region::new(); region.append_block(block); @@ -888,7 +917,7 @@ mod tests { .unwrap() .into(); - block.append_operation(func::r#return(&[res], location)); + block.append_operation(func::r#return(&context, &[res], location)); let region = Region::new(); region.append_block(block); @@ -933,7 +962,7 @@ mod tests { .unwrap() .into(); - block.append_operation(func::r#return(&[res], location)); + block.append_operation(func::r#return(&context, &[res], location)); let region = Region::new(); region.append_block(block); @@ -968,6 +997,7 @@ mod tests { let res = block .append_operation(intr_ctpop( + &context, block.argument(0).unwrap().into(), integer_type, location, @@ -976,7 +1006,7 @@ mod tests { .unwrap() .into(); - block.append_operation(func::r#return(&[res], location)); + block.append_operation(func::r#return(&context, &[res], location)); let region = Region::new(); region.append_block(block); @@ -1011,6 +1041,7 @@ mod tests { let res = block .append_operation(intr_bswap( + &context, block.argument(0).unwrap().into(), integer_type, location, @@ -1019,7 +1050,7 @@ mod tests { .unwrap() .into(); - block.append_operation(func::r#return(&[res], location)); + block.append_operation(func::r#return(&context, &[res], location)); let region = Region::new(); region.append_block(block); @@ -1054,6 +1085,7 @@ mod tests { let res = block .append_operation(intr_bitreverse( + &context, block.argument(0).unwrap().into(), integer_type, location, @@ -1062,7 +1094,7 @@ mod tests { .unwrap() .into(); - block.append_operation(func::r#return(&[res], location)); + block.append_operation(func::r#return(&context, &[res], location)); let region = Region::new(); region.append_block(block); @@ -1107,7 +1139,7 @@ mod tests { .unwrap() .into(); - block.append_operation(func::r#return(&[res], location)); + block.append_operation(func::r#return(&context, &[res], location)); let region = Region::new(); region.append_block(block); @@ -1143,6 +1175,7 @@ mod tests { let res = block .append_operation(zext( + &context, block.argument(0).unwrap().into(), integer_double_type, location, @@ -1151,7 +1184,7 @@ mod tests { .unwrap() .into(); - block.append_operation(func::r#return(&[res], location)); + block.append_operation(func::r#return(&context, &[res], location)); let region = Region::new(); region.append_block(block); diff --git a/melior/src/dialect/memref.rs b/melior/src/dialect/memref.rs index 7f6351ec4f..ba45efeabc 100644 --- a/melior/src/dialect/memref.rs +++ b/melior/src/dialect/memref.rs @@ -62,7 +62,7 @@ fn allocate<'c>( alignment: Option>, location: Location<'c>, ) -> Operation<'c> { - let mut builder = OperationBuilder::new(name, location); + let mut builder = OperationBuilder::new(context, name, location); builder = builder.add_attributes(&[( Identifier::new(context, "operand_segment_sizes"), @@ -81,30 +81,36 @@ fn allocate<'c>( /// Create a `memref.cast` operation. pub fn cast<'c>( + context: &'c Context, value: Value<'c, '_>, r#type: MemRefType<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("memref.cast", location) + OperationBuilder::new(context, "memref.cast", location) .add_operands(&[value]) .add_results(&[r#type.into()]) .build() } /// Create a `memref.dealloc` operation. -pub fn dealloc<'c>(value: Value<'c, '_>, location: Location<'c>) -> Operation<'c> { - OperationBuilder::new("memref.dealloc", location) +pub fn dealloc<'c>( + context: &'c Context, + value: Value<'c, '_>, + location: Location<'c>, +) -> Operation<'c> { + OperationBuilder::new(context, "memref.dealloc", location) .add_operands(&[value]) .build() } /// Create a `memref.dim` operation. pub fn dim<'c>( + context: &'c Context, value: Value<'c, '_>, index: Value<'c, '_>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("memref.dim", location) + OperationBuilder::new(context, "memref.dim", location) .add_operands(&[value, index]) .enable_result_type_inference() .build() @@ -117,7 +123,7 @@ pub fn get_global<'c>( r#type: MemRefType<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("memref.get_global", location) + OperationBuilder::new(context, "memref.get_global", location) .add_attributes(&[( Identifier::new(context, "name"), FlatSymbolRefAttribute::new(context, name).into(), @@ -138,7 +144,7 @@ pub fn global<'c>( alignment: Option>, location: Location<'c>, ) -> Operation<'c> { - let mut builder = OperationBuilder::new("memref.global", location).add_attributes(&[ + let mut builder = OperationBuilder::new(context, "memref.global", location).add_attributes(&[ ( Identifier::new(context, "sym_name"), StringAttribute::new(context, name).into(), @@ -177,11 +183,12 @@ pub fn global<'c>( /// Create a `memref.load` operation. pub fn load<'c>( + context: &'c Context, memref: Value<'c, '_>, indices: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("memref.load", location) + OperationBuilder::new(context, "memref.load", location) .add_operands(&[memref]) .add_operands(indices) .enable_result_type_inference() @@ -189,8 +196,12 @@ pub fn load<'c>( } /// Create a `memref.rank` operation. -pub fn rank<'c>(value: Value<'c, '_>, location: Location<'c>) -> Operation<'c> { - OperationBuilder::new("memref.rank", location) +pub fn rank<'c>( + context: &'c Context, + value: Value<'c, '_>, + location: Location<'c>, +) -> Operation<'c> { + OperationBuilder::new(context, "memref.rank", location) .add_operands(&[value]) .enable_result_type_inference() .build() @@ -198,12 +209,13 @@ pub fn rank<'c>(value: Value<'c, '_>, location: Location<'c>) -> Operation<'c> { /// Create a `memref.store` operation. pub fn store<'c>( + context: &'c Context, value: Value<'c, '_>, memref: Value<'c, '_>, indices: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("memref.store", location) + OperationBuilder::new(context, "memref.store", location) .add_operands(&[value, memref]) .add_operands(indices) .build() @@ -218,7 +230,7 @@ pub fn realloc<'c>( alignment: Option>, location: Location<'c>, ) -> Operation<'c> { - let mut builder = OperationBuilder::new("memref.realloc", location) + let mut builder = OperationBuilder::new(context, "memref.realloc", location) .add_operands(&[value]) .add_results(&[r#type.into()]); @@ -255,7 +267,7 @@ mod tests { let block = Block::new(&[]); build_block(&block); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(context, &[], location)); let region = Region::new(); region.append_block(block); @@ -290,7 +302,11 @@ mod tests { None, location, )); - block.append_operation(dealloc(memref.result(0).unwrap().into(), location)); + block.append_operation(dealloc( + &context, + memref.result(0).unwrap().into(), + location, + )); }) } @@ -352,6 +368,7 @@ mod tests { )); block.append_operation(cast( + &context, memref.result(0).unwrap().into(), Type::parse(&context, "memref") .unwrap() @@ -384,6 +401,7 @@ mod tests { )); block.append_operation(dim( + &context, memref.result(0).unwrap().into(), index.result(0).unwrap().into(), location, @@ -417,7 +435,7 @@ mod tests { let block = Block::new(&[]); block.append_operation(get_global(&context, "foo", mem_ref_type, location)); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -498,7 +516,12 @@ mod tests { None, location, )); - block.append_operation(load(memref.result(0).unwrap().into(), &[], location)); + block.append_operation(load( + &context, + memref.result(0).unwrap().into(), + &[], + location, + )); }) } @@ -524,6 +547,7 @@ mod tests { )); block.append_operation(load( + &context, memref.result(0).unwrap().into(), &[index.result(0).unwrap().into()], location, @@ -545,7 +569,7 @@ mod tests { None, location, )); - block.append_operation(rank(memref.result(0).unwrap().into(), location)); + block.append_operation(rank(&context, memref.result(0).unwrap().into(), location)); }) } @@ -571,6 +595,7 @@ mod tests { )); block.append_operation(store( + &context, value.result(0).unwrap().into(), memref.result(0).unwrap().into(), &[], @@ -607,6 +632,7 @@ mod tests { )); block.append_operation(store( + &context, value.result(0).unwrap().into(), memref.result(0).unwrap().into(), &[index.result(0).unwrap().into()], diff --git a/melior/src/dialect/ods.rs b/melior/src/dialect/ods.rs index 20eae779c4..1c764ce274 100644 --- a/melior/src/dialect/ods.rs +++ b/melior/src/dialect/ods.rs @@ -129,10 +129,10 @@ mod tests { pass_manager.add_pass(pass::conversion::create_func_to_llvm()); pass_manager - .nested_under("func.func") + .nested_under(context, "func.func") .add_pass(pass::conversion::create_arith_to_llvm()); pass_manager - .nested_under("func.func") + .nested_under(context, "func.func") .add_pass(pass::conversion::create_index_to_llvm()); pass_manager.add_pass(pass::conversion::create_scf_to_control_flow()); pass_manager.add_pass(pass::conversion::create_control_flow_to_llvm()); @@ -192,6 +192,7 @@ mod tests { block.append_operation( llvm::alloca( + &context, dialect::llvm::r#type::pointer(i64_type.into(), 0).into(), alloca_size, location, @@ -199,7 +200,7 @@ mod tests { .into(), ); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); }); } @@ -215,7 +216,7 @@ mod tests { let i64_type = IntegerType::new(&context, 64); block.append_operation( - llvm::AllocaOpBuilder::new(location) + llvm::AllocaOpBuilder::new(&context, location) .alignment(IntegerAttribute::new(8, i64_type.into())) .elem_type(TypeAttribute::new(i64_type.into())) .array_size(alloca_size) @@ -224,7 +225,7 @@ mod tests { .into(), ); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); }); } } diff --git a/melior/src/dialect/scf.rs b/melior/src/dialect/scf.rs index f83bb4e859..8b3daa1bf5 100644 --- a/melior/src/dialect/scf.rs +++ b/melior/src/dialect/scf.rs @@ -10,11 +10,12 @@ use crate::{ /// Creates a `scf.condition` operation. pub fn condition<'c>( + context: &'c Context, condition: Value<'c, '_>, values: &[Value<'c, '_>], location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("scf.condition", location) + OperationBuilder::new(context, "scf.condition", location) .add_operands(&[condition]) .add_operands(values) .build() @@ -22,11 +23,12 @@ pub fn condition<'c>( /// Creates a `scf.execute_region` operation. pub fn execute_region<'c>( + context: &'c Context, result_types: &[Type<'c>], region: Region<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("scf.execute_region", location) + OperationBuilder::new(context, "scf.execute_region", location) .add_results(result_types) .add_regions(vec![region]) .build() @@ -34,13 +36,14 @@ pub fn execute_region<'c>( /// Creates a `scf.for` operation. pub fn r#for<'c>( + context: &'c Context, start: Value<'c, '_>, end: Value<'c, '_>, step: Value<'c, '_>, region: Region<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("scf.for", location) + OperationBuilder::new(context, "scf.for", location) .add_operands(&[start, end, step]) .add_regions(vec![region]) .build() @@ -48,13 +51,14 @@ pub fn r#for<'c>( /// Creates a `scf.if` operation. pub fn r#if<'c>( + context: &'c Context, condition: Value<'c, '_>, result_types: &[Type<'c>], then_region: Region<'c>, else_region: Region<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("scf.if", location) + OperationBuilder::new(context, "scf.if", location) .add_operands(&[condition]) .add_results(result_types) .add_regions(vec![then_region, else_region]) @@ -70,7 +74,7 @@ pub fn index_switch<'c>( regions: Vec>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("scf.index_switch", location) + OperationBuilder::new(context, "scf.index_switch", location) .add_operands(&[condition]) .add_results(result_types) .add_attributes(&[(Identifier::new(context, "cases"), cases.into())]) @@ -80,13 +84,14 @@ pub fn index_switch<'c>( /// Creates a `scf.while` operation. pub fn r#while<'c>( + context: &'c Context, initial_values: &[Value<'c, '_>], result_types: &[Type<'c>], before_region: Region<'c>, after_region: Region<'c>, location: Location<'c>, ) -> Operation<'c> { - OperationBuilder::new("scf.while", location) + OperationBuilder::new(context, "scf.while", location) .add_operands(initial_values) .add_results(result_types) .add_regions(vec![before_region, after_region]) @@ -94,8 +99,12 @@ pub fn r#while<'c>( } /// Creates a `scf.yield` operation. -pub fn r#yield<'c>(values: &[Value<'c, '_>], location: Location<'c>) -> Operation<'c> { - OperationBuilder::new("scf.yield", location) +pub fn r#yield<'c>( + context: &'c Context, + values: &[Value<'c, '_>], + location: Location<'c>, +) -> Operation<'c> { + OperationBuilder::new(context, "scf.yield", location) .add_operands(values) .build() } @@ -131,6 +140,7 @@ mod tests { let block = Block::new(&[]); block.append_operation(execute_region( + &context, &[index_type], { let block = Block::new(&[]); @@ -142,6 +152,7 @@ mod tests { )); block.append_operation(r#yield( + &context, &[value.result(0).unwrap().into()], location, )); @@ -153,7 +164,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -201,12 +212,13 @@ mod tests { )); block.append_operation(r#for( + &context, start.result(0).unwrap().into(), end.result(0).unwrap().into(), step.result(0).unwrap().into(), { let block = Block::new(&[(Type::index(&context), location)]); - block.append_operation(r#yield(&[], location)); + block.append_operation(r#yield(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -215,7 +227,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -255,6 +267,7 @@ mod tests { )); let result = block.append_operation(r#if( + &context, condition.result(0).unwrap().into(), &[index_type], { @@ -267,6 +280,7 @@ mod tests { )); block.append_operation(r#yield( + &context, &[result.result(0).unwrap().into()], location, )); @@ -285,6 +299,7 @@ mod tests { )); block.append_operation(r#yield( + &context, &[result.result(0).unwrap().into()], location, )); @@ -297,6 +312,7 @@ mod tests { )); block.append_operation(func::r#return( + &context, &[result.result(0).unwrap().into()], location, )); @@ -335,12 +351,13 @@ mod tests { )); block.append_operation(r#if( + &context, condition.result(0).unwrap().into(), &[], { let block = Block::new(&[]); - block.append_operation(r#yield(&[], location)); + block.append_operation(r#yield(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -350,7 +367,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -395,7 +412,7 @@ mod tests { { let block = Block::new(&[]); - block.append_operation(r#yield(&[], location)); + block.append_operation(r#yield(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -404,7 +421,7 @@ mod tests { { let block = Block::new(&[]); - block.append_operation(r#yield(&[], location)); + block.append_operation(r#yield(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -413,7 +430,7 @@ mod tests { { let block = Block::new(&[]); - block.append_operation(r#yield(&[], location)); + block.append_operation(r#yield(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -423,7 +440,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -463,6 +480,7 @@ mod tests { )); block.append_operation(r#while( + &context, &[initial.result(0).unwrap().into()], &[index_type], { @@ -482,6 +500,7 @@ mod tests { )); block.append_operation(super::condition( + &context, condition.result(0).unwrap().into(), &[result.result(0).unwrap().into()], location, @@ -501,6 +520,7 @@ mod tests { )); block.append_operation(r#yield( + &context, &[result.result(0).unwrap().into()], location, )); @@ -512,7 +532,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -550,6 +570,7 @@ mod tests { )); block.append_operation(r#while( + &context, &[initial.result(0).unwrap().into()], &[float_type], { @@ -569,6 +590,7 @@ mod tests { )); block.append_operation(super::condition( + &context, condition.result(0).unwrap().into(), &[result.result(0).unwrap().into()], location, @@ -588,6 +610,7 @@ mod tests { )); block.append_operation(r#yield( + &context, &[result.result(0).unwrap().into()], location, )); @@ -599,7 +622,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); @@ -636,6 +659,7 @@ mod tests { )); block.append_operation(r#while( + &context, &[ initial.result(0).unwrap().into(), initial.result(0).unwrap().into(), @@ -659,6 +683,7 @@ mod tests { )); block.append_operation(super::condition( + &context, condition.result(0).unwrap().into(), &[ result.result(0).unwrap().into(), @@ -682,6 +707,7 @@ mod tests { )); block.append_operation(r#yield( + &context, &[ result.result(0).unwrap().into(), result.result(0).unwrap().into(), @@ -696,7 +722,7 @@ mod tests { location, )); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(&context, &[], location)); let region = Region::new(); region.append_block(block); diff --git a/melior/src/execution_engine.rs b/melior/src/execution_engine.rs index d37c42e89f..a84bf97ac9 100644 --- a/melior/src/execution_engine.rs +++ b/melior/src/execution_engine.rs @@ -1,4 +1,4 @@ -use crate::{ir::Module, logical_result::LogicalResult, string_ref::StringRef, Error}; +use crate::{ir::Module, logical_result::LogicalResult, string_ref::StringRef, Context, Error}; use mlir_sys::{ mlirExecutionEngineCreate, mlirExecutionEngineDestroy, mlirExecutionEngineDumpToObjectFile, mlirExecutionEngineInvokePacked, mlirExecutionEngineRegisterSymbol, MlirExecutionEngine, @@ -11,8 +11,9 @@ pub struct ExecutionEngine { impl ExecutionEngine { /// Creates an execution engine. - pub fn new( - module: &Module, + pub fn new<'c>( + context: &'c Context, + module: &Module<'c>, optimization_level: usize, shared_library_paths: &[&str], enable_object_dump: bool, @@ -25,7 +26,7 @@ impl ExecutionEngine { shared_library_paths.len() as i32, shared_library_paths .iter() - .map(|&string| StringRef::from(string).to_raw()) + .map(|&string| StringRef::from_str(context, string).to_raw()) .collect::>() .as_ptr(), enable_object_dump, @@ -42,10 +43,15 @@ impl ExecutionEngine { /// This function modifies memory locations pointed by the `arguments` /// argument. If those pointers are invalid or misaligned, calling this /// function might result in undefined behavior. - pub unsafe fn invoke_packed(&self, name: &str, arguments: &mut [*mut ()]) -> Result<(), Error> { + pub unsafe fn invoke_packed( + &self, + context: &Context, + name: &str, + arguments: &mut [*mut ()], + ) -> Result<(), Error> { let result = LogicalResult::from_raw(mlirExecutionEngineInvokePacked( self.raw, - StringRef::from(name).to_raw(), + StringRef::from_str(context, name).to_raw(), arguments.as_mut_ptr() as _, )); @@ -63,13 +69,22 @@ impl ExecutionEngine { /// This function makes a pointer accessible to the execution engine. If a /// given pointer is invalid or misaligned, calling this function might /// result in undefined behavior. - pub unsafe fn register_symbol(&self, name: &str, ptr: *mut ()) { - mlirExecutionEngineRegisterSymbol(self.raw, StringRef::from(name).to_raw(), ptr as _); + pub unsafe fn register_symbol(&self, context: &Context, name: &str, ptr: *mut ()) { + mlirExecutionEngineRegisterSymbol( + self.raw, + StringRef::from_str(context, name).to_raw(), + ptr as _, + ); } /// Dumps a module to an object file. - pub fn dump_to_object_file(&self, path: &str) { - unsafe { mlirExecutionEngineDumpToObjectFile(self.raw, StringRef::from(path).to_raw()) } + pub fn dump_to_object_file(&self, context: &Context, path: &str) { + unsafe { + mlirExecutionEngineDumpToObjectFile( + self.raw, + StringRef::from_str(context, path).to_raw(), + ) + } } } @@ -105,12 +120,12 @@ mod tests { pass_manager.add_pass(pass::conversion::create_func_to_llvm()); pass_manager - .nested_under("func.func") + .nested_under(&context, "func.func") .add_pass(pass::conversion::create_arith_to_llvm()); assert_eq!(pass_manager.run(&mut module), Ok(())); - let engine = ExecutionEngine::new(&module, 2, &[], false); + let engine = ExecutionEngine::new(&context, &module, 2, &[], false); let mut argument = 42; let mut result = -1; @@ -118,6 +133,7 @@ mod tests { assert_eq!( unsafe { engine.invoke_packed( + &context, "add", &mut [ &mut argument as *mut i32 as *mut (), @@ -153,11 +169,12 @@ mod tests { pass_manager.add_pass(pass::conversion::create_func_to_llvm()); pass_manager - .nested_under("func.func") + .nested_under(&context, "func.func") .add_pass(pass::conversion::create_arith_to_llvm()); assert_eq!(pass_manager.run(&mut module), Ok(())); - ExecutionEngine::new(&module, 2, &[], true).dump_to_object_file("/tmp/melior/test.o"); + ExecutionEngine::new(&context, &module, 2, &[], true) + .dump_to_object_file(&context, "/tmp/melior/test.o"); } } diff --git a/melior/src/ir/attribute.rs b/melior/src/ir/attribute.rs index aaa09f8168..2f5766d3b6 100644 --- a/melior/src/ir/attribute.rs +++ b/melior/src/ir/attribute.rs @@ -44,7 +44,7 @@ impl<'c> Attribute<'c> { unsafe { Self::from_option_raw(mlirAttributeParseGet( context.to_raw(), - StringRef::from(source).to_raw(), + StringRef::from_str(context, source).to_raw(), )) } } diff --git a/melior/src/ir/attribute/flat_symbol_ref.rs b/melior/src/ir/attribute/flat_symbol_ref.rs index ccb346d7cf..c98933de04 100644 --- a/melior/src/ir/attribute/flat_symbol_ref.rs +++ b/melior/src/ir/attribute/flat_symbol_ref.rs @@ -14,7 +14,7 @@ impl<'c> FlatSymbolRefAttribute<'c> { unsafe { Self::from_raw(mlirFlatSymbolRefAttrGet( context.to_raw(), - StringRef::from(symbol).to_raw(), + StringRef::from_str(context, symbol).to_raw(), )) } } diff --git a/melior/src/ir/attribute/string.rs b/melior/src/ir/attribute/string.rs index 4c2eec8591..22fa0c542d 100644 --- a/melior/src/ir/attribute/string.rs +++ b/melior/src/ir/attribute/string.rs @@ -14,7 +14,7 @@ impl<'c> StringAttribute<'c> { unsafe { Self::from_raw(mlirStringAttrGet( context.to_raw(), - StringRef::from(string).to_raw(), + StringRef::from_str(context, string).to_raw(), )) } } diff --git a/melior/src/ir/block.rs b/melior/src/ir/block.rs index 0d1879a6c7..74b60a87bd 100644 --- a/melior/src/ir/block.rs +++ b/melior/src/ir/block.rs @@ -404,7 +404,7 @@ mod tests { let block = Block::new(&[]); let operation = block.append_operation( - OperationBuilder::new("func.return", Location::unknown(&context)).build(), + OperationBuilder::new(&context, "func.return", Location::unknown(&context)).build(), ); assert_eq!(block.terminator(), Some(operation)); @@ -421,8 +421,9 @@ mod tests { context.set_allow_unregistered_dialects(true); let block = Block::new(&[]); - let operation = block - .append_operation(OperationBuilder::new("foo", Location::unknown(&context)).build()); + let operation = block.append_operation( + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(), + ); assert_eq!(block.first_operation(), Some(operation)); } @@ -440,7 +441,9 @@ mod tests { context.set_allow_unregistered_dialects(true); let block = Block::new(&[]); - block.append_operation(OperationBuilder::new("foo", Location::unknown(&context)).build()); + block.append_operation( + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(), + ); } #[test] @@ -451,7 +454,7 @@ mod tests { block.insert_operation( 0, - OperationBuilder::new("foo", Location::unknown(&context)).build(), + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(), ); } @@ -461,11 +464,12 @@ mod tests { context.set_allow_unregistered_dialects(true); let block = Block::new(&[]); - let first_operation = block - .append_operation(OperationBuilder::new("foo", Location::unknown(&context)).build()); + let first_operation = block.append_operation( + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(), + ); let second_operation = block.insert_operation_after( first_operation, - OperationBuilder::new("foo", Location::unknown(&context)).build(), + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(), ); assert_eq!(block.first_operation(), Some(first_operation)); @@ -481,11 +485,12 @@ mod tests { context.set_allow_unregistered_dialects(true); let block = Block::new(&[]); - let second_operation = block - .append_operation(OperationBuilder::new("foo", Location::unknown(&context)).build()); + let second_operation = block.append_operation( + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(), + ); let first_operation = block.insert_operation_before( second_operation, - OperationBuilder::new("foo", Location::unknown(&context)).build(), + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(), ); assert_eq!(block.first_operation(), Some(first_operation)); diff --git a/melior/src/ir/identifier.rs b/melior/src/ir/identifier.rs index 6dd118a8a2..ea84e458f8 100644 --- a/melior/src/ir/identifier.rs +++ b/melior/src/ir/identifier.rs @@ -17,11 +17,11 @@ pub struct Identifier<'c> { impl<'c> Identifier<'c> { /// Creates an identifier. - pub fn new(context: &Context, name: &str) -> Self { + pub fn new(context: &'c Context, name: &str) -> Self { unsafe { Self::from_raw(mlirIdentifierGet( context.to_raw(), - StringRef::from(name).to_raw(), + StringRef::from_str(context, name).to_raw(), )) } } diff --git a/melior/src/ir/location.rs b/melior/src/ir/location.rs index bb9e584fee..3179842f2c 100644 --- a/melior/src/ir/location.rs +++ b/melior/src/ir/location.rs @@ -27,7 +27,7 @@ impl<'c> Location<'c> { unsafe { Self::from_raw(mlirLocationFileLineColGet( context.to_raw(), - StringRef::from(filename).to_raw(), + StringRef::from_str(context, filename).to_raw(), line as u32, column as u32, )) @@ -51,7 +51,7 @@ impl<'c> Location<'c> { unsafe { Self::from_raw(mlirLocationNameGet( context.to_raw(), - StringRef::from(name).to_raw(), + StringRef::from_str(context, name).to_raw(), child.to_raw(), )) } diff --git a/melior/src/ir/module.rs b/melior/src/ir/module.rs index cf0c1d60f7..87053a02a4 100644 --- a/melior/src/ir/module.rs +++ b/melior/src/ir/module.rs @@ -28,7 +28,7 @@ impl<'c> Module<'c> { unsafe { Self::from_option_raw(mlirModuleCreateParse( context.to_raw(), - StringRef::from(source).to_raw(), + StringRef::from_str(context, source).to_raw(), )) } } @@ -126,7 +126,7 @@ mod tests { region.append_block(Block::new(&[])); let module = Module::from_operation( - OperationBuilder::new("builtin.module", Location::unknown(&context)) + OperationBuilder::new(&context, "builtin.module", Location::unknown(&context)) .add_regions(vec![region]) .build(), ) @@ -141,7 +141,7 @@ mod tests { let context = create_test_context(); assert!(Module::from_operation( - OperationBuilder::new("func.func", Location::unknown(&context),).build() + OperationBuilder::new(&context, "func.func", Location::unknown(&context),).build() ) .is_none()); } diff --git a/melior/src/ir/operation.rs b/melior/src/ir/operation.rs index 3d37f6809e..d93d1450ee 100644 --- a/melior/src/ir/operation.rs +++ b/melior/src/ir/operation.rs @@ -184,37 +184,42 @@ impl<'c> Operation<'c> { } /// Gets a attribute with the given name. - pub fn attribute(&self, name: &str) -> Result, Error> { + pub fn attribute(&self, context: &'c Context, name: &str) -> Result, Error> { unsafe { Attribute::from_option_raw(mlirOperationGetAttributeByName( self.raw, - StringRef::from(name).to_raw(), + StringRef::from_str(context, name).to_raw(), )) } .ok_or(Error::AttributeNotFound(name.into())) } /// Checks if the operation has a attribute with the given name. - pub fn has_attribute(&self, name: &str) -> bool { - self.attribute(name).is_ok() + pub fn has_attribute(&self, context: &'c Context, name: &str) -> bool { + self.attribute(context, name).is_ok() } /// Sets the attribute with the given name to the given attribute. - pub fn set_attribute(&mut self, name: &str, attribute: &Attribute<'c>) { + pub fn set_attribute(&mut self, context: &'c Context, name: &str, attribute: &Attribute<'c>) { unsafe { mlirOperationSetAttributeByName( self.raw, - StringRef::from(name).to_raw(), + StringRef::from_str(context, name).to_raw(), attribute.to_raw(), ) } } /// Removes the attribute with the given name. - pub fn remove_attribute(&mut self, name: &str) -> Result<(), Error> { - unsafe { mlirOperationRemoveAttributeByName(self.raw, StringRef::from(name).to_raw()) } - .then_some(()) - .ok_or(Error::AttributeNotFound(name.into())) + pub fn remove_attribute(&mut self, context: &'c Context, name: &str) -> Result<(), Error> { + unsafe { + mlirOperationRemoveAttributeByName( + self.raw, + StringRef::from_str(context, name).to_raw(), + ) + } + .then_some(()) + .ok_or(Error::AttributeNotFound(name.into())) } /// Gets the next operation in the same block. @@ -426,7 +431,7 @@ mod tests { fn new() { let context = create_test_context(); context.set_allow_unregistered_dialects(true); - OperationBuilder::new("foo", Location::unknown(&context)).build(); + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(); } #[test] @@ -435,7 +440,7 @@ mod tests { context.set_allow_unregistered_dialects(true); assert_eq!( - OperationBuilder::new("foo", Location::unknown(&context),) + OperationBuilder::new(&context, "foo", Location::unknown(&context),) .build() .name(), Identifier::new(&context, "foo") @@ -447,8 +452,9 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); let block = Block::new(&[]); - let operation = block - .append_operation(OperationBuilder::new("foo", Location::unknown(&context)).build()); + let operation = block.append_operation( + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(), + ); assert_eq!(operation.block().as_deref(), Some(&block)); } @@ -458,7 +464,7 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); assert_eq!( - OperationBuilder::new("foo", Location::unknown(&context)) + OperationBuilder::new(&context, "foo", Location::unknown(&context)) .build() .block(), None @@ -470,7 +476,7 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); assert_eq!( - OperationBuilder::new("foo", Location::unknown(&context)) + OperationBuilder::new(&context, "foo", Location::unknown(&context)) .build() .result(0) .unwrap_err(), @@ -487,7 +493,7 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); assert_eq!( - OperationBuilder::new("foo", Location::unknown(&context),) + OperationBuilder::new(&context, "foo", Location::unknown(&context),) .build() .region(0), Err(Error::PositionOutOfBounds { @@ -509,7 +515,7 @@ mod tests { let argument: Value = block.argument(0).unwrap().into(); let operands = vec![argument, argument, argument]; - let operation = OperationBuilder::new("foo", Location::unknown(&context)) + let operation = OperationBuilder::new(&context, "foo", Location::unknown(&context)) .add_operands(&operands) .build(); @@ -524,7 +530,7 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); - let operation = OperationBuilder::new("foo", Location::unknown(&context)) + let operation = OperationBuilder::new(&context, "foo", Location::unknown(&context)) .add_regions(vec![Region::new()]) .build(); @@ -539,22 +545,26 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); - let mut operation = OperationBuilder::new("foo", Location::unknown(&context)) + let mut operation = OperationBuilder::new(&context, "foo", Location::unknown(&context)) .add_attributes(&[( Identifier::new(&context, "foo"), StringAttribute::new(&context, "bar").into(), )]) .build(); - assert!(operation.has_attribute("foo")); + assert!(operation.has_attribute(&context, "foo")); assert_eq!( - operation.attribute("foo").map(|a| a.to_string()), + operation.attribute(&context, "foo").map(|a| a.to_string()), Ok("\"bar\"".into()) ); - assert!(operation.remove_attribute("foo").is_ok()); - assert!(operation.remove_attribute("foo").is_err()); - operation.set_attribute("foo", &StringAttribute::new(&context, "foo").into()); + assert!(operation.remove_attribute(&context, "foo").is_ok()); + assert!(operation.remove_attribute(&context, "foo").is_err()); + operation.set_attribute( + &context, + "foo", + &StringAttribute::new(&context, "foo").into(), + ); assert_eq!( - operation.attribute("foo").map(|a| a.to_string()), + operation.attribute(&context, "foo").map(|a| a.to_string()), Ok("\"foo\"".into()) ); assert_eq!( @@ -570,7 +580,7 @@ mod tests { fn clone() { let context = create_test_context(); context.set_allow_unregistered_dialects(true); - let operation = OperationBuilder::new("foo", Location::unknown(&context)).build(); + let operation = OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(); let _ = operation.clone(); } @@ -581,7 +591,7 @@ mod tests { context.set_allow_unregistered_dialects(true); assert_eq!( - OperationBuilder::new("foo", Location::unknown(&context),) + OperationBuilder::new(&context, "foo", Location::unknown(&context),) .build() .to_string(), "\"foo\"() : () -> ()\n" @@ -596,7 +606,7 @@ mod tests { assert_eq!( format!( "{:?}", - OperationBuilder::new("foo", Location::unknown(&context)).build() + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build() ), "Operation(\n\"foo\"() : () -> ()\n)" ); @@ -608,7 +618,7 @@ mod tests { context.set_allow_unregistered_dialects(true); assert_eq!( - OperationBuilder::new("foo", Location::unknown(&context)) + OperationBuilder::new(&context, "foo", Location::unknown(&context)) .build() .to_string_with_flags( OperationPrintingFlags::new() diff --git a/melior/src/ir/operation/builder.rs b/melior/src/ir/operation/builder.rs index 954673e8c9..d45f6bd6ae 100644 --- a/melior/src/ir/operation/builder.rs +++ b/melior/src/ir/operation/builder.rs @@ -20,10 +20,13 @@ pub struct OperationBuilder<'c> { impl<'c> OperationBuilder<'c> { /// Creates an operation builder. - pub fn new(name: &str, location: Location<'c>) -> Self { + pub fn new(context: &'c Context, name: &str, location: Location<'c>) -> Self { Self { raw: unsafe { - mlirOperationStateGet(StringRef::from(name).to_raw(), location.to_raw()) + mlirOperationStateGet( + StringRef::from_str(context, name).to_raw(), + location.to_raw(), + ) }, _context: Default::default(), } @@ -134,7 +137,7 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); - OperationBuilder::new("foo", Location::unknown(&context)).build(); + OperationBuilder::new(&context, "foo", Location::unknown(&context)).build(); } #[test] @@ -147,7 +150,7 @@ mod tests { let block = Block::new(&[(r#type, location)]); let argument = block.argument(0).unwrap().into(); - OperationBuilder::new("foo", Location::unknown(&context)) + OperationBuilder::new(&context, "foo", Location::unknown(&context)) .add_operands(&[argument]) .build(); } @@ -157,7 +160,7 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); - OperationBuilder::new("foo", Location::unknown(&context)) + OperationBuilder::new(&context, "foo", Location::unknown(&context)) .add_results(&[Type::parse(&context, "i1").unwrap()]) .build(); } @@ -167,7 +170,7 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); - OperationBuilder::new("foo", Location::unknown(&context)) + OperationBuilder::new(&context, "foo", Location::unknown(&context)) .add_regions(vec![Region::new()]) .build(); } @@ -177,7 +180,7 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); - OperationBuilder::new("foo", Location::unknown(&context)) + OperationBuilder::new(&context, "foo", Location::unknown(&context)) .add_successors(&[&Block::new(&[])]) .build(); } @@ -187,7 +190,7 @@ mod tests { let context = create_test_context(); context.set_allow_unregistered_dialects(true); - OperationBuilder::new("foo", Location::unknown(&context)) + OperationBuilder::new(&context, "foo", Location::unknown(&context)) .add_attributes(&[( Identifier::new(&context, "foo"), Attribute::parse(&context, "unit").unwrap(), @@ -206,7 +209,7 @@ mod tests { let argument = block.argument(0).unwrap().into(); assert_eq!( - OperationBuilder::new("arith.addi", location) + OperationBuilder::new(&context, "arith.addi", location) .add_operands(&[argument, argument]) .enable_result_type_inference() .build() diff --git a/melior/src/ir/operation/result.rs b/melior/src/ir/operation/result.rs index f0067f23a3..cdcf4db63a 100644 --- a/melior/src/ir/operation/result.rs +++ b/melior/src/ir/operation/result.rs @@ -71,7 +71,7 @@ mod tests { context.set_allow_unregistered_dialects(true); let r#type = Type::parse(&context, "index").unwrap(); - let operation = OperationBuilder::new("foo", Location::unknown(&context)) + let operation = OperationBuilder::new(&context, "foo", Location::unknown(&context)) .add_results(&[r#type]) .build(); diff --git a/melior/src/ir/type.rs b/melior/src/ir/type.rs index 0a3e882136..479cd0881a 100644 --- a/melior/src/ir/type.rs +++ b/melior/src/ir/type.rs @@ -43,7 +43,7 @@ impl<'c> Type<'c> { unsafe { Self::from_option_raw(mlirTypeParseGet( context.to_raw(), - StringRef::from(source).to_raw(), + StringRef::from_str(context, source).to_raw(), )) } } diff --git a/melior/src/ir/value.rs b/melior/src/ir/value.rs index b97451fde3..073e987747 100644 --- a/melior/src/ir/value.rs +++ b/melior/src/ir/value.rs @@ -90,7 +90,7 @@ mod tests { let location = Location::unknown(&context); let index_type = Type::index(&context); - let operation = OperationBuilder::new("arith.constant", location) + let operation = OperationBuilder::new(&context, "arith.constant", location) .add_results(&[index_type]) .add_attributes(&[( Identifier::new(&context, "value"), @@ -107,7 +107,7 @@ mod tests { let location = Location::unknown(&context); let r#type = Type::index(&context); - let operation = OperationBuilder::new("arith.constant", location) + let operation = OperationBuilder::new(&context, "arith.constant", location) .add_results(&[r#type]) .add_attributes(&[( Identifier::new(&context, "value"), @@ -133,7 +133,7 @@ mod tests { let location = Location::unknown(&context); let index_type = Type::index(&context); - let value = OperationBuilder::new("arith.constant", location) + let value = OperationBuilder::new(&context, "arith.constant", location) .add_results(&[index_type]) .add_attributes(&[( Identifier::new(&context, "value"), @@ -150,7 +150,7 @@ mod tests { let location = Location::unknown(&context); let index_type = Type::index(&context); - let operation = OperationBuilder::new("arith.constant", location) + let operation = OperationBuilder::new(&context, "arith.constant", location) .add_results(&[index_type]) .add_attributes(&[( Identifier::new(&context, "value"), @@ -169,7 +169,7 @@ mod tests { let index_type = Type::index(&context); let operation = || { - OperationBuilder::new("arith.constant", location) + OperationBuilder::new(&context, "arith.constant", location) .add_results(&[index_type]) .add_attributes(&[( Identifier::new(&context, "value"), @@ -192,7 +192,7 @@ mod tests { let location = Location::unknown(&context); let index_type = Type::index(&context); - let operation = OperationBuilder::new("arith.constant", location) + let operation = OperationBuilder::new(&context, "arith.constant", location) .add_results(&[index_type]) .add_attributes(&[( Identifier::new(&context, "value"), @@ -213,7 +213,7 @@ mod tests { let location = Location::unknown(&context); let index_type = Type::index(&context); - let operation = OperationBuilder::new("arith.constant", location) + let operation = OperationBuilder::new(&context, "arith.constant", location) .add_results(&[index_type]) .add_attributes(&[( Identifier::new(&context, "value"), @@ -234,7 +234,7 @@ mod tests { let location = Location::unknown(&context); let index_type = Type::index(&context); - let operation = OperationBuilder::new("arith.constant", location) + let operation = OperationBuilder::new(&context, "arith.constant", location) .add_results(&[index_type]) .add_attributes(&[( Identifier::new(&context, "value"), diff --git a/melior/src/lib.rs b/melior/src/lib.rs index eda303821f..80f6b46a4b 100644 --- a/melior/src/lib.rs +++ b/melior/src/lib.rs @@ -74,12 +74,17 @@ mod tests { let block = Block::new(&[(integer_type, location), (integer_type, location)]); let sum = block.append_operation(arith::addi( + &context, block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )); - block.append_operation(func::r#return(&[sum.result(0).unwrap().into()], location)); + block.append_operation(func::r#return( + &context, + &[sum.result(0).unwrap().into()], + location, + )); let region = Region::new(); region.append_block(block); @@ -124,7 +129,7 @@ mod tests { )); let dim = function_block.append_operation( - OperationBuilder::new("memref.dim", location) + OperationBuilder::new(&context, "memref.dim", location) .add_operands(&[ function_block.argument(0).unwrap().into(), zero.result(0).unwrap().into(), @@ -145,7 +150,7 @@ mod tests { let f32_type = Type::float32(&context); let lhs = loop_block.append_operation( - OperationBuilder::new("memref.load", location) + OperationBuilder::new(&context, "memref.load", location) .add_operands(&[ function_block.argument(0).unwrap().into(), loop_block.argument(0).unwrap().into(), @@ -155,7 +160,7 @@ mod tests { ); let rhs = loop_block.append_operation( - OperationBuilder::new("memref.load", location) + OperationBuilder::new(&context, "memref.load", location) .add_operands(&[ function_block.argument(1).unwrap().into(), loop_block.argument(0).unwrap().into(), @@ -165,13 +170,14 @@ mod tests { ); let add = loop_block.append_operation(arith::addf( + &context, lhs.result(0).unwrap().into(), rhs.result(0).unwrap().into(), location, )); loop_block.append_operation( - OperationBuilder::new("memref.store", location) + OperationBuilder::new(&context, "memref.store", location) .add_operands(&[ add.result(0).unwrap().into(), function_block.argument(0).unwrap().into(), @@ -180,10 +186,11 @@ mod tests { .build(), ); - loop_block.append_operation(scf::r#yield(&[], location)); + loop_block.append_operation(scf::r#yield(&context, &[], location)); } function_block.append_operation(scf::r#for( + &context, zero.result(0).unwrap().into(), dim.result(0).unwrap().into(), one.result(0).unwrap().into(), @@ -195,7 +202,7 @@ mod tests { location, )); - function_block.append_operation(func::r#return(&[], location)); + function_block.append_operation(func::r#return(&context, &[], location)); let function_region = Region::new(); function_region.append_block(function_block); @@ -235,7 +242,7 @@ mod tests { rhs: Value<'c, '_>, ) -> Value<'c, 'a> { block - .append_operation(arith::addi(lhs, rhs, Location::unknown(context))) + .append_operation(arith::addi(context, lhs, rhs, Location::unknown(context))) .result(0) .unwrap() .into() @@ -251,6 +258,7 @@ mod tests { let block = Block::new(&[(integer_type, location), (integer_type, location)]); block.append_operation(func::r#return( + &context, &[compile_add( &context, &block, diff --git a/melior/src/pass/external.rs b/melior/src/pass/external.rs index 0fa3df2eac..482b425cbd 100644 --- a/melior/src/pass/external.rs +++ b/melior/src/pass/external.rs @@ -4,7 +4,7 @@ use super::Pass; use crate::{ dialect::DialectHandle, ir::{r#type::TypeId, OperationRef}, - ContextRef, StringRef, + Context, ContextRef, StringRef, }; use mlir_sys::{ mlirCreateExternalPass, mlirExternalPassSignalFailure, MlirContext, MlirExternalPass, @@ -140,6 +140,7 @@ impl<'c, F: FnMut(OperationRef<'c, '_>, ExternalPass<'_>) + Clone> RunExternalPa /// /// ``` /// use melior::{ +/// Context, /// ir::{r#type::TypeId, OperationRef}, /// pass::{create_external, ExternalPass}, /// }; @@ -149,7 +150,10 @@ impl<'c, F: FnMut(OperationRef<'c, '_>, ExternalPass<'_>) + Clone> RunExternalPa /// /// static EXAMPLE_PASS: PassId = PassId; /// +/// let context = Context::new(); +/// /// create_external( +/// &context, /// |operation: OperationRef, _pass: ExternalPass| { /// operation.dump(); /// }, @@ -161,7 +165,9 @@ impl<'c, F: FnMut(OperationRef<'c, '_>, ExternalPass<'_>) + Clone> RunExternalPa /// &[], /// ); /// ``` +#[allow(clippy::too_many_arguments)] pub fn create_external<'c, T: RunExternalPass<'c>>( + context: &'c Context, pass: T, pass_id: TypeId, name: &str, @@ -173,10 +179,10 @@ pub fn create_external<'c, T: RunExternalPass<'c>>( unsafe { Pass::from_raw(mlirCreateExternalPass( pass_id.to_raw(), - StringRef::from(name).to_raw(), - StringRef::from(argument).to_raw(), - StringRef::from(description).to_raw(), - StringRef::from(op_name).to_raw(), + StringRef::from_str(context, name).to_raw(), + StringRef::from_str(context, argument).to_raw(), + StringRef::from_str(context, description).to_raw(), + StringRef::from_str(context, op_name).to_raw(), dependent_dialects.len() as isize, dependent_dialects.as_ptr() as _, MlirExternalPassCallbacks { @@ -219,7 +225,7 @@ mod tests { TypeAttribute::new(FunctionType::new(context, &[], &[]).into()), { let block = Block::new(&[]); - block.append_operation(func::r#return(&[], location)); + block.append_operation(func::r#return(context, &[], location)); let region = Region::new(); region.append_block(block); @@ -236,11 +242,12 @@ mod tests { static TEST_PASS: PassId = PassId; #[derive(Clone, Debug)] - struct TestPass { + struct TestPass<'c> { + context: &'c Context, value: i32, } - impl<'c> RunExternalPass<'c> for TestPass { + impl<'c> RunExternalPass<'c> for TestPass<'c> { fn construct(&mut self) { assert_eq!(self.value, 10); } @@ -267,14 +274,15 @@ mod tests { .first_operation() .expect("body has a function") .name() - == Identifier::new(&operation.context(), "func.func") + == Identifier::new(self.context, "func.func") ); } } - impl TestPass { - fn create(self) -> Pass { + impl<'c> TestPass<'c> { + fn into_pass(self, context: &Context) -> Pass { create_external( + context, self, TypeId::create(&TEST_PASS), "test pass", @@ -291,8 +299,11 @@ mod tests { let mut module = create_module(&context); let pass_manager = PassManager::new(&context); - let test_pass = TestPass { value: 10 }; - pass_manager.add_pass(test_pass.create()); + let test_pass = TestPass { + context: &context, + value: 10, + }; + pass_manager.add_pass(test_pass.into_pass(&context)); pass_manager.run(&mut module).unwrap(); } @@ -306,6 +317,7 @@ mod tests { let pass_manager = PassManager::new(&context); pass_manager.add_pass(create_external( + &context, |operation: OperationRef, pass: ExternalPass<'_>| { assert!(operation.verify()); assert!( @@ -317,7 +329,7 @@ mod tests { .first_operation() .expect("body has a function") .name() - == Identifier::new(&operation.context(), "func.func") + == Identifier::new(&context, "func.func") ); pass.signal_failure(); }, diff --git a/melior/src/pass/manager.rs b/melior/src/pass/manager.rs index fafbb409f1..4f99588852 100644 --- a/melior/src/pass/manager.rs +++ b/melior/src/pass/manager.rs @@ -28,11 +28,11 @@ impl<'c> PassManager<'c> { /// Gets an operation pass manager for nested operations corresponding to a /// given name. - pub fn nested_under(&self, name: &str) -> OperationPassManager { + pub fn nested_under(&self, context: &'c Context, name: &str) -> OperationPassManager { unsafe { OperationPassManager::from_raw(mlirPassManagerGetNestedUnder( self.raw, - StringRef::from(name).to_raw(), + StringRef::from_str(context, name).to_raw(), )) } } @@ -178,15 +178,15 @@ mod tests { let manager = PassManager::new(&context); manager - .nested_under("func.func") + .nested_under(&context, "func.func") .add_pass(pass::transform::create_print_op_stats()); assert_eq!(manager.run(&mut module), Ok(())); let manager = PassManager::new(&context); manager - .nested_under("builtin.module") - .nested_under("func.func") + .nested_under(&context, "builtin.module") + .nested_under(&context, "func.func") .add_pass(pass::transform::create_print_op_stats()); assert_eq!(manager.run(&mut module), Ok(())); @@ -196,7 +196,7 @@ mod tests { fn print_pass_pipeline() { let context = create_test_context(); let manager = PassManager::new(&context); - let function_manager = manager.nested_under("func.func"); + let function_manager = manager.nested_under(&context, "func.func"); function_manager.add_pass(pass::transform::create_print_op_stats()); @@ -216,6 +216,7 @@ mod tests { let manager = PassManager::new(&context); insta::assert_display_snapshot!(parse_pass_pipeline( + &context, manager.as_operation_pass_manager(), "builtin.module(func.func(print-op-stats{json=false}),\ func.func(print-op-stats{json=false}))" @@ -226,6 +227,7 @@ mod tests { assert_eq!( parse_pass_pipeline( + &context, manager.as_operation_pass_manager(), "builtin.module(func.func(print-op-stats{json=false}),\ func.func(print-op-stats{json=false}))" diff --git a/melior/src/pass/operation_manager.rs b/melior/src/pass/operation_manager.rs index b27b0240db..e4ec956bdb 100644 --- a/melior/src/pass/operation_manager.rs +++ b/melior/src/pass/operation_manager.rs @@ -1,5 +1,5 @@ use super::PassManager; -use crate::{pass::Pass, string_ref::StringRef}; +use crate::{pass::Pass, string_ref::StringRef, Context}; use mlir_sys::{ mlirOpPassManagerAddOwnedPass, mlirOpPassManagerGetNestedUnder, mlirPrintPassPipeline, MlirOpPassManager, MlirStringRef, @@ -12,19 +12,19 @@ use std::{ /// An operation pass manager. #[derive(Clone, Copy, Debug)] -pub struct OperationPassManager<'a> { +pub struct OperationPassManager<'c, 'a> { raw: MlirOpPassManager, - _parent: PhantomData<&'a PassManager<'a>>, + _parent: PhantomData<&'a PassManager<'c>>, } -impl<'a> OperationPassManager<'a> { +impl<'c, 'a> OperationPassManager<'c, 'a> { /// Gets an operation pass manager for nested operations corresponding to a /// given name. - pub fn nested_under(&self, name: &str) -> Self { + pub fn nested_under(&self, context: &'c Context, name: &str) -> Self { unsafe { Self::from_raw(mlirOpPassManagerGetNestedUnder( self.raw, - StringRef::from(name).to_raw(), + StringRef::from_str(context, name).to_raw(), )) } } @@ -52,7 +52,7 @@ impl<'a> OperationPassManager<'a> { } } -impl<'a> Display for OperationPassManager<'a> { +impl<'c, 'a> Display for OperationPassManager<'c, 'a> { fn fmt(&self, formatter: &mut Formatter) -> fmt::Result { let mut data = (formatter, Ok(())); diff --git a/melior/src/string_ref.rs b/melior/src/string_ref.rs index 9ee6859fde..5ef68a1216 100644 --- a/melior/src/string_ref.rs +++ b/melior/src/string_ref.rs @@ -1,6 +1,5 @@ -use dashmap::DashMap; +use crate::Context; use mlir_sys::{mlirStringRefCreateFromCString, mlirStringRefEqual, MlirStringRef}; -use once_cell::sync::Lazy; use std::{ ffi::CString, marker::PhantomData, @@ -8,24 +7,31 @@ use std::{ str::{self, Utf8Error}, }; -// We need to pass null-terminated strings to functions in the MLIR API although -// Rust's strings are not. -static STRING_CACHE: Lazy> = Lazy::new(Default::default); - /// A string reference. // https://mlir.llvm.org/docs/CAPI/#stringref // // TODO The documentation says string refs do not have to be null-terminated. // But it looks like some functions do not handle strings not null-terminated? #[derive(Clone, Copy, Debug)] -pub struct StringRef<'a> { +pub struct StringRef<'c> { raw: MlirStringRef, - _parent: PhantomData<&'a ()>, + _parent: PhantomData<&'c Context>, } -impl<'a> StringRef<'a> { +impl<'c> StringRef<'c> { + pub fn from_str(context: &'c Context, string: &str) -> Self { + let string = context + .string_cache() + .entry(CString::new(string).unwrap()) + .or_default() + .key() + .as_ptr(); + + unsafe { Self::from_raw(mlirStringRefCreateFromCString(string)) } + } + /// Converts a string reference into a `str`. - pub fn as_str(&self) -> Result<&'a str, Utf8Error> { + pub fn as_str(&self) -> Result<&'c str, Utf8Error> { unsafe { let bytes = slice::from_raw_parts(self.raw.data as *mut u8, self.raw.length); @@ -63,32 +69,37 @@ impl<'a> PartialEq for StringRef<'a> { impl<'a> Eq for StringRef<'a> {} -impl From<&str> for StringRef<'static> { - fn from(string: &str) -> Self { - let entry = STRING_CACHE - .entry(CString::new(string).unwrap()) - .or_insert_with(Default::default); - - unsafe { Self::from_raw(mlirStringRefCreateFromCString(entry.key().as_ptr())) } - } -} - #[cfg(test)] mod tests { use super::*; #[test] fn equal() { - assert_eq!(StringRef::from("foo"), StringRef::from("foo")); + let context = Context::new(); + + assert_eq!( + StringRef::from_str(&context, "foo"), + StringRef::from_str(&context, "foo") + ); } #[test] fn equal_str() { - assert_eq!(StringRef::from("foo").as_str().unwrap(), "foo"); + let context = Context::new(); + + assert_eq!( + StringRef::from_str(&context, "foo").as_str().unwrap(), + "foo" + ); } #[test] fn not_equal() { - assert_ne!(StringRef::from("foo"), StringRef::from("bar")); + let context = Context::new(); + + assert_ne!( + StringRef::from_str(&context, "foo"), + StringRef::from_str(&context, "bar") + ); } } diff --git a/melior/src/utility.rs b/melior/src/utility.rs index cb632b6334..a93ecc8bc4 100644 --- a/melior/src/utility.rs +++ b/melior/src/utility.rs @@ -33,13 +33,17 @@ pub fn register_all_passes() { } /// Parses a pass pipeline. -pub fn parse_pass_pipeline(manager: pass::OperationPassManager, source: &str) -> Result<(), Error> { +pub fn parse_pass_pipeline( + context: &Context, + manager: pass::OperationPassManager, + source: &str, +) -> Result<(), Error> { let mut error_message = None; let result = LogicalResult::from_raw(unsafe { mlirParsePassPipeline( manager.to_raw(), - StringRef::from(source).to_raw(), + StringRef::from_str(context, source).to_raw(), Some(handle_parse_error), &mut error_message as *mut _ as *mut _, )