diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index cc43b0918bad..33b366fd8f35 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -93,6 +93,7 @@ iree_compiler_cc_library( "BufferizeCopyOnlyDispatchesPass.cpp", "CleanupBufferAllocViewPass.cpp", "ConcretizePadResultShape.cpp", + "ConfigTrackingCanonicalizer.cpp", "ConvertBf16ArithToF32.cpp", "ConvertBf16ToUInt16Buffers.cpp", "ConvertToDestinationPassingStylePass.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index c87a67d3358b..53feab310b2e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -85,6 +85,7 @@ iree_cc_library( "BufferizeCopyOnlyDispatchesPass.cpp" "CleanupBufferAllocViewPass.cpp" "ConcretizePadResultShape.cpp" + "ConfigTrackingCanonicalizer.cpp" "ConvertBf16ArithToF32.cpp" "ConvertBf16ToUInt16Buffers.cpp" "ConvertToDestinationPassingStylePass.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/ConcretizePadResultShape.cpp b/compiler/src/iree/compiler/Codegen/Common/ConcretizePadResultShape.cpp index 22f08cff359c..89727e482dd9 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ConcretizePadResultShape.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ConcretizePadResultShape.cpp @@ -5,6 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Common/Transforms.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -138,7 +139,11 @@ class ConcretizePadResultShapePass final { RewritePatternSet patterns(context); populateConcretizePadResultShapePatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + GreedyRewriteConfig config; + auto listener = ConfigTrackingListener(); + config.listener = &listener; + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns), + config))) { return signalPassFailure(); } } diff --git a/compiler/src/iree/compiler/Codegen/Common/ConfigTrackingCanonicalizer.cpp b/compiler/src/iree/compiler/Codegen/Common/ConfigTrackingCanonicalizer.cpp new file mode 100644 index 000000000000..0294d9b43208 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/ConfigTrackingCanonicalizer.cpp @@ -0,0 +1,108 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Common/Transforms.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-codegen-config-tracking-canonicalizer" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_CONFIGTRACKINGCANONICALIZERPASS +#include "iree/compiler/Codegen/Common/Passes.h.inc" + +static Operation *skipCastsDefiningOp(Value v) { + auto producer = v.getDefiningOp(); + while (auto castProducer = dyn_cast(producer)) { + producer = castProducer.getSource().getDefiningOp(); + } + return producer; +} + +void ConfigTrackingListener::notifyOperationReplaced(Operation *op, + ValueRange replacement) { + // We have no way to track replacements without a producer. + if (replacement.empty()) { + return; + } + + IREE::Codegen::LoweringConfigAttrInterface loweringConfig = + getLoweringConfig(op); + if (!loweringConfig) { + return; + } + + // Must have a producer of the same type to track the lowering config. + auto producer = skipCastsDefiningOp(replacement.front()); + if (!producer || producer->getName() != op->getName()) { + return; + } + + for (auto v : replacement.drop_front()) { + // Conservatively require that all replacements are produced by the same + // operation. + if (skipCastsDefiningOp(v) != producer) { + return; + } + } + + // No need to add the lowering config if it's already present. + if (getLoweringConfig(producer)) { + return; + } + + setLoweringConfig(producer, loweringConfig); +} + +namespace { + +/// Add the corresponding fast-math flags to operations given a floating-point +/// optimization mode. +// TODO: For now we only allow default flags, such as arithmetic reassociation. +struct ConfigTrackingCanonicalizerPass final + : impl::ConfigTrackingCanonicalizerPassBase< + ConfigTrackingCanonicalizerPass> { +public: + using impl::ConfigTrackingCanonicalizerPassBase< + ConfigTrackingCanonicalizerPass>::ConfigTrackingCanonicalizerPassBase; + /// Initialize the canonicalizer by building the set of patterns used during + /// execution. + LogicalResult initialize(MLIRContext *context) override { + // Inherit the same config defaults from the upstream canonicalizer pass. + config.useTopDownTraversal = true; + config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Normal; + + RewritePatternSet owningPatterns(context); + for (auto *dialect : context->getLoadedDialects()) + dialect->getCanonicalizationPatterns(owningPatterns); + for (RegisteredOperationName op : context->getRegisteredOperations()) + op.getCanonicalizationPatterns(owningPatterns, context); + + patterns = + std::make_shared(std::move(owningPatterns)); + return success(); + } + + void runOnOperation() override { + // Canonicalization is best-effort. Non-convergence is not a pass failure. + auto listener = ConfigTrackingListener(); + config.listener = &listener; + LogicalResult didConverge = + applyPatternsAndFoldGreedily(getOperation(), *patterns, config); + if (this->testConvergence && failed(didConverge)) { + getOperation()->emitError("Canonicalizer failed to converge"); + return signalPassFailure(); + } + } + GreedyRewriteConfig config; + std::shared_ptr patterns; +}; + +} // namespace +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index 493afa843f81..7695f795f755 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -44,6 +44,15 @@ def BufferizeCopyOnlyDispatchesPass : }]; } +def ConfigTrackingCanonicalizerPass : + Pass<"iree-codegen-config-tracking-canonicalize", ""> { + let summary = "Codegen specific canonicalization pass that tracks lowering configs"; + let options = [ + Option<"testConvergence", "test-convergence", "bool", + /*default=*/"false", "Fails if the patterns fail to converge"> + ]; +} + def CleanupBufferAllocViewPass : InterfacePass<"iree-codegen-cleanup-buffer-alloc-view", "mlir::FunctionOpInterface"> { let summary = diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.h b/compiler/src/iree/compiler/Codegen/Common/Transforms.h index 0a000348e22e..98c7478d7e95 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Transforms.h +++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.h @@ -18,6 +18,14 @@ struct OneShotBufferizationOptions; namespace mlir::iree_compiler { +/// Common helper class for tracking lowering configs through pattern +/// applications. +class ConfigTrackingListener : public RewriterBase::Listener { +public: + ConfigTrackingListener() = default; + void notifyOperationReplaced(Operation *op, ValueRange replacement) override; +}; + using IGEMMConfigFn = std::function; using IGEMMControlFn = std::function; diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index 16c47d2184cb..db321d132c61 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -121,7 +121,7 @@ static void addTileAndDistributePasses(OpPassManager &funcPassManager) { funcPassManager.addPass(createConvertToDestinationPassingStylePass()); funcPassManager.addPass(createFoldAffineMinInDistributedLoopsPass()); } - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); funcPassManager.addPass(createFuseTensorPadWithConsumerPass()); funcPassManager.addPass(createConcretizePadResultShapePass()); @@ -425,7 +425,7 @@ void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager, funcPassManager.addPass(createTensorToVectorVectorizePadPass()); if (pipelineOpt.decomposePackUnPackOps) { funcPassManager.addPass(createDecomposePackUnPackOpsPass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 53f67e4b1266..33f3ae90373e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -197,7 +197,7 @@ static void tileAndDistributeToWorkgroup( // TODO(#16421): Disable decomposition due to failure in bufferization. // funcPassManager.addPass( // IREE::LinalgExt::createTileAndDecomposeAttentionPass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); } @@ -238,13 +238,13 @@ static void addGPUVectorizationPasses(OpPassManager &funcPassManager, void addGPUVectorizationPassPipeline(OpPassManager &funcPassManager) { tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); - funcPassManager.addPass(createCanonicalizerPass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); // Distribute linalg onto threads within the workgroup. funcPassManager.addPass(createGPUTensorTilePass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); // Linalg -> vector @@ -365,7 +365,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager, GPUApplyTilingLevelPassOptions options; options.tilingLevel = IREE::GPU::TilingLevel::Reduction; funcPassManager.addPass(createGPUApplyTilingLevelPass(options)); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); } @@ -384,7 +384,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager, } funcPassManager.addPass(createPropagateReshapesByExpansionPass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); // Step 4. Tile and fuse tileable ops to subgroups/threads. @@ -392,7 +392,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager, GPUApplyTilingLevelPassOptions options; options.tilingLevel = IREE::GPU::TilingLevel::Thread; funcPassManager.addPass(createGPUApplyTilingLevelPass(options)); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); } { @@ -406,7 +406,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager, funcPassManager.addPass(iree_compiler::createNormalizeLoopBoundsPass( NormalizeLoopBoundsPassOptions{/*normalizeFor=*/false, /*normalizeForall=*/true})); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); // TODO: This LICM instance is load bearing due to brittleness of the @@ -489,13 +489,13 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager, void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) { tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); - funcPassManager.addPass(createCanonicalizerPass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); // Distribute linalg onto threads within the workgroup. funcPassManager.addPass(createGPUTilePass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); funcPassManager.addPass( IREE::LinalgExt::createDecomposeWinogradTransformPass()); @@ -512,7 +512,7 @@ void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) { // Post bufferization optimizations. funcPassManager.addPass(createIREELoopInvariantCodeMotionPass()); funcPassManager.addPass(memref::createFoldMemRefAliasOpsPass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); funcPassManager.addPass(createOptimizeVectorTransferPass()); funcPassManager.addPass(createOptimizeTensorInsertExtractSlicesPass()); @@ -526,8 +526,8 @@ void addGPUMatmulSimtPassPipeline(OpPassManager &funcPassManager, const GPUPipelineOptions &options) { tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); - funcPassManager.addPass(createCanonicalizerPass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); funcPassManager.addPass(createGPUTensorTileToSerialLoopsPass()); @@ -727,8 +727,8 @@ void addGPUTransposePassPipeline(OpPassManager &funcPassManager, const GPUPipelineOptions &options) { tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); - funcPassManager.addPass(createCanonicalizerPass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); funcPassManager.addPass( @@ -844,7 +844,7 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager, funcPassManager.addPass( IREE::LinalgExt::createConvertAttentionToOnlineAttentionPass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); funcPassManager.addPass(createGPUPromoteMatmulOperandsPass()); @@ -855,12 +855,12 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager, options.allowZeroSlices = true; funcPassManager.addPass(createGPUApplyTilingLevelPass(options)); funcPassManager.addPass(affine::createLoopCoalescingPass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); } funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); // Set anchors at tensor level for vector distribution later and hoist out @@ -927,9 +927,9 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager, void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) { tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); funcPassManager.addPass(createRematerializeParallelOpsPass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createGPUTileReductionPass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); // Linalg -> vector @@ -970,11 +970,11 @@ void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) { void addGPUPackUnPackPasses(OpPassManager &funcPassManager) { tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); funcPassManager.addPass(createGPUTensorTilePass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); funcPassManager.addPass(createDecomposePackUnPackOpsPass( @@ -1165,7 +1165,7 @@ static void buildLLVMGPUCodegenConfigurationPassPipelineImpl( addCommonTargetExecutablePreprocessingPasses(funcPassManager); addEncodingToNopPasses(funcPassManager); funcPassManager.addPass(createBlockDynamicDimensionsPass); - funcPassManager.addPass(createCanonicalizerPass); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass); funcPassManager.addPass(createCSEPass); } modulePassManager.addPass(createMaterializeUserConfigsPass()); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp index 30868fffbec6..6501c1080aaa 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp @@ -119,7 +119,7 @@ static void addTileAndDistributeToWorkgroupsPasses( } funcPassManager.addPass(createConvertToDestinationPassingStylePass( useWARForCooperativeMatrixCodegen)); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); } @@ -305,7 +305,7 @@ void addSPIRVBaseVectorizePassPipeline(OpPassManager &funcPassManager) { funcPassManager.addPass(createFoldAffineMinInDistributedLoopsPass()); funcPassManager.addPass(memref::createResolveShapedTypeResultDimsPass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); // Tile to GPU invocations and vectorize. @@ -341,18 +341,18 @@ void addSPIRVWinogradVectorizePassPipeline(OpPassManager &funcPassManager) { funcPassManager.addPass(createFoldAffineMinInDistributedLoopsPass()); funcPassManager.addPass(memref::createResolveShapedTypeResultDimsPass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); funcPassManager.addPass(createGPUTilePass()); funcPassManager.addPass( IREE::LinalgExt::createDecomposeWinogradTransformPass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); // Tile to GPU invocations and vectorize. funcPassManager.addPass(createSPIRVAnnotateWinogradLoopsPass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); { GenericVectorizationPassOptions options; @@ -392,7 +392,7 @@ void addSPIRVCooperativeMatrixVectorizePassPipeline( funcPassManager.addPass(createRemoveSingleIterationLoopPass()); // Run canonicalization patterns to propagate constant shape sizes after // removing trip-one loops. - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); // Tile and distribute to GPU subgroups. @@ -550,7 +550,7 @@ void addSPIRVSubgroupReducePassPipeline(OpPassManager &funcPassManager) { // Fuse input parallel ops into the reduction op so that we don't need to // create temporary allocations during bufferization. funcPassManager.addPass(createRematerializeParallelOpsPass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createGPUTileReductionPass()); funcPassManager.addPass(createCanonicalizerPass()); diff --git a/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp b/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp index 78e3ea546b95..952b5e6dc42f 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp @@ -39,7 +39,7 @@ static void addTileAndDistributePasses(OpPassManager &funcPassManager) { funcPassManager.addPass(createCSEPass()); funcPassManager.addPass(createConvertToDestinationPassingStylePass()); funcPassManager.addPass(createFoldAffineMinInDistributedLoopsPass()); - funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); funcPassManager.addPass(createFuseTensorPadWithConsumerPass()); funcPassManager.addPass(createConcretizePadResultShapePass());