diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index a3f9a7e61e7d..a9f8c9b1dec6 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -193,6 +193,7 @@ def PolygeistCanonicalize : Pass<"canonicalize-polygeist"> { "arith::ArithDialect", "cf::ControlFlowDialect", "scf::SCFDialect", + "polygeist::PolygeistDialect", ]; let options = [ Option<"topDownProcessingEnabled", "top-down", "bool", diff --git a/lib/polygeist/Passes/PolygeistCanonicalize.cpp b/lib/polygeist/Passes/PolygeistCanonicalize.cpp index a02fc2a34f7f..1d12aaf2ad00 100644 --- a/lib/polygeist/Passes/PolygeistCanonicalize.cpp +++ b/lib/polygeist/Passes/PolygeistCanonicalize.cpp @@ -27,6 +27,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Passes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "polygeist/Dialect.h" #include "polygeist/Ops.h" #include "polygeist/Passes/Passes.h" @@ -57,16 +58,16 @@ struct PolygeistCanonicalizePass config.maxIterations = maxIterations; config.maxNumRewrites = maxNumRewrites; + // The polygeist dialect is marked as a dependency to this pass and that + // causes all of the custom canonicalizers (which are not neccessarily only + // for polygeist ops) to get imported + RewritePatternSet owningPatterns(context); for (auto *dialect : context->getLoadedDialects()) dialect->getCanonicalizationPatterns(owningPatterns); for (RegisteredOperationName op : context->getRegisteredOperations()) op.getCanonicalizationPatterns(owningPatterns, context); - // A hack to add custom canonicalization patterns for non-polygeist ops - polygeist::TypeAlignOp::getCanonicalizationPatterns(owningPatterns, - context); - patterns = std::make_shared( std::move(owningPatterns), disabledPatterns, enabledPatterns); return success();