diff --git a/compiler/src/iree/compiler/Dialect/HAL/Utils/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Utils/BUILD.bazel index e7696df4e03bb..b58c9c260ca94 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Utils/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Utils/BUILD.bazel @@ -14,8 +14,12 @@ package( iree_compiler_cc_library( name = "Utils", + srcs = [ + "ExternBuildingUtils.cpp", + ], hdrs = [ "DeviceSwitchBuilder.h", + "ExternBuildingUtils.h", ], deps = [ "//compiler/src/iree/compiler/Dialect/HAL/IR", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Utils/CMakeLists.txt index 5509ab2ad92c4..0172f5eb35a19 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Utils/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Utils/CMakeLists.txt @@ -15,6 +15,9 @@ iree_cc_library( Utils HDRS "DeviceSwitchBuilder.h" + "ExternBuildingUtils.h" + SRCS + "ExternBuildingUtils.cpp" DEPS LLVMSupport MLIRFuncDialect diff --git a/compiler/src/iree/compiler/Dialect/HAL/Utils/ExternBuildingUtils.cpp b/compiler/src/iree/compiler/Dialect/HAL/Utils/ExternBuildingUtils.cpp new file mode 100644 index 0000000000000..cae9c663ae2a9 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Utils/ExternBuildingUtils.cpp @@ -0,0 +1,91 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_ +#define IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_ + +#include "iree/compiler/Dialect/HAL/Utils/ExternBuildingUtils.h" +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Location.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace HAL { + +namespace { + +/// Helper to build a hal.dispatch.extern op with the given arguments. This +/// returns the block arguments of the workgroup count region and the extern +/// op. Note that the workgroup count region will not include the terminator +/// and that is left up to the user to properly populate. +static FailureOr +createDispatchExtern(PatternRewriter &rewriter, ValueRange workload, + TypeRange resultTypes, ValueRange resultDims, + ValueRange arguments, ValueRange argumentDims, + DenseI64ArrayAttr tiedOperands, DictionaryAttr attrDict) { + Location rootLoc = (*arguments.begin()).getLoc(); + SmallVector tiedOperandsIntList(tiedOperands.asArrayRef()); + SmallVector namedAttributes(attrDict.begin(), attrDict.end()); + Operation *externOp = rewriter.create( + rootLoc, workload, resultTypes, resultDims, arguments, argumentDims, + tiedOperandsIntList, namedAttributes); + return externOp; +} + +/// Helper to emplace a block on the given hal.dispatch.extern op. This returns +/// the block arguments of the updated workgroup count region. Note that the +/// workgroup count region will not include the terminator and that is left up +/// to the user to properly populate. +static FailureOr +emplaceExternWorkgroupCountRegion(PatternRewriter &rewriter, Operation *op) { + Location rootLoc = op->getLoc(); + auto externOp = dyn_cast(op); + if (!externOp) { + return failure(); + } + + SmallVector countTypes({rewriter.getType()}); + SmallVector countLocs({rootLoc}); + for (auto workloadIdx : externOp.getWorkload()) { + countTypes.push_back(workloadIdx.getType()); + countLocs.push_back(workloadIdx.getLoc()); + } + + auto &entryBlock = externOp.getWorkgroupCount().emplaceBlock(); + auto countArgs = entryBlock.addArguments(countTypes, countLocs); + + ArrayRef countArgsArray(countArgs.begin(), countArgs.end()); + + /// Update the insertion point to the beginning of the block to enable + /// contructing the workgroup count region. + rewriter.setInsertionPointToStart(&entryBlock); + return ValueRange(countArgsArray); +} +} // namespace + +void registerExternDispatchRewriteFunction(PDLPatternModule &pdlPatterns) { + pdlPatterns.registerRewriteFunction("create_dispatch_extern", + createDispatchExtern); + pdlPatterns.registerRewriteFunction("emplace_extern_workgroup_count", + emplaceExternWorkgroupCountRegion); +} + +} // namespace HAL +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_ diff --git a/compiler/src/iree/compiler/Dialect/HAL/Utils/ExternBuildingUtils.h b/compiler/src/iree/compiler/Dialect/HAL/Utils/ExternBuildingUtils.h new file mode 100644 index 0000000000000..ab4a0a1001b17 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Utils/ExternBuildingUtils.h @@ -0,0 +1,27 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_ +#define IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace HAL { + +void registerExternDispatchRewriteFunction(PDLPatternModule &pdlPatterns); + +} // namespace HAL +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_ diff --git a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel index e751347ce2b69..8bd0cec4d859a 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel +++ b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel @@ -45,6 +45,7 @@ iree_compiler_cc_library( name = "GlobalOptimization", srcs = [ "MaterializeHomogeneousEncodings.cpp", + "MaterializeExternDispatches.cpp", "Passes.cpp", ], hdrs = [ @@ -57,6 +58,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Flow/Transforms", "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect", + "//compiler/src/iree/compiler/Dialect/HAL/Utils", "//compiler/src/iree/compiler/Dialect/Util/Transforms", "//compiler/src/iree/compiler/Pipelines:Options", "//compiler/src/iree/compiler/Utils", diff --git a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt index 7921d385d6130..940882f0a7527 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt +++ b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt @@ -39,6 +39,7 @@ iree_cc_library( HDRS "Passes.h" SRCS + "MaterializeExternDispatches.cpp" "MaterializeHomogeneousEncodings.cpp" "Passes.cpp" DEPS @@ -55,6 +56,7 @@ iree_cc_library( iree::compiler::Dialect::Flow::Transforms iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::HAL::IR::HALDialect + iree::compiler::Dialect::HAL::Utils iree::compiler::Dialect::Util::Transforms iree::compiler::Pipelines::Options iree::compiler::Utils diff --git a/compiler/src/iree/compiler/GlobalOptimization/MaterializeExternDispatches.cpp b/compiler/src/iree/compiler/GlobalOptimization/MaterializeExternDispatches.cpp new file mode 100644 index 0000000000000..0e1e837490b25 --- /dev/null +++ b/compiler/src/iree/compiler/GlobalOptimization/MaterializeExternDispatches.cpp @@ -0,0 +1,57 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/HAL/Utils/ExternBuildingUtils.h" +#include "iree/compiler/GlobalOptimization/PassDetail.h" +#include "iree/compiler/Utils/CustomPatternApplicatorPassBase.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +class MaterializeExternDispatchesPass + : public iree_compiler::PatternApplicatorPassBase< + MaterializeExternDispatchesPass, + iree_compiler::GlobalOptimization:: + MaterializeExternDispatchesPassBase> { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + LogicalResult initializePatterns(MLIRContext *context, + RewritePatternSet &tmpPatterns) { + iree_compiler::IREE::HAL::registerExternDispatchRewriteFunction( + tmpPatterns.getPDLPatterns()); + return iree_compiler::detail::populatePDLModuleFromFileName( + context, tmpPatterns, this->pdlModuleFileName); + } + + MaterializeExternDispatchesPass(StringRef pdlModuleFileName = StringRef()) { + this->pdlModuleFileName = pdlModuleFileName.str(); + } + MaterializeExternDispatchesPass(const MaterializeExternDispatchesPass &pass) = + default; +}; +} // namespace + +namespace mlir { +namespace iree_compiler { +namespace GlobalOptimization { +std::unique_ptr +createMaterializeExternDispatchesPass(std::string pdlModuleFileName) { + return std::make_unique(pdlModuleFileName); +} +} // namespace GlobalOptimization +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index 472f7bd00e113..25fcfedfa06ec 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -43,6 +43,11 @@ void buildGlobalOptimizationPassPipeline( mainPassManager.addPass(IREE::Util::createDemoteI64ToI32Pass()); } + if (!transformOptions.options.customDispatchPatternModuleFileName.empty()) { + mainPassManager.addPass(createMaterializeExternDispatchesPass( + transformOptions.options.customDispatchPatternModuleFileName)); + } + // Preprocessing passes to get the program into a canonical state. FunctionLikeNest(mainPassManager) .addPass(IREE::Flow::createRemoveZeroExtentTensorsPass) diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.h b/compiler/src/iree/compiler/GlobalOptimization/Passes.h index 65552c50120b9..e16620c483380 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.h +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.h @@ -37,6 +37,11 @@ struct TransformOptions : public PassPipelineOptions { void buildGlobalOptimizationPassPipeline( OpPassManager &mainPassManager, const TransformOptions &transformOptions); +// Materializes logical encodings to physical encodings if there is a single +// device target. +std::unique_ptr> +createMaterializeExternDispatchesPass(std::string pdlModuleFileName = ""); + // Materializes logical encodings to physical encodings if there is a single // device target. std::unique_ptr> diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.td b/compiler/src/iree/compiler/GlobalOptimization/Passes.td index e07e4e83de73e..4227cbdaeebf0 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.td +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.td @@ -9,6 +9,18 @@ include "mlir/Pass/PassBase.td" +def MaterializeExternDispatchesPass : + Pass<"iree-materialize-extern-dispatches", "mlir::ModuleOp"> { + let summary = "Pass to form custom external dispatches"; + let constructor = + "mlir::iree_compiler::GlobalOptimization::createMaterializeExternDispatchesPass()"; + let options = [ + Option<"pdlModuleFileName", "pdl-module-file-name", "std::string", + /*default=*/"\"\"", + "Optional file name to load a pdl module from."> + ]; +} + def MaterializeHomogeneousEncodings : Pass<"iree-global-opt-materialize-homogeneous-encodings", "mlir::ModuleOp"> { let summary = "Materializes logical encodings to physical encodings if there is a single device target."; diff --git a/compiler/src/iree/compiler/Pipelines/Options.cpp b/compiler/src/iree/compiler/Pipelines/Options.cpp index 5d4f8dccf603c..349e24e77f6d8 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.cpp +++ b/compiler/src/iree/compiler/Pipelines/Options.cpp @@ -152,6 +152,11 @@ void GlobalOptimizationOptions::bindOptions(OptionsBinder &binder) { llvm::cl::desc("Strips debug assertions after any useful " "information has been extracted."), llvm::cl::cat(category)); + binder.opt( + "iree-opt-extern-dispatch-pattern-module", + customDispatchPatternModuleFileName, + llvm::cl::desc("File path to custom dispatch rewrite pattnern module."), + llvm::cl::cat(category)); } void SchedulingOptions::bindOptions(OptionsBinder &binder) { diff --git a/compiler/src/iree/compiler/Pipelines/Options.h b/compiler/src/iree/compiler/Pipelines/Options.h index f1c4830887f84..2b65ce29b996a 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.h +++ b/compiler/src/iree/compiler/Pipelines/Options.h @@ -95,6 +95,9 @@ struct GlobalOptimizationOptions { // Strips debug assertions after any useful information has been extracted. bool stripAssertions = false; + // File path to load custom dispatch rewrite patterns from. + std::string customDispatchPatternModuleFileName = ""; + void bindOptions(OptionsBinder &binder); using FromFlags = OptionsFromFlags; }; diff --git a/compiler/src/iree/compiler/Utils/BUILD.bazel b/compiler/src/iree/compiler/Utils/BUILD.bazel index aaa6ca5ecccc8..5e7215f41aef1 100644 --- a/compiler/src/iree/compiler/Utils/BUILD.bazel +++ b/compiler/src/iree/compiler/Utils/BUILD.bazel @@ -18,6 +18,7 @@ iree_compiler_cc_library( name = "Utils", srcs = [ "ConversionUtils.cpp", + "CustomPatternApplicatorPassBase.cpp", "ElementPackingUtils.cpp", "FlatbufferUtils.cpp", "ModuleUtils.cpp", @@ -29,6 +30,7 @@ iree_compiler_cc_library( ], hdrs = [ "ConversionUtils.h", + "CustomPatternApplicatorPassBase.h", "ElementPackingUtils.h", "FlatbufferUtils.h", "IndexSet.h", diff --git a/compiler/src/iree/compiler/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Utils/CMakeLists.txt index d7ef3b56ddc20..77e990481e539 100644 --- a/compiler/src/iree/compiler/Utils/CMakeLists.txt +++ b/compiler/src/iree/compiler/Utils/CMakeLists.txt @@ -15,6 +15,7 @@ iree_cc_library( Utils HDRS "ConversionUtils.h" + "CustomPatternApplicatorPassBase.h" "ElementPackingUtils.h" "FlatbufferUtils.h" "IndexSet.h" @@ -27,6 +28,7 @@ iree_cc_library( "TracingUtils.h" SRCS "ConversionUtils.cpp" + "CustomPatternApplicatorPassBase.cpp" "ElementPackingUtils.cpp" "FlatbufferUtils.cpp" "ModuleUtils.cpp" diff --git a/compiler/src/iree/compiler/Utils/CustomPatternApplicatorPassBase.cpp b/compiler/src/iree/compiler/Utils/CustomPatternApplicatorPassBase.cpp new file mode 100644 index 0000000000000..f1e687eb5a33e --- /dev/null +++ b/compiler/src/iree/compiler/Utils/CustomPatternApplicatorPassBase.cpp @@ -0,0 +1,93 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_UTILS_CUSTOMPATTERNAPPLICATORPASSBASE_H_ +#define IREE_COMPILER_UTILS_CUSTOMPATTERNAPPLICATORPASSBASE_H_ + +#include "iree/compiler/Utils/CustomPatternApplicatorPassBase.h" +#include "llvm/Support/SourceMgr.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace iree_compiler { +namespace detail { +/// Helper to get a list of sizes from the given RankedTensorType value. +static FailureOr getValueRangeTensorSizes(PatternRewriter &rewriter, + ValueRange vals) { + SmallVector flatTensorSizes; + for (auto val : vals) { + if (auto tensorType = dyn_cast(val.getType())) { + for (int64_t i = 0; i < tensorType.getRank(); ++i) { + flatTensorSizes.push_back( + rewriter.create(val.getLoc(), val, i)); + } + } + } + return ValueRange(flatTensorSizes); +} + +static FailureOr getI32TensorSizes(PatternRewriter &rewriter, + ValueRange vals) { + SmallVector flatI32TensorSizes; + for (auto val : vals) { + flatI32TensorSizes.push_back(rewriter.create( + val.getLoc(), rewriter.getIntegerType(32), val)); + } + return ValueRange(flatI32TensorSizes); +} + +static FailureOr extractValueFromRange(PatternRewriter &rewriter, + ValueRange vals, + IntegerAttr index) { + if (index.getInt() >= vals.size()) + return failure(); + return vals[index.getInt()]; +} + +void populateCommonNativeRewriteHelpers(RewritePatternSet &patterns) { + mlir::registerConversionPDLFunctions(patterns); + patterns.getPDLPatterns().registerRewriteFunction("extract_value", + extractValueFromRange); + patterns.getPDLPatterns().registerRewriteFunction("get_tensor_sizes", + getValueRangeTensorSizes); + patterns.getPDLPatterns().registerRewriteFunction("convert_index_to_i32", + getI32TensorSizes); +} + +LogicalResult populatePDLModuleFromFileName(MLIRContext *context, + RewritePatternSet &patterns, + llvm::StringRef pdlModuleFileName) { + std::string errorMessage; + auto memoryBuffer = mlir::openInputFile(pdlModuleFileName, &errorMessage); + if (!memoryBuffer) { + return emitError(FileLineColLoc::get( + StringAttr::get(context, pdlModuleFileName), 0, 0)) + << "failed to open pattern module file: " << errorMessage; + } + // Tell sourceMgr about this buffer, the parser will pick it up. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc()); + PDLPatternModule pdlModule = + OwningOpRef(parseSourceFile(sourceMgr, context)); + patterns.insert(std::move(pdlModule)); + return success(); +} + +} // namespace detail +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_UTILS_CUSTOMPATTERNAPPLICATORPASSBASE_H_ diff --git a/compiler/src/iree/compiler/Utils/CustomPatternApplicatorPassBase.h b/compiler/src/iree/compiler/Utils/CustomPatternApplicatorPassBase.h new file mode 100644 index 0000000000000..e2024e61d24bf --- /dev/null +++ b/compiler/src/iree/compiler/Utils/CustomPatternApplicatorPassBase.h @@ -0,0 +1,95 @@ +//===- PatternApplicatorPassBase.h ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Base class with shared implementation for custom pattern applicator +// passes. +// +//===----------------------------------------------------------------------===// + +#ifndef IREE_COMPILER_UTILS_CUSTOMPATTERNAPPLICATORPASSBASE_H_ +#define IREE_COMPILER_UTILS_CUSTOMPATTERNAPPLICATORPASSBASE_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { + +using transform::TransformOptions; + +namespace iree_compiler { + +namespace detail { +void populateCommonNativeRewriteHelpers(RewritePatternSet &patterns); + +LogicalResult populatePDLModuleFromFileName(MLIRContext *context, + RewritePatternSet &patterns, + llvm::StringRef pdlModuleFileName); +} // namespace detail + +template typename GeneratedBase> +class PatternApplicatorPassBase : public GeneratedBase { +public: + explicit PatternApplicatorPassBase( + const TransformOptions &options = TransformOptions()) + : options(options) {} + + PatternApplicatorPassBase(const PatternApplicatorPassBase &pass) { + options = pass.options; + } + + LogicalResult initialize(MLIRContext *context) override { + RewritePatternSet tmpPatterns(context); + detail::populateCommonNativeRewriteHelpers(tmpPatterns); + if (failed(static_cast(this)->initializePatterns( + context, tmpPatterns))) { + return failure(); + } + patterns = std::move(tmpPatterns); + return success(); + } + + // Hook for populating necessary library constraints/rewrites for the pattern + // applicator at initialization time, as well as setting up the type + // converter. + LogicalResult initializePatterns(MLIRContext *context, + RewritePatternSet &tmpPatterns) { + return success(); + } + + void runOnOperation() override { + auto *pass = static_cast(this); + Operation *op = pass->getOperation(); + + /// If there are no patterns nothing to do. + if (!patterns.getPDLByteCode()) { + return; + } + if (failed(applyPatternsAndFoldGreedily(op, patterns))) + return pass->signalPassFailure(); + } + +private: + /// Pattern applicator options. + TransformOptions options; + + FrozenRewritePatternSet patterns; +}; + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_UTILS_CUSTOMPATTERNAPPLICATORPASSBASE_H_ diff --git a/samples/custom_dispatch/vulkan/shaders/CMakeLists.txt b/samples/custom_dispatch/vulkan/shaders/CMakeLists.txt index ce15498a25ba0..360c35bcc6b23 100644 --- a/samples/custom_dispatch/vulkan/shaders/CMakeLists.txt +++ b/samples/custom_dispatch/vulkan/shaders/CMakeLists.txt @@ -50,8 +50,10 @@ iree_lit_test_suite( SRCS "example.mlir" "example_inline.mlir" + "example_apply.mlir" DATA ${_SPV_TARGET} + "pattern_module.mlir" TOOLS FileCheck iree-compile diff --git a/samples/custom_dispatch/vulkan/shaders/example_apply.mlir b/samples/custom_dispatch/vulkan/shaders/example_apply.mlir new file mode 100644 index 0000000000000..f749427947700 --- /dev/null +++ b/samples/custom_dispatch/vulkan/shaders/example_apply.mlir @@ -0,0 +1,55 @@ +// RUN: iree-compile %s \ +// RUN: --iree-hal-executable-object-search-path=$IREE_BINARY_DIR \ +// RUN: --iree-opt-extern-dispatch-pattern-module=%p/pattern_module.mlir | \ +// RUN: iree-run-module \ +// RUN: --device=vulkan \ +// RUN: --module=- \ +// RUN: --function=mixed_invocation \ +// RUN: --input=8xf32=2 \ +// RUN: --input=8xf32=4 | \ +// RUN: FileCheck %s + +// The configuration used for executable compilation. +// This lets the compiler and runtime know the format and requirements of the +// executable binaries produced and multiple variants with differing formats +// and compilation options (architectures, etc) can be embedded for runtime +// selection. +#spirv_target = #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env< + #spirv.vce, + #spirv.resource_limits + > +}> + +// The target devices that the program will run on. +// These can come from compiler flags and multiple targets can be supported +// It's possible, for example, to support targeting multiple devices in the same +// compiled binary. +#vulkan_target = #hal.device.target<"vulkan", { + executable_targets = [#spirv_target], + // HACK: Vulkan target currently uses the legacy synchronous execution model. + legacy_sync +}> + +module @example attributes {hal.device.targets = [#vulkan_target]} { + + // Function demonstrating replacing a kernel with a hand-written implementation. + // Invoke with: + // --device=vulkan + // --function=mixed_invocation + // --input=8xf32=2 + // --input=8xf32=4 + // CHECK-LABEL: EXEC @mixed_invocation + func.func @mixed_invocation(%arg0: tensor, %arg1: tensor) -> tensor { + // Target to match and replace with a hand written dispatch + %0 = arith.mulf %arg0, %arg1 : tensor + + // Code gen some other ops - these will interleave with the hand-authored + // ones but naturally won't be able to fuse with them. + %1 = arith.addf %0, %arg1 : tensor + + // CHECK: 8xf32=12 12 12 12 12 12 12 12 + return %1 : tensor + } + +} // module diff --git a/samples/custom_dispatch/vulkan/shaders/pattern_module.mlir b/samples/custom_dispatch/vulkan/shaders/pattern_module.mlir new file mode 100644 index 0000000000000..aac14590e8286 --- /dev/null +++ b/samples/custom_dispatch/vulkan/shaders/pattern_module.mlir @@ -0,0 +1,97 @@ +// RUN: iree-opt %s + +// The required configuration for the custom dispatch. This tells the compiler +// the requisite target information needed to support the associated custom +// shader. +#spirv_target = #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env< + #spirv.vce, + #spirv.resource_limits + > +}> + +#layout = #hal.pipeline.layout, + <1, storage_buffer, ReadOnly>, + <2, storage_buffer> + ]> +]> +#bindings = [ + #hal.interface.binding<0, 0>, + #hal.interface.binding<0, 1>, + #hal.interface.binding<0, 2> +] +#objects = #hal.executable.objects<{ + #spirv_target = [ + #hal.executable.object<{ + path = "samples/custom_dispatch/vulkan/shaders/simple_mul.spv" + }> + ] +}> + +#attrdict = { + export = "main", + layout = #layout, + bindings = #bindings, + objects = #objects +} + +module { + pdl.pattern : benefit(1) { + %lhs_type = pdl.type : tensor + %rhs_type = pdl.type : tensor + %out_type = pdl.type : tensor + %lhs = pdl.operand : %lhs_type + %rhs = pdl.operand : %rhs_type + + // Match the target operation(s) for rewriting and the original arguments and result types. + %mul = pdl.operation "arith.mulf" (%lhs, %rhs : !pdl.value, !pdl.value) -> (%out_type : !pdl.type) + + pdl.rewrite %mul { + // Constant attributes for rewriting. `attrdict` contains the layout/binding/object + // required for the custom dispatch. + %attrdict = pdl.attribute = #attrdict + %tied_operands = pdl.attribute = array + %c1_idx = pdl.attribute = 1 : index + %apply_map = pdl.attribute = affine_map<()[s0] -> (s0 ceildiv 64)> + + %range = pdl.range %lhs : !pdl.value + %res_type_range = pdl.range %out_type : !pdl.type + %workload = pdl.apply_native_rewrite "get_tensor_sizes"(%range : !pdl.range) : !pdl.range + %dims = pdl.apply_native_rewrite "convert_index_to_i32"(%workload : !pdl.range) : !pdl.range + %arg_range = pdl.range %dims, %lhs, %rhs : !pdl.range, !pdl.value, !pdl.value + %arg_dims = pdl.apply_native_rewrite "get_tensor_sizes"(%arg_range : !pdl.range) : !pdl.range + + // Create the extern dispatch op. The workgroup count region has not been constructed at + // this point. + %extern = pdl.apply_native_rewrite "create_dispatch_extern"( + %workload, %res_type_range, %workload, + %arg_range, %arg_dims, %tied_operands, + %attrdict : !pdl.range, !pdl.range, !pdl.range, + !pdl.range, !pdl.range, !pdl.attribute, + !pdl.attribute) : !pdl.operation + + // Emplace the workgroup count region based on the workload, return handles to the new block + // arguments, and set the insertion point to the new block. + %wkg_args = pdl.apply_native_rewrite "emplace_extern_workgroup_count"(%extern : !pdl.operation) : !pdl.range + + // Workgroup count of 1 along y and z. + %c1_op = pdl.operation "arith.constant" {"value" = %c1_idx} + %c1 = pdl.result 0 of %c1_op + + // Extract the workload argument; pdl_interp provides an extract op, however it isn't + // allowed to be used inside a pdl.pattern region, hence this hack. + %workload_arg = pdl.apply_native_rewrite "extract_value"(%wkg_args, %c1_idx : !pdl.range, !pdl.attribute) : !pdl.value + + // Compute x and create the terminator of the workgroup count region. + %index_type = pdl.type : index + %x_op = pdl.operation "affine.apply"(%workload_arg : !pdl.value) {"map" = %apply_map} -> (%index_type : !pdl.type) + %x = pdl.result 0 of %x_op + %res = pdl.operation "hal.return"(%x, %c1, %c1 : !pdl.value, !pdl.value, !pdl.value) + + %new_result = pdl.result 0 of %extern + pdl.replace %mul with (%new_result : !pdl.value) + } + } +}