diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index 018cc77051ad..52c6b8fa13c4 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -522,6 +522,16 @@ SmallVector TensorBarrierOp::getTiedResultOperandIndices() { // hal.dispatch.extern //===----------------------------------------------------------------------===// +void DispatchExternOp::build(OpBuilder &builder, OperationState &state, + ValueRange workload, TypeRange resultTypes, + ValueRange resultDims, ValueRange arguments, + ValueRange argumentDims, + DenseI64ArrayAttr tiedOperands, + DictionaryAttr attributes) { + build(builder, state, workload, resultTypes, resultDims, arguments, + argumentDims, tiedOperands.asArrayRef(), attributes.getValue()); +} + void DispatchExternOp::build(OpBuilder &builder, OperationState &state, ValueRange workload, TypeRange resultTypes, ValueRange resultDims, ValueRange arguments, @@ -550,6 +560,24 @@ void DispatchExternOp::build(OpBuilder &builder, OperationState &state, state.addRegion(); } +/// Helper to emplace a block on the given hal.dispatch.extern op. This returns +/// the entry block of the updated workgroup count region and sets up the +/// block arguments as a !hal.device + inferred from the workload types. +/// The workgroup count region will not include the terminator and that is left +/// up to the user to properly populate. +Block *DispatchExternOp::emplaceWorkgroupCountRegion(OpBuilder &builder) { + SmallVector countTypes({builder.getType()}); + SmallVector countLocs({getLoc()}); + for (auto workloadIdx : getWorkload()) { + countTypes.push_back(workloadIdx.getType()); + countLocs.push_back(workloadIdx.getLoc()); + } + + auto &entryBlock = getWorkgroupCount().emplaceBlock(); + entryBlock.addArguments(countTypes, countLocs); + return &entryBlock; +} + // Verifies that |dynamicDims| contains the appropriate number of dims for all // of the dynamic dimensions in |values|. static LogicalResult verifyOpDynamicDims(Operation *op, ValueRange values, diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index 29f94572dbae..8384024cd085 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -403,6 +403,12 @@ def HAL_DispatchExternOp : HAL_PureOp<"dispatch.extern", [ let skipDefaultBuilders = 1; let builders = [ + OpBuilder<(ins + "ValueRange":$workload, + "TypeRange":$resultTypes, "ValueRange":$resultDims, + "ValueRange":$arguments, "ValueRange":$argumentDims, + "DenseI64ArrayAttr":$tiedOperands, + CArg<"DictionaryAttr", "nullptr">:$attributes)>, OpBuilder<(ins "ValueRange":$workload, "TypeRange":$resultTypes, "ValueRange":$resultDims, @@ -420,6 +426,8 @@ def HAL_DispatchExternOp : HAL_PureOp<"dispatch.extern", [ getResultTypes()); } + Block *emplaceWorkgroupCountRegion(OpBuilder &builder); + /// Returns the index of the args() operand in the Operation operands list. unsigned mapArgOperandToOpOperand(unsigned i) { return i + getWorkload().size(); }; diff --git a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel index ac282d8cca16..08d3506ef446 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel +++ b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel @@ -48,6 +48,7 @@ iree_compiler_cc_library( "DetachElementwiseFromNamedOps.cpp", "EraseUnusedLinalgOperands.cpp", "ExpandVectors.cpp", + "MaterializeExternDispatches.cpp", "MaterializeHomogeneousEncodings.cpp", "Passes.cpp", "RemoveZeroExtentTensors.cpp", @@ -82,6 +83,8 @@ iree_compiler_cc_library( "@llvm-project//mlir:LinalgUtils", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:MemRefTransforms", + "@llvm-project//mlir:PDLDialect", + "@llvm-project//mlir:PDLInterpDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TensorTransforms", diff --git a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt index f67edf23b1c8..625d2ba9157e 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt +++ b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt @@ -43,6 +43,7 @@ iree_cc_library( "DetachElementwiseFromNamedOps.cpp" "EraseUnusedLinalgOperands.cpp" "ExpandVectors.cpp" + "MaterializeExternDispatches.cpp" "MaterializeHomogeneousEncodings.cpp" "Passes.cpp" "RemoveZeroExtentTensors.cpp" @@ -65,6 +66,8 @@ iree_cc_library( MLIRLinalgUtils MLIRMemRefDialect MLIRMemRefTransforms + MLIRPDLDialect + MLIRPDLInterpDialect MLIRPass MLIRTensorDialect MLIRTensorTransforms diff --git a/compiler/src/iree/compiler/GlobalOptimization/MaterializeExternDispatches.cpp b/compiler/src/iree/compiler/GlobalOptimization/MaterializeExternDispatches.cpp new file mode 100644 index 000000000000..3bc9c479385a --- /dev/null +++ b/compiler/src/iree/compiler/GlobalOptimization/MaterializeExternDispatches.cpp @@ -0,0 +1,101 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/GlobalOptimization/PassDetail.h" +#include "iree/compiler/Utils/CustomPatternApplicatorPassBase.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace iree_compiler { +namespace GlobalOptimization { + +namespace { + +/// Wrapper to build a hal.dispatch.extern op with the given arguments. +static Operation * +createDispatchExtern(PatternRewriter &rewriter, Operation *target, + ValueRange workload, TypeRange resultTypes, + ValueRange resultDims, ValueRange arguments, + ValueRange argumentDims, DenseI64ArrayAttr tiedOperands, + DictionaryAttr attrDict) { + return rewriter.create( + target->getLoc(), workload, resultTypes, resultDims, arguments, + argumentDims, tiedOperands, attrDict); +} + +/// Helper to emplace a block on the given hal.dispatch.extern op. This returns +/// the block arguments of the updated workgroup count region. Note that the +/// workgroup count region will not include the terminator and that is left up +/// to the user to properly populate. +static FailureOr +emplaceExternWorkgroupCountRegion(PatternRewriter &rewriter, Operation *op) { + auto externOp = dyn_cast(op); + if (!externOp) { + return failure(); + } + + Block *entryBlock = externOp.emplaceWorkgroupCountRegion(rewriter); + + /// Update the insertion point to the beginning of the block to enable + /// contructing the workgroup count region. + rewriter.setInsertionPointToStart(entryBlock); + return ValueRange(entryBlock->getArguments()); +} + +static void +registerExternDispatchRewriteFunction(PDLPatternModule &pdlPatterns) { + pdlPatterns.registerRewriteFunction("create_dispatch_extern", + createDispatchExtern); + pdlPatterns.registerRewriteFunction("emplace_extern_workgroup_count", + emplaceExternWorkgroupCountRegion); +} + +} // namespace + +class MaterializeExternDispatchesPass + : public iree_compiler::PatternApplicatorPassBase< + MaterializeExternDispatchesPass, + iree_compiler::GlobalOptimization:: + MaterializeExternDispatchesPassBase> { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + LogicalResult initializePatterns(MLIRContext *context, + RewritePatternSet &tmpPatterns) { + registerExternDispatchRewriteFunction(tmpPatterns.getPDLPatterns()); + for (auto fileName : this->pdlModuleFileNames) { + if (failed(iree_compiler::detail::populatePDLModuleFromFileName( + context, tmpPatterns, fileName))) { + return failure(); + } + } + return success(); + } + + MaterializeExternDispatchesPass(ArrayRef pdlModuleFileNames) { + this->pdlModuleFileNames = pdlModuleFileNames; + } + MaterializeExternDispatchesPass(const MaterializeExternDispatchesPass &pass) = + default; +}; + +std::unique_ptr createMaterializeExternDispatchesPass( + ArrayRef pdlModuleFileNames) { + return std::make_unique(pdlModuleFileNames); +} +} // namespace GlobalOptimization +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index 082740819f70..d9de2a12f1c7 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -43,6 +43,11 @@ void buildGlobalOptimizationPassPipeline( mainPassManager.addPass(IREE::Util::createDemoteI64ToI32Pass()); } + if (!transformOptions.options.customDispatchPatternModuleFileNames.empty()) { + mainPassManager.addPass(createMaterializeExternDispatchesPass( + transformOptions.options.customDispatchPatternModuleFileNames)); + } + // Preprocessing passes to get the program into a canonical state. FunctionLikeNest(mainPassManager) .addPass(createRemoveZeroExtentTensorsPass) diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.h b/compiler/src/iree/compiler/GlobalOptimization/Passes.h index f92b120858a2..35e6f4c31bf2 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.h +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.h @@ -58,6 +58,12 @@ createEraseUnusedLinalgOperands(); // forms. std::unique_ptr createExpandVectorsPass(); +// Materializes logical encodings to physical encodings if there is a single +// device target. +std::unique_ptr> +createMaterializeExternDispatchesPass( + ArrayRef pdlModuleFileName = {}); + // Materializes logical encodings to physical encodings if there is a single // device target. std::unique_ptr> diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.td b/compiler/src/iree/compiler/GlobalOptimization/Passes.td index c9a5d4bdc4d5..5ed2a42f35a0 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.td +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.td @@ -32,6 +32,18 @@ def ExpandVectors : Pass<"iree-global-opt-expand-vectors", ""> { let constructor = "mlir::iree_compiler::GlobalOptimization::createExpandVectorsPass()"; } +def MaterializeExternDispatchesPass : + Pass<"iree-materialize-extern-dispatches", ""> { + let summary = "Pass to form custom external dispatches"; + let constructor = + "mlir::iree_compiler::GlobalOptimization::createMaterializeExternDispatchesPass()"; + let options = [ + ListOption<"pdlModuleFileNames", "pdl-module-file-names", "std::string", + "Optional files to load a pdl module from.", + "llvm::cl::ZeroOrMore"> + ]; +} + def MaterializeHomogeneousEncodings : Pass<"iree-global-opt-materialize-homogeneous-encodings", "mlir::ModuleOp"> { let summary = "Materializes logical encodings to physical encodings if there is a single device target."; diff --git a/compiler/src/iree/compiler/Pipelines/Options.cpp b/compiler/src/iree/compiler/Pipelines/Options.cpp index 5d4f8dccf603..98bb644546de 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.cpp +++ b/compiler/src/iree/compiler/Pipelines/Options.cpp @@ -152,6 +152,11 @@ void GlobalOptimizationOptions::bindOptions(OptionsBinder &binder) { llvm::cl::desc("Strips debug assertions after any useful " "information has been extracted."), llvm::cl::cat(category)); + binder.list( + "iree-opt-extern-dispatch-pattern-module", + customDispatchPatternModuleFileNames, + llvm::cl::desc("File path to custom dispatch rewrite pattern module."), + llvm::cl::ZeroOrMore, llvm::cl::cat(category)); } void SchedulingOptions::bindOptions(OptionsBinder &binder) { diff --git a/compiler/src/iree/compiler/Pipelines/Options.h b/compiler/src/iree/compiler/Pipelines/Options.h index 0909ab8bb82e..4a2b3c9053f7 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.h +++ b/compiler/src/iree/compiler/Pipelines/Options.h @@ -99,6 +99,9 @@ struct GlobalOptimizationOptions { // allow hoisting. The threshold is 1MB by default. int64_t constExprMaxSizeIncreaseThreshold = 1024 * 1024; + // File paths to load custom dispatch rewrite patterns from. + std::vector customDispatchPatternModuleFileNames = {}; + void bindOptions(OptionsBinder &binder); using FromFlags = OptionsFromFlags; }; diff --git a/compiler/src/iree/compiler/Utils/BUILD.bazel b/compiler/src/iree/compiler/Utils/BUILD.bazel index aaa6ca5ecccc..2d6e2e48512c 100644 --- a/compiler/src/iree/compiler/Utils/BUILD.bazel +++ b/compiler/src/iree/compiler/Utils/BUILD.bazel @@ -18,6 +18,7 @@ iree_compiler_cc_library( name = "Utils", srcs = [ "ConversionUtils.cpp", + "CustomPatternApplicatorPassBase.cpp", "ElementPackingUtils.cpp", "FlatbufferUtils.cpp", "ModuleUtils.cpp", @@ -29,6 +30,7 @@ iree_compiler_cc_library( ], hdrs = [ "ConversionUtils.h", + "CustomPatternApplicatorPassBase.h", "ElementPackingUtils.h", "FlatbufferUtils.h", "IndexSet.h", @@ -50,9 +52,11 @@ iree_compiler_cc_library( "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:PDLDialect", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], diff --git a/compiler/src/iree/compiler/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Utils/CMakeLists.txt index d7ef3b56ddc2..2923a2e8bb11 100644 --- a/compiler/src/iree/compiler/Utils/CMakeLists.txt +++ b/compiler/src/iree/compiler/Utils/CMakeLists.txt @@ -15,6 +15,7 @@ iree_cc_library( Utils HDRS "ConversionUtils.h" + "CustomPatternApplicatorPassBase.h" "ElementPackingUtils.h" "FlatbufferUtils.h" "IndexSet.h" @@ -27,6 +28,7 @@ iree_cc_library( "TracingUtils.h" SRCS "ConversionUtils.cpp" + "CustomPatternApplicatorPassBase.cpp" "ElementPackingUtils.cpp" "FlatbufferUtils.cpp" "ModuleUtils.cpp" @@ -40,9 +42,11 @@ iree_cc_library( MLIRArithDialect MLIRFuncDialect MLIRIR + MLIRPDLDialect MLIRParser MLIRPass MLIRSupport + MLIRTensorDialect MLIRTransformUtils MLIRTransforms iree::base diff --git a/compiler/src/iree/compiler/Utils/CustomPatternApplicatorPassBase.cpp b/compiler/src/iree/compiler/Utils/CustomPatternApplicatorPassBase.cpp new file mode 100644 index 000000000000..5029d0683a92 --- /dev/null +++ b/compiler/src/iree/compiler/Utils/CustomPatternApplicatorPassBase.cpp @@ -0,0 +1,99 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_UTILS_CUSTOMPATTERNAPPLICATORPASSBASE_H_ +#define IREE_COMPILER_UTILS_CUSTOMPATTERNAPPLICATORPASSBASE_H_ + +#include "iree/compiler/Utils/CustomPatternApplicatorPassBase.h" +#include "llvm/Support/SourceMgr.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace iree_compiler { +namespace detail { + +/// Helper to get a list of sizes from the given RankedTensorType value. +static SmallVector getValueRangeTensorSizes(PatternRewriter &rewriter, + ValueRange vals) { + SmallVector flatTensorSizes; + for (auto val : vals) { + if (auto tensorType = dyn_cast(val.getType())) { + for (int64_t i = 0; i < tensorType.getRank(); ++i) { + flatTensorSizes.push_back( + rewriter.create(val.getLoc(), val, i).getResult()); + } + } + } + return flatTensorSizes; +} + +static SmallVector getI32TensorSizes(PatternRewriter &rewriter, + ValueRange vals) { + SmallVector flatI32TensorSizes; + for (auto val : vals) { + if (isa(val.getType())) { + flatI32TensorSizes.push_back( + rewriter + .create(val.getLoc(), + rewriter.getIntegerType(32), val) + .getResult()); + } + } + return flatI32TensorSizes; +} + +static FailureOr extractValueFromRange(PatternRewriter &rewriter, + ValueRange vals, Attribute attr) { + IntegerAttr index = dyn_cast(attr); + if (!index || index.getInt() >= vals.size()) + return failure(); + return vals[index.getInt()]; +} + +void populateCommonNativeRewriteHelpers(RewritePatternSet &patterns) { + mlir::registerConversionPDLFunctions(patterns); + patterns.getPDLPatterns().registerRewriteFunction("extract_value", + extractValueFromRange); + patterns.getPDLPatterns().registerRewriteFunction("get_tensor_sizes", + getValueRangeTensorSizes); + patterns.getPDLPatterns().registerRewriteFunction("convert_index_to_i32", + getI32TensorSizes); +} + +LogicalResult populatePDLModuleFromFileName(MLIRContext *context, + RewritePatternSet &patterns, + llvm::StringRef pdlModuleFileName) { + std::string errorMessage; + auto memoryBuffer = mlir::openInputFile(pdlModuleFileName, &errorMessage); + if (!memoryBuffer) { + return emitError(FileLineColLoc::get( + StringAttr::get(context, pdlModuleFileName), 0, 0)) + << "failed to open pattern module file: " << errorMessage; + } + // Tell sourceMgr about this buffer, the parser will pick it up. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc()); + PDLPatternModule pdlModule = + OwningOpRef(parseSourceFile(sourceMgr, context)); + patterns.insert(std::move(pdlModule)); + return success(); +} + +} // namespace detail +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_UTILS_CUSTOMPATTERNAPPLICATORPASSBASE_H_ diff --git a/compiler/src/iree/compiler/Utils/CustomPatternApplicatorPassBase.h b/compiler/src/iree/compiler/Utils/CustomPatternApplicatorPassBase.h new file mode 100644 index 000000000000..e2024e61d24b --- /dev/null +++ b/compiler/src/iree/compiler/Utils/CustomPatternApplicatorPassBase.h @@ -0,0 +1,95 @@ +//===- PatternApplicatorPassBase.h ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Base class with shared implementation for custom pattern applicator +// passes. +// +//===----------------------------------------------------------------------===// + +#ifndef IREE_COMPILER_UTILS_CUSTOMPATTERNAPPLICATORPASSBASE_H_ +#define IREE_COMPILER_UTILS_CUSTOMPATTERNAPPLICATORPASSBASE_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { + +using transform::TransformOptions; + +namespace iree_compiler { + +namespace detail { +void populateCommonNativeRewriteHelpers(RewritePatternSet &patterns); + +LogicalResult populatePDLModuleFromFileName(MLIRContext *context, + RewritePatternSet &patterns, + llvm::StringRef pdlModuleFileName); +} // namespace detail + +template typename GeneratedBase> +class PatternApplicatorPassBase : public GeneratedBase { +public: + explicit PatternApplicatorPassBase( + const TransformOptions &options = TransformOptions()) + : options(options) {} + + PatternApplicatorPassBase(const PatternApplicatorPassBase &pass) { + options = pass.options; + } + + LogicalResult initialize(MLIRContext *context) override { + RewritePatternSet tmpPatterns(context); + detail::populateCommonNativeRewriteHelpers(tmpPatterns); + if (failed(static_cast(this)->initializePatterns( + context, tmpPatterns))) { + return failure(); + } + patterns = std::move(tmpPatterns); + return success(); + } + + // Hook for populating necessary library constraints/rewrites for the pattern + // applicator at initialization time, as well as setting up the type + // converter. + LogicalResult initializePatterns(MLIRContext *context, + RewritePatternSet &tmpPatterns) { + return success(); + } + + void runOnOperation() override { + auto *pass = static_cast(this); + Operation *op = pass->getOperation(); + + /// If there are no patterns nothing to do. + if (!patterns.getPDLByteCode()) { + return; + } + if (failed(applyPatternsAndFoldGreedily(op, patterns))) + return pass->signalPassFailure(); + } + +private: + /// Pattern applicator options. + TransformOptions options; + + FrozenRewritePatternSet patterns; +}; + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_UTILS_CUSTOMPATTERNAPPLICATORPASSBASE_H_ diff --git a/samples/custom_dispatch/vulkan/shaders/CMakeLists.txt b/samples/custom_dispatch/vulkan/shaders/CMakeLists.txt index ce15498a25ba..6aca03859c87 100644 --- a/samples/custom_dispatch/vulkan/shaders/CMakeLists.txt +++ b/samples/custom_dispatch/vulkan/shaders/CMakeLists.txt @@ -50,8 +50,10 @@ iree_lit_test_suite( SRCS "example.mlir" "example_inline.mlir" + "example_pattern_module.mlir" DATA ${_SPV_TARGET} + "example_patterns.mlir" TOOLS FileCheck iree-compile diff --git a/samples/custom_dispatch/vulkan/shaders/example_pattern_module.mlir b/samples/custom_dispatch/vulkan/shaders/example_pattern_module.mlir new file mode 100644 index 000000000000..92c0c732522d --- /dev/null +++ b/samples/custom_dispatch/vulkan/shaders/example_pattern_module.mlir @@ -0,0 +1,55 @@ +// RUN: iree-compile %s \ +// RUN: --iree-hal-executable-object-search-path=$IREE_BINARY_DIR \ +// RUN: --iree-opt-extern-dispatch-pattern-module=%p/example_patterns.mlir | \ +// RUN: iree-run-module \ +// RUN: --device=vulkan \ +// RUN: --module=- \ +// RUN: --function=mixed_invocation \ +// RUN: --input=8xf32=2 \ +// RUN: --input=8xf32=4 | \ +// RUN: FileCheck %s + +// The configuration used for executable compilation. +// This lets the compiler and runtime know the format and requirements of the +// executable binaries produced and multiple variants with differing formats +// and compilation options (architectures, etc) can be embedded for runtime +// selection. +#spirv_target = #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env< + #spirv.vce, + #spirv.resource_limits + > +}> + +// The target devices that the program will run on. +// These can come from compiler flags and multiple targets can be supported +// It's possible, for example, to support targeting multiple devices in the same +// compiled binary. +#vulkan_target = #hal.device.target<"vulkan", { + executable_targets = [#spirv_target], + // HACK: Vulkan target currently uses the legacy synchronous execution model. + legacy_sync +}> + +module @example attributes {hal.device.targets = [#vulkan_target]} { + + // Function demonstrating replacing a kernel with a hand-written implementation. + // Invoke with: + // --device=vulkan + // --function=mixed_invocation + // --input=8xf32=2 + // --input=8xf32=4 + // CHECK-LABEL: EXEC @mixed_invocation + func.func @mixed_invocation(%arg0: tensor, %arg1: tensor) -> tensor { + // Target to match and replace with a hand written dispatch + %0 = arith.mulf %arg0, %arg1 : tensor + + // Code gen some other ops - these will interleave with the hand-authored + // ones but naturally won't be able to fuse with them. + %1 = arith.addf %0, %arg1 : tensor + + // CHECK: 8xf32=12 12 12 12 12 12 12 12 + return %1 : tensor + } + +} // module diff --git a/samples/custom_dispatch/vulkan/shaders/example_patterns.mlir b/samples/custom_dispatch/vulkan/shaders/example_patterns.mlir new file mode 100644 index 000000000000..3bc5b3829b89 --- /dev/null +++ b/samples/custom_dispatch/vulkan/shaders/example_patterns.mlir @@ -0,0 +1,97 @@ +// RUN: iree-opt %s + +// The required configuration for the custom dispatch. This tells the compiler +// the requisite target information needed to support the associated custom +// shader. +#spirv_target = #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env< + #spirv.vce, + #spirv.resource_limits + > +}> + +#layout = #hal.pipeline.layout, + <1, storage_buffer, ReadOnly>, + <2, storage_buffer> + ]> +]> +#bindings = [ + #hal.interface.binding<0, 0>, + #hal.interface.binding<0, 1>, + #hal.interface.binding<0, 2> +] +#objects = #hal.executable.objects<{ + #spirv_target = [ + #hal.executable.object<{ + path = "samples/custom_dispatch/vulkan/shaders/simple_mul.spv" + }> + ] +}> + +#attrdict = { + export = "main", + layout = #layout, + bindings = #bindings, + objects = #objects +} + +module { + pdl.pattern : benefit(1) { + %lhs_type = pdl.type : tensor + %rhs_type = pdl.type : tensor + %out_type = pdl.type : tensor + %lhs = pdl.operand : %lhs_type + %rhs = pdl.operand : %rhs_type + + // Match the target operation(s) for rewriting and the original arguments and result types. + %mul = pdl.operation "arith.mulf" (%lhs, %rhs : !pdl.value, !pdl.value) -> (%out_type : !pdl.type) + + pdl.rewrite %mul { + // Constant attributes for rewriting. `attrdict` contains the layout/binding/object + // required for the custom dispatch. + %attrdict = pdl.attribute = #attrdict + %tied_operands = pdl.attribute = array + %c1_idx = pdl.attribute = 1 : index + %apply_map = pdl.attribute = affine_map<()[s0] -> (s0 ceildiv 64)> + + %range = pdl.range %lhs : !pdl.value + %res_type_range = pdl.range %out_type : !pdl.type + %workload = pdl.apply_native_rewrite "get_tensor_sizes"(%range : !pdl.range) : !pdl.range + %new_dims = pdl.apply_native_rewrite "convert_index_to_i32"(%workload : !pdl.range) : !pdl.range + %arg_range = pdl.range %new_dims, %lhs, %rhs : !pdl.range, !pdl.value, !pdl.value + %arg_dims = pdl.apply_native_rewrite "get_tensor_sizes"(%arg_range : !pdl.range) : !pdl.range + + // Create the extern dispatch op. The workgroup count region has not been constructed at + // this point. + %extern = pdl.apply_native_rewrite "create_dispatch_extern"( + %mul, %workload, %res_type_range, %workload, + %arg_range, %arg_dims, %tied_operands, + %attrdict : !pdl.operation, !pdl.range, !pdl.range, !pdl.range, + !pdl.range, !pdl.range, !pdl.attribute, + !pdl.attribute) : !pdl.operation + + // Emplace the workgroup count region based on the workload, return handles to the new block + // arguments, and set the insertion point to the new block. + %wkg_args = pdl.apply_native_rewrite "emplace_extern_workgroup_count"(%extern : !pdl.operation) : !pdl.range + + // Workgroup count of 1 along y and z. + %c1_op = pdl.operation "arith.constant" {"value" = %c1_idx} + %c1 = pdl.result 0 of %c1_op + + // Extract the workload argument; pdl_interp provides an extract op, however it isn't + // allowed to be used inside a pdl.pattern region, hence this hack. + %workload_arg = pdl.apply_native_rewrite "extract_value"(%wkg_args, %c1_idx : !pdl.range, !pdl.attribute) : !pdl.value + + // Compute x and create the terminator of the workgroup count region. + %index_type = pdl.type : index + %x_op = pdl.operation "affine.apply"(%workload_arg : !pdl.value) {"map" = %apply_map} -> (%index_type : !pdl.type) + %x = pdl.result 0 of %x_op + %res = pdl.operation "hal.return"(%x, %c1, %c1 : !pdl.value, !pdl.value, !pdl.value) + + %new_result = pdl.result 0 of %extern + pdl.replace %mul with (%new_result : !pdl.value) + } + } +}