From 9f10c9ed69f3d6eecd4f6f4ed82e2ff028783558 Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Tue, 5 Dec 2023 23:33:05 +0800 Subject: [PATCH] Refactor dialect macros (#378) --- macro/src/dialect.rs | 10 ++- macro/src/dialect/dialect.rs | 104 ----------------------- macro/src/dialect/error.rs | 11 ++- macro/src/dialect/input.rs | 48 ++--------- macro/src/dialect/input/input_field.rs | 34 ++++++++ macro/src/dialect/operation.rs | 65 +++++++------- macro/src/dialect/operation/accessors.rs | 9 +- macro/src/dialect/operation/builder.rs | 2 +- macro/src/dialect/types.rs | 34 ++++---- 9 files changed, 116 insertions(+), 201 deletions(-) delete mode 100644 macro/src/dialect/dialect.rs create mode 100644 macro/src/dialect/input/input_field.rs diff --git a/macro/src/dialect.rs b/macro/src/dialect.rs index 67833e57fd..9a5563e470 100644 --- a/macro/src/dialect.rs +++ b/macro/src/dialect.rs @@ -30,7 +30,12 @@ pub fn generate_dialect(input: DialectInput) -> Result, _>>()? .into_iter() .filter(|operation| operation.dialect_name() == dialect_name) - .collect::>(); + .map(|operation| operation.to_tokens()) + .collect::, _>>()?; let doc = format!( "`{name}` dialect.\n\n{}", diff --git a/macro/src/dialect/dialect.rs b/macro/src/dialect/dialect.rs deleted file mode 100644 index 67833e57fd..0000000000 --- a/macro/src/dialect/dialect.rs +++ /dev/null @@ -1,104 +0,0 @@ -mod error; -mod input; -mod operation; -mod types; -mod utility; - -use self::{ - error::Error, - utility::{sanitize_documentation, sanitize_snake_case_name}, -}; -pub use input::DialectInput; -use operation::Operation; -use proc_macro::TokenStream; -use proc_macro2::Span; -use quote::quote; -use std::{env, fmt::Display, path::Path, process::Command, str}; -use tblgen::{record::Record, record_keeper::RecordKeeper, TableGenParser}; - -const LLVM_MAJOR_VERSION: usize = 17; - -pub fn generate_dialect(input: DialectInput) -> Result> { - let mut parser = TableGenParser::new(); - - if let Some(source) = input.table_gen() { - parser = parser.add_source(source).map_err(create_syn_error)?; - } - - if let Some(file) = input.td_file() { - parser = parser.add_source_file(file).map_err(create_syn_error)?; - } - - // spell-checker: disable-next-line - for path in input.includes().chain([&*llvm_config("--includedir")?]) { - parser = parser.add_include_path(path); - } - - let keeper = parser.parse().map_err(Error::Parse)?; - - let dialect = generate_dialect_module( - input.name(), - keeper - .all_derived_definitions("Dialect") - .find(|definition| definition.str_value("name") == Ok(input.name())) - .ok_or_else(|| create_syn_error("dialect not found"))?, - &keeper, - ) - .map_err(|error| error.add_source_info(keeper.source_info()))?; - - Ok(quote! { #dialect }.into()) -} - -fn generate_dialect_module( - name: &str, - dialect: Record, - record_keeper: &RecordKeeper, -) -> Result { - let dialect_name = dialect.name()?; - let operations = record_keeper - .all_derived_definitions("Op") - .map(Operation::new) - .collect::, _>>()? - .into_iter() - .filter(|operation| operation.dialect_name() == dialect_name) - .collect::>(); - - let doc = format!( - "`{name}` dialect.\n\n{}", - sanitize_documentation(dialect.str_value("description").unwrap_or(""),)? - ); - let name = sanitize_snake_case_name(name)?; - - Ok(quote! { - #[doc = #doc] - pub mod #name { - #(#operations)* - } - }) -} - -fn llvm_config(argument: &str) -> Result> { - let prefix = env::var(format!("MLIR_SYS_{}0_PREFIX", LLVM_MAJOR_VERSION)) - .map(|path| Path::new(&path).join("bin")) - .unwrap_or_default(); - let call = format!( - "{} --link-static {}", - prefix.join("llvm-config").display(), - argument - ); - - Ok(str::from_utf8( - &if cfg!(target_os = "windows") { - Command::new("cmd").args(["/C", &call]).output()? - } else { - Command::new("sh").arg("-c").arg(&call).output()? - } - .stdout, - )? - .trim() - .to_string()) -} - -fn create_syn_error(error: impl Display) -> syn::Error { - syn::Error::new(Span::call_site(), format!("{}", error)) -} diff --git a/macro/src/dialect/error.rs b/macro/src/dialect/error.rs index 2f7b67f93d..4c4a13b367 100644 --- a/macro/src/dialect/error.rs +++ b/macro/src/dialect/error.rs @@ -83,16 +83,19 @@ impl From for Error { pub enum OdsError { ExpectedSuperClass(&'static str), InvalidTrait, + UnexpectedSuperClass(&'static str), } impl Display for OdsError { fn fmt(&self, formatter: &mut Formatter) -> fmt::Result { match self { - Self::ExpectedSuperClass(class) => write!( - formatter, - "expected this record to be a subclass of {class}", - ), + Self::ExpectedSuperClass(class) => { + write!(formatter, "record should be a sub-class of {class}",) + } Self::InvalidTrait => write!(formatter, "record is not a supported trait"), + Self::UnexpectedSuperClass(class) => { + write!(formatter, "record should not be a sub-class of {class}",) + } } } } diff --git a/macro/src/dialect/input.rs b/macro/src/dialect/input.rs index 0304083fb0..bb11df5fb7 100644 --- a/macro/src/dialect/input.rs +++ b/macro/src/dialect/input.rs @@ -1,13 +1,14 @@ -use proc_macro2::Ident; -use quote::format_ident; +mod input_field; + +use self::input_field::InputField; use std::ops::Deref; -use syn::{bracketed, parse::Parse, punctuated::Punctuated, LitStr, Token}; +use syn::{parse::Parse, punctuated::Punctuated, Token}; pub struct DialectInput { name: String, table_gen: Option, td_file: Option, - includes: Vec, + include_directories: Vec, } impl DialectInput { @@ -23,8 +24,8 @@ impl DialectInput { self.td_file.as_deref() } - pub fn includes(&self) -> impl Iterator { - self.includes.iter().map(Deref::deref) + pub fn include_directories(&self) -> impl Iterator { + self.include_directories.iter().map(Deref::deref) } } @@ -40,7 +41,7 @@ impl Parse for DialectInput { InputField::Name(field) => name = Some(field.value()), InputField::TableGen(td) => table_gen = Some(td.value()), InputField::TdFile(file) => td_file = Some(file.value()), - InputField::Includes(field) => { + InputField::IncludeDirectories(field) => { includes = field.into_iter().map(|literal| literal.value()).collect() } } @@ -50,38 +51,7 @@ impl Parse for DialectInput { name: name.ok_or(input.error("dialect name required"))?, table_gen, td_file, - includes, + include_directories: includes, }) } } - -enum InputField { - Name(LitStr), - TableGen(LitStr), - TdFile(LitStr), - Includes(Punctuated), -} - -impl Parse for InputField { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let ident = input.parse::()?; - - input.parse::()?; - - if ident == format_ident!("name") { - Ok(Self::Name(input.parse()?)) - } else if ident == format_ident!("table_gen") { - Ok(Self::TableGen(input.parse()?)) - } else if ident == format_ident!("td_file") { - Ok(Self::TdFile(input.parse()?)) - } else if ident == format_ident!("include_dirs") { - let content; - bracketed!(content in input); - Ok(Self::Includes( - Punctuated::::parse_terminated(&content)?, - )) - } else { - Err(input.error(format!("invalid field {}", ident))) - } - } -} diff --git a/macro/src/dialect/input/input_field.rs b/macro/src/dialect/input/input_field.rs new file mode 100644 index 0000000000..61a161fca6 --- /dev/null +++ b/macro/src/dialect/input/input_field.rs @@ -0,0 +1,34 @@ +use proc_macro2::Ident; +use quote::format_ident; +use syn::{bracketed, parse::Parse, punctuated::Punctuated, LitStr, Token}; + +pub enum InputField { + Name(LitStr), + TableGen(LitStr), + TdFile(LitStr), + IncludeDirectories(Punctuated), +} + +impl Parse for InputField { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let ident = input.parse::()?; + + input.parse::()?; + + if ident == format_ident!("name") { + Ok(Self::Name(input.parse()?)) + } else if ident == format_ident!("table_gen") { + Ok(Self::TableGen(input.parse()?)) + } else if ident == format_ident!("td_file") { + Ok(Self::TdFile(input.parse()?)) + } else if ident == format_ident!("include_dirs") { + let content; + bracketed!(content in input); + Ok(Self::IncludeDirectories( + Punctuated::::parse_terminated(&content)?, + )) + } else { + Err(input.error(format!("invalid field {}", ident))) + } + } +} diff --git a/macro/src/dialect/operation.rs b/macro/src/dialect/operation.rs index 90e5005aa5..fdee4641ad 100644 --- a/macro/src/dialect/operation.rs +++ b/macro/src/dialect/operation.rs @@ -16,7 +16,7 @@ use crate::dialect::{ types::{AttributeConstraint, RegionConstraint, SuccessorConstraint, Trait, TypeConstraint}, }; use proc_macro2::TokenStream; -use quote::{format_ident, quote, ToTokens, TokenStreamExt}; +use quote::{format_ident, quote}; use tblgen::{error::WithLocation, record::Record}; #[derive(Clone, Debug)] @@ -44,7 +44,7 @@ impl<'a> Operation<'a> { let arguments = Self::dag_constraints(definition, "arguments")?; let regions = Self::collect_regions(definition)?; - let (results, unfixed_results_count) = Self::collect_results( + let (results, unfixed_result_count) = Self::collect_results( definition, has_trait("::mlir::OpTrait::SameVariadicResultSize"), has_trait("::mlir::OpTrait::AttrSizedResultSegments"), @@ -86,7 +86,7 @@ impl<'a> Operation<'a> { can_infer_type: traits.iter().any(|r#trait| { (r#trait.has_name("::mlir::OpTrait::FirstAttrDerivedResultType") || r#trait.has_name("::mlir::OpTrait::SameOperandsAndResultType")) - && unfixed_results_count == 0 + && unfixed_result_count == 0 || r#trait.has_name("::mlir::InferTypeOpInterface::Trait") && regions.is_empty() }), summary: { @@ -114,16 +114,17 @@ impl<'a> Operation<'a> { pub fn fields(&self) -> impl Iterator> + Clone { self.results .iter() - .chain(self.operands.iter()) - .chain(self.regions.iter()) - .chain(self.successors.iter()) - .chain(self.attributes.iter()) - .chain(self.derived_attributes.iter()) + .chain(&self.operands) + .chain(&self.regions) + .chain(&self.successors) + .chain(&self.attributes) + .chain(&self.derived_attributes) } fn collect_successors(definition: Record<'a>) -> Result, Error> { let successors_dag = definition.dag_value("successors")?; let len = successors_dag.num_args(); + successors_dag .args() .enumerate() @@ -144,6 +145,7 @@ impl<'a> Operation<'a> { fn collect_regions(definition: Record<'a>) -> Result, Error> { let regions_dag = definition.dag_value("regions")?; let len = regions_dag.num_args(); + regions_dag .args() .enumerate() @@ -296,53 +298,58 @@ impl<'a> Operation<'a> { .iter() .filter(|(_, definition)| definition.subclass_of("Attr")) .map(|(name, definition)| { - // TODO: Replace assert! with Result - assert!(!definition.subclass_of("DerivedAttr")); - - OperationField::new_attribute(name, AttributeConstraint::new(*definition)) + if definition.subclass_of("DerivedAttr") { + Err(OdsError::UnexpectedSuperClass("DerivedAttr") + .with_location(*definition) + .into()) + } else { + OperationField::new_attribute(name, AttributeConstraint::new(*definition)) + } }) .collect() } - fn collect_derived_attributes(def: Record<'a>) -> Result>, Error> { - def.values() + fn collect_derived_attributes( + definition: Record<'a>, + ) -> Result>, Error> { + definition + .values() .filter_map(|value| { let Ok(def) = Record::try_from(value) else { return None; }; def.subclass_of("Attr").then_some(def) }) - .map(|def| { - if def.subclass_of("DerivedAttr") { - OperationField::new_attribute(def.name()?, AttributeConstraint::new(def)) + .map(|definition| { + if definition.subclass_of("DerivedAttr") { + OperationField::new_attribute( + definition.name()?, + AttributeConstraint::new(definition), + ) } else { Err(OdsError::ExpectedSuperClass("DerivedAttr") - .with_location(def) + .with_location(definition) .into()) } }) .collect() } -} -impl<'a> ToTokens for Operation<'a> { - // TODO Compile values for proper error handling and remove `Result::expect()`. - fn to_tokens(&self, tokens: &mut TokenStream) { + pub fn to_tokens(&self) -> Result { let class_name = format_ident!("{}", &self.class_name); let name = &self.full_name; let accessors = self .fields() - .map(|field| field.accessors().expect("valid accessors")); - let builder = OperationBuilder::new(self).expect("valid builder generator"); - let builder_tokens = builder.builder().expect("valid builder"); + .map(|field| field.accessors()) + .collect::, _>>()?; + let builder = OperationBuilder::new(self)?; + let builder_tokens = builder.to_tokens()?; let builder_fn = builder.create_op_builder_fn(); - let default_constructor = builder - .create_default_constructor() - .expect("valid constructor"); + let default_constructor = builder.create_default_constructor()?; let summary = &self.summary; let description = &self.description; - tokens.append_all(quote! { + Ok(quote! { #[doc = #summary] #[doc = "\n\n"] #[doc = #description] diff --git a/macro/src/dialect/operation/accessors.rs b/macro/src/dialect/operation/accessors.rs index 7ebf4ecd0a..e2f1cd0f62 100644 --- a/macro/src/dialect/operation/accessors.rs +++ b/macro/src/dialect/operation/accessors.rs @@ -14,10 +14,9 @@ impl<'a> OperationField<'a> { sequence_info: SequenceInfo { index, len }, variadic_kind, } => { - let kind_str = kind.as_str(); - let kind_ident = format_ident!("{}", kind_str); - let plural = format_ident!("{}s", kind_str); - let count = format_ident!("{}_count", kind_str); + let kind_ident = format_ident!("{}", kind.as_str()); + let plural = format_ident!("{}s", kind.as_str()); + let count = format_ident!("{}_count", kind.as_str()); let error_variant = match kind { ElementKind::Operand => quote!(OperandNotFound), ElementKind::Result => quote!(ResultNotFound), @@ -82,7 +81,7 @@ impl<'a> OperationField<'a> { quote! { #compute_start_length #get_elements } } VariadicKind::AttributeSized => { - let attribute_name = format!("{}_segment_sizes", kind_str); + let attribute_name = format!("{}_segment_sizes", kind.as_str()); let compute_start_length = quote! { let attribute = ::melior::ir::attribute::DenseI32ArrayAttribute::<'c>::try_from( diff --git a/macro/src/dialect/operation/builder.rs b/macro/src/dialect/operation/builder.rs index a7b9efc7f8..5f5989ccfb 100644 --- a/macro/src/dialect/operation/builder.rs +++ b/macro/src/dialect/operation/builder.rs @@ -105,7 +105,7 @@ impl<'o> OperationBuilder<'o> { }) } - pub fn builder(&self) -> Result { + pub fn to_tokens(&self) -> Result { let field_names = self .type_state .field_names() diff --git a/macro/src/dialect/types.rs b/macro/src/dialect/types.rs index 6b4db80c62..0549503a17 100644 --- a/macro/src/dialect/types.rs +++ b/macro/src/dialect/types.rs @@ -173,7 +173,7 @@ enum TraitKind { #[allow(unused)] structural: bool, }, - Pred {}, + Predicate, Internal { name: String, }, @@ -188,25 +188,25 @@ pub struct Trait { } impl Trait { - pub fn new(def: Record) -> Result { + pub fn new(definition: Record) -> Result { Ok(Self { - kind: if def.subclass_of("PredTrait") { - TraitKind::Pred {} - } else if def.subclass_of("InterfaceTrait") { + kind: if definition.subclass_of("PredTrait") { + TraitKind::Predicate + } else if definition.subclass_of("InterfaceTrait") { TraitKind::Interface { - name: Self::name(def)?, + name: Self::name(definition)?, } - } else if def.subclass_of("NativeTrait") { + } else if definition.subclass_of("NativeTrait") { TraitKind::Native { - name: Self::name(def)?, - structural: def.subclass_of("StructuralOpTrait"), + name: Self::name(definition)?, + structural: definition.subclass_of("StructuralOpTrait"), } - } else if def.subclass_of("GenInternalTrait") { + } else if definition.subclass_of("GenInternalTrait") { TraitKind::Internal { - name: def.string_value("trait")?, + name: definition.string_value("trait")?, } } else { - return Err(OdsError::InvalidTrait.with_location(def).into()); + return Err(OdsError::InvalidTrait.with_location(definition).into()); }, }) } @@ -215,14 +215,14 @@ impl Trait { match &self.kind { TraitKind::Native { name, .. } | TraitKind::Internal { name } - | TraitKind::Interface { name } => expected_name == name, - TraitKind::Pred {} => false, + | TraitKind::Interface { name } => name == expected_name, + TraitKind::Predicate => false, } } - fn name(def: Record) -> Result { - let r#trait = def.string_value("trait")?; - let namespace = def.string_value("cppNamespace")?; + fn name(definition: Record) -> Result { + let r#trait = definition.string_value("trait")?; + let namespace = definition.string_value("cppNamespace")?; Ok(if namespace.is_empty() { r#trait