Skip to content

Refactor dialect macros #375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions macro/src/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,28 @@ use tblgen::{record::Record, record_keeper::RecordKeeper, TableGenParser};
const LLVM_MAJOR_VERSION: usize = 17;

pub fn generate_dialect(input: DialectInput) -> Result<TokenStream, Box<dyn std::error::Error>> {
let mut td_parser = TableGenParser::new();
let mut parser = TableGenParser::new();

if let Some(source) = input.tablegen() {
td_parser = td_parser.add_source(source).map_err(create_syn_error)?;
if let Some(source) = input.table_gen() {
parser = parser.add_source(source).map_err(create_syn_error)?;
}

if let Some(file) = input.td_file() {
td_parser = td_parser.add_source_file(file).map_err(create_syn_error)?;
parser = parser.add_source_file(file).map_err(create_syn_error)?;
}

// spell-checker: disable-next-line
for include in input.includes().chain([&*llvm_config("--includedir")?]) {
td_parser = td_parser.add_include_path(include);
for path in input.includes().chain([&*llvm_config("--includedir")?]) {
parser = parser.add_include_path(path);
}

let keeper = td_parser.parse().map_err(Error::Parse)?;
let keeper = parser.parse().map_err(Error::Parse)?;

let dialect = dialect_module(
let dialect = generate_dialect_module(
input.name(),
keeper
.all_derived_definitions("Dialect")
.find(|def| def.str_value("name") == Ok(input.name()))
.find(|definition| definition.str_value("name") == Ok(input.name()))
.ok_or_else(|| create_syn_error("dialect not found"))?,
&keeper,
)
Expand All @@ -49,7 +49,7 @@ pub fn generate_dialect(input: DialectInput) -> Result<TokenStream, Box<dyn std:
Ok(quote! { #dialect }.into())
}

fn dialect_module(
fn generate_dialect_module(
name: &str,
dialect: Record,
record_keeper: &RecordKeeper,
Expand Down
14 changes: 7 additions & 7 deletions macro/src/dialect/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use syn::{bracketed, parse::Parse, punctuated::Punctuated, LitStr, Token};

pub struct DialectInput {
name: String,
tablegen: Option<String>,
table_gen: Option<String>,
td_file: Option<String>,
includes: Vec<String>,
}
Expand All @@ -15,8 +15,8 @@ impl DialectInput {
&self.name
}

pub fn tablegen(&self) -> Option<&str> {
self.tablegen.as_deref()
pub fn table_gen(&self) -> Option<&str> {
self.table_gen.as_deref()
}

pub fn td_file(&self) -> Option<&str> {
Expand All @@ -31,14 +31,14 @@ impl DialectInput {
impl Parse for DialectInput {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut name = None;
let mut tablegen = None;
let mut table_gen = None;
let mut td_file = None;
let mut includes = vec![];

for item in Punctuated::<InputField, Token![,]>::parse_terminated(input)? {
match item {
InputField::Name(field) => name = Some(field.value()),
InputField::TableGen(td) => tablegen = Some(td.value()),
InputField::TableGen(td) => table_gen = Some(td.value()),
InputField::TdFile(file) => td_file = Some(file.value()),
InputField::Includes(field) => {
includes = field.into_iter().map(|literal| literal.value()).collect()
Expand All @@ -48,7 +48,7 @@ impl Parse for DialectInput {

Ok(Self {
name: name.ok_or(input.error("dialect name required"))?,
tablegen,
table_gen,
td_file,
includes,
})
Expand All @@ -70,7 +70,7 @@ impl Parse for InputField {

if ident == format_ident!("name") {
Ok(Self::Name(input.parse()?))
} else if ident == format_ident!("tablegen") {
} else if ident == format_ident!("table_gen") {
Ok(Self::TableGen(input.parse()?))
} else if ident == format_ident!("td_file") {
Ok(Self::TdFile(input.parse()?))
Expand Down
5 changes: 3 additions & 2 deletions macro/src/dialect/utility.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ pub fn sanitize_snake_case_name(name: &str) -> Result<Ident, Error> {
}

fn sanitize_name(name: &str) -> Result<Ident, Error> {
// Replace any "." with "_"
// Replace any "." with "_".
let mut name = name.replace('.', "_");

// Add "_" suffix to avoid conflicts with existing methods
// Add "_" suffix to avoid conflicts with existing methods.
if RESERVED_NAMES.contains(&name.as_str())
|| name
.chars()
Expand Down Expand Up @@ -44,6 +44,7 @@ pub fn sanitize_documentation(string: &str) -> Result<String, Error> {
};

if block.info.is_empty() {
// Mark them not in Rust to prevent documentation tests.
block.info = "text".into();
}
}
Expand Down
2 changes: 1 addition & 1 deletion macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use syn::parse_macro_input;
/// ```rust
/// melior::dialect! {
/// name: "func",
/// tablegen: r#"include "mlir/Dialect/Func/IR/FuncOps.td""#
/// table_gen: r#"include "mlir/Dialect/Func/IR/FuncOps.td""#
/// }
/// ```
#[proc_macro]
Expand Down
50 changes: 25 additions & 25 deletions melior/src/dialect/ods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,104 +9,104 @@ pub mod __private {

melior_macro::dialect! {
name: "affine",
tablegen: r#"include "mlir/Dialect/Affine/IR/AffineOps.td""#
table_gen: r#"include "mlir/Dialect/Affine/IR/AffineOps.td""#
}
melior_macro::dialect! {
name: "amdgpu",
tablegen: r#"include "mlir/Dialect/AMDGPU/IR/AMDGPU.td""#
table_gen: r#"include "mlir/Dialect/AMDGPU/IR/AMDGPU.td""#
}
melior_macro::dialect! {
name: "arith",
tablegen: r#"include "mlir/Dialect/Arith/IR/ArithOps.td""#
table_gen: r#"include "mlir/Dialect/Arith/IR/ArithOps.td""#
}
melior_macro::dialect! {
name: "arm_neon",
tablegen: r#"include "mlir/Dialect/ArmNeon/ArmNeon.td""#
table_gen: r#"include "mlir/Dialect/ArmNeon/ArmNeon.td""#
}
melior_macro::dialect! {
name: "arm_sve",
tablegen: r#"include "mlir/Dialect/ArmSVE/ArmSVE.td""#
table_gen: r#"include "mlir/Dialect/ArmSVE/ArmSVE.td""#
}
melior_macro::dialect! {
name: "async",
tablegen: r#"include "mlir/Dialect/Async/IR/AsyncOps.td""#
table_gen: r#"include "mlir/Dialect/Async/IR/AsyncOps.td""#
}
melior_macro::dialect! {
name: "bufferization",
tablegen: r#"include "mlir/Dialect/Bufferization/IR/BufferizationOps.td""#
table_gen: r#"include "mlir/Dialect/Bufferization/IR/BufferizationOps.td""#
}
melior_macro::dialect! {
name: "cf",
tablegen: r#"include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.td""#
table_gen: r#"include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.td""#
}
melior_macro::dialect! {
name: "func",
tablegen: r#"include "mlir/Dialect/Func/IR/FuncOps.td""#
table_gen: r#"include "mlir/Dialect/Func/IR/FuncOps.td""#
}
melior_macro::dialect! {
name: "index",
tablegen: r#"include "mlir/Dialect/Index/IR/IndexOps.td""#
table_gen: r#"include "mlir/Dialect/Index/IR/IndexOps.td""#
}
melior_macro::dialect! {
name: "llvm",
// spell-checker: disable-next-line
tablegen: r#"include "mlir/Dialect/LLVMIR/LLVMOps.td""#
table_gen: r#"include "mlir/Dialect/LLVMIR/LLVMOps.td""#
}
melior_macro::dialect! {
name: "memref",
tablegen: r#"include "mlir/Dialect/MemRef/IR/MemRefOps.td""#
table_gen: r#"include "mlir/Dialect/MemRef/IR/MemRefOps.td""#
}
melior_macro::dialect! {
name: "scf",
tablegen: r#"include "mlir/Dialect/SCF/IR/SCFOps.td""#
table_gen: r#"include "mlir/Dialect/SCF/IR/SCFOps.td""#
}
melior_macro::dialect! {
name: "pdl",
tablegen: r#"include "mlir/Dialect/PDL/IR/PDLOps.td""#
table_gen: r#"include "mlir/Dialect/PDL/IR/PDLOps.td""#
}
melior_macro::dialect! {
name: "pdl_interp",
tablegen: r#"include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.td""#
table_gen: r#"include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.td""#
}
melior_macro::dialect! {
name: "math",
tablegen: r#"include "mlir/Dialect/Math/IR/MathOps.td""#
table_gen: r#"include "mlir/Dialect/Math/IR/MathOps.td""#
}
melior_macro::dialect! {
name: "gpu",
tablegen: r#"include "mlir/Dialect/GPU/IR/GPUOps.td""#
table_gen: r#"include "mlir/Dialect/GPU/IR/GPUOps.td""#
}
melior_macro::dialect! {
name: "linalg",
tablegen: r#"include "mlir/Dialect/Linalg/IR/LinalgOps.td""#
table_gen: r#"include "mlir/Dialect/Linalg/IR/LinalgOps.td""#
}
melior_macro::dialect! {
name: "quant",
tablegen: r#"include "mlir/Dialect/Quant/QuantOps.td""#
table_gen: r#"include "mlir/Dialect/Quant/QuantOps.td""#
}
melior_macro::dialect! {
name: "shape",
tablegen: r#"include "mlir/Dialect/Shape/IR/ShapeOps.td""#
table_gen: r#"include "mlir/Dialect/Shape/IR/ShapeOps.td""#
}
melior_macro::dialect! {
name: "sparse_tensor",
tablegen: r#"include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.td""#
table_gen: r#"include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.td""#
}
melior_macro::dialect! {
name: "tensor",
tablegen: r#"include "mlir/Dialect/Tensor/IR/TensorOps.td""#
table_gen: r#"include "mlir/Dialect/Tensor/IR/TensorOps.td""#
}
melior_macro::dialect! {
name: "tosa",
tablegen: r#"include "mlir/Dialect/Tosa/IR/TosaOps.td""#
table_gen: r#"include "mlir/Dialect/Tosa/IR/TosaOps.td""#
}
melior_macro::dialect! {
name: "transform",
tablegen: r#"include "mlir/Dialect/Transform/IR/TransformOps.td""#
table_gen: r#"include "mlir/Dialect/Transform/IR/TransformOps.td""#
}
melior_macro::dialect! {
name: "vector",
tablegen: r#"include "mlir/Dialect/Vector/IR/VectorOps.td""#
table_gen: r#"include "mlir/Dialect/Vector/IR/VectorOps.td""#
}

#[cfg(test)]
Expand Down