Skip to content

Commit 7e5f809

Browse files
committed
Add support for external custom dispatch rewrite patterns
Currently custom dispatches must be written in as a part of downstream flows (targeting `hal.dispatch.extern` or Flow ops directly). This limits how lightweight deployment of custom dispatches with nearly stock IREE flows. This adds a mechanism to plugin and apply pdl patterns completely separate from other IR. This allows hermetically sealing the details of a custom dispatch as a separate IR module that is easy to reuse across models. The intent is to follow up to this patch with similar passes in the various input dialects to allow different potential rewrite points, as well as turn some of the rewriting utilities into transform operations as another way to represent the dispatch + rewrite logic without the need for C++ and a downstram project.
1 parent 479f4ed commit 7e5f809

19 files changed

+563
-0
lines changed

compiler/src/iree/compiler/Dialect/HAL/Utils/BUILD.bazel

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@ package(
1414

1515
iree_compiler_cc_library(
1616
name = "Utils",
17+
srcs = [
18+
"ExternBuildingUtils.cpp",
19+
],
1720
hdrs = [
1821
"DeviceSwitchBuilder.h",
22+
"ExternBuildingUtils.h",
1923
],
2024
deps = [
2125
"//compiler/src/iree/compiler/Dialect/HAL/IR",

compiler/src/iree/compiler/Dialect/HAL/Utils/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ iree_cc_library(
1515
Utils
1616
HDRS
1717
"DeviceSwitchBuilder.h"
18+
"ExternBuildingUtils.h"
19+
SRCS
20+
"ExternBuildingUtils.cpp"
1821
DEPS
1922
LLVMSupport
2023
MLIRFuncDialect
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Copyright 2023 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#ifndef IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_
8+
#define IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_
9+
10+
#include "iree/compiler/Dialect/HAL/Utils/ExternBuildingUtils.h"
11+
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
12+
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
13+
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
14+
#include "mlir/IR/Builders.h"
15+
#include "mlir/IR/IRMapping.h"
16+
#include "mlir/IR/Location.h"
17+
#include "mlir/Pass/Pass.h"
18+
#include "mlir/Pass/PassManager.h"
19+
#include "mlir/Transforms/DialectConversion.h"
20+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21+
#include "mlir/Transforms/RegionUtils.h"
22+
23+
namespace mlir {
24+
namespace iree_compiler {
25+
namespace IREE {
26+
namespace HAL {
27+
28+
namespace {
29+
30+
/// Helper to build a hal.dispatch.extern op with the given arguments. This
31+
/// returns the block arguments of the workgroup count region and the extern
32+
/// op. Note that the workgroup count region will not include the terminator
33+
/// and that is left up to the user to properly populate.
34+
static FailureOr<Operation *>
35+
createDispatchExtern(PatternRewriter &rewriter, ValueRange workload,
36+
TypeRange resultTypes, ValueRange resultDims,
37+
ValueRange arguments, ValueRange argumentDims,
38+
DenseI64ArrayAttr tiedOperands, DictionaryAttr attrDict) {
39+
Location rootLoc = (*arguments.begin()).getLoc();
40+
SmallVector<int64_t> tiedOperandsIntList(tiedOperands.asArrayRef());
41+
SmallVector<NamedAttribute> namedAttributes(attrDict.begin(), attrDict.end());
42+
Operation *externOp = rewriter.create<DispatchExternOp>(
43+
rootLoc, workload, resultTypes, resultDims, arguments, argumentDims,
44+
tiedOperandsIntList, namedAttributes);
45+
return externOp;
46+
}
47+
48+
/// Helper to emplace a block on the given hal.dispatch.extern op. This returns
49+
/// the block arguments of the updated workgroup count region. Note that the
50+
/// workgroup count region will not include the terminator and that is left up
51+
/// to the user to properly populate.
52+
static FailureOr<ValueRange>
53+
emplaceExternWorkgroupCountRegion(PatternRewriter &rewriter, Operation *op) {
54+
Location rootLoc = op->getLoc();
55+
auto externOp = dyn_cast<DispatchExternOp>(op);
56+
if (!externOp) {
57+
return failure();
58+
}
59+
60+
SmallVector<Type> countTypes({rewriter.getType<IREE::HAL::DeviceType>()});
61+
SmallVector<Location> countLocs({rootLoc});
62+
for (auto workloadIdx : externOp.getWorkload()) {
63+
countTypes.push_back(workloadIdx.getType());
64+
countLocs.push_back(workloadIdx.getLoc());
65+
}
66+
67+
auto &entryBlock = externOp.getWorkgroupCount().emplaceBlock();
68+
auto countArgs = entryBlock.addArguments(countTypes, countLocs);
69+
70+
ArrayRef<BlockArgument> countArgsArray(countArgs.begin(), countArgs.end());
71+
72+
/// Update the insertion point to the beginning of the block to enable
73+
/// contructing the workgroup count region.
74+
rewriter.setInsertionPointToStart(&entryBlock);
75+
return ValueRange(countArgsArray);
76+
}
77+
} // namespace
78+
79+
void registerExternDispatchRewriteFunction(PDLPatternModule &pdlPatterns) {
80+
pdlPatterns.registerRewriteFunction("create_dispatch_extern",
81+
createDispatchExtern);
82+
pdlPatterns.registerRewriteFunction("emplace_extern_workgroup_count",
83+
emplaceExternWorkgroupCountRegion);
84+
}
85+
86+
} // namespace HAL
87+
} // namespace IREE
88+
} // namespace iree_compiler
89+
} // namespace mlir
90+
91+
#endif // IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright 2023 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#ifndef IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_
8+
#define IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_
9+
10+
#include "mlir/IR/Attributes.h"
11+
#include "mlir/IR/Builders.h"
12+
#include "mlir/IR/Operation.h"
13+
#include "mlir/IR/PatternMatch.h"
14+
15+
namespace mlir {
16+
namespace iree_compiler {
17+
namespace IREE {
18+
namespace HAL {
19+
20+
void registerExternDispatchRewriteFunction(PDLPatternModule &pdlPatterns);
21+
22+
} // namespace HAL
23+
} // namespace IREE
24+
} // namespace iree_compiler
25+
} // namespace mlir
26+
27+
#endif // IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_

compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ iree_compiler_cc_library(
4848
"DetachElementwiseFromNamedOps.cpp",
4949
"EraseUnusedLinalgOperands.cpp",
5050
"MaterializeHomogeneousEncodings.cpp",
51+
"MaterializeExternDispatches.cpp",
5152
"Passes.cpp",
5253
"RemoveZeroExtentTensors.cpp",
5354
"SetEncoding.cpp",
@@ -62,6 +63,7 @@ iree_compiler_cc_library(
6263
"//compiler/src/iree/compiler/Dialect/Flow/Transforms",
6364
"//compiler/src/iree/compiler/Dialect/HAL/IR",
6465
"//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect",
66+
"//compiler/src/iree/compiler/Dialect/HAL/Utils",
6567
"//compiler/src/iree/compiler/Dialect/Util/Transforms",
6668
"//compiler/src/iree/compiler/Pipelines:Options",
6769
"//compiler/src/iree/compiler/Utils",

compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ iree_cc_library(
4242
"Convert1X1FilterConv2DToMatmul.cpp"
4343
"DetachElementwiseFromNamedOps.cpp"
4444
"EraseUnusedLinalgOperands.cpp"
45+
"MaterializeExternDispatches.cpp"
4546
"MaterializeHomogeneousEncodings.cpp"
4647
"Passes.cpp"
4748
"RemoveZeroExtentTensors.cpp"
@@ -73,6 +74,7 @@ iree_cc_library(
7374
iree::compiler::Dialect::Flow::Transforms
7475
iree::compiler::Dialect::HAL::IR
7576
iree::compiler::Dialect::HAL::IR::HALDialect
77+
iree::compiler::Dialect::HAL::Utils
7678
iree::compiler::Dialect::Util::Transforms
7779
iree::compiler::Pipelines::Options
7880
iree::compiler::Utils
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// Copyright 2023 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
8+
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
9+
#include "iree/compiler/Dialect/HAL/Utils/ExternBuildingUtils.h"
10+
#include "iree/compiler/GlobalOptimization/PassDetail.h"
11+
#include "iree/compiler/Utils/CustomPatternApplicatorPassBase.h"
12+
#include "mlir/Dialect/Func/IR/FuncOps.h"
13+
#include "mlir/Dialect/PDL/IR/PDL.h"
14+
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15+
#include "mlir/Pass/Pass.h"
16+
17+
using namespace mlir;
18+
19+
namespace {
20+
21+
class MaterializeExternDispatchesPass
22+
: public iree_compiler::PatternApplicatorPassBase<
23+
MaterializeExternDispatchesPass,
24+
iree_compiler::GlobalOptimization::
25+
MaterializeExternDispatchesPassBase> {
26+
public:
27+
void getDependentDialects(DialectRegistry &registry) const override {
28+
registry.insert<mlir::iree_compiler::IREE::HAL::HALDialect, pdl::PDLDialect,
29+
pdl_interp::PDLInterpDialect>();
30+
}
31+
32+
LogicalResult initializePatterns(MLIRContext *context,
33+
RewritePatternSet &tmpPatterns) {
34+
iree_compiler::IREE::HAL::registerExternDispatchRewriteFunction(
35+
tmpPatterns.getPDLPatterns());
36+
return iree_compiler::detail::populatePDLModuleFromFileName(
37+
context, tmpPatterns, this->pdlModuleFileName);
38+
}
39+
40+
MaterializeExternDispatchesPass(StringRef pdlModuleFileName = StringRef()) {
41+
this->pdlModuleFileName = pdlModuleFileName.str();
42+
}
43+
MaterializeExternDispatchesPass(const MaterializeExternDispatchesPass &pass) =
44+
default;
45+
};
46+
} // namespace
47+
48+
namespace mlir {
49+
namespace iree_compiler {
50+
namespace GlobalOptimization {
51+
std::unique_ptr<Pass>
52+
createMaterializeExternDispatchesPass(std::string pdlModuleFileName) {
53+
return std::make_unique<MaterializeExternDispatchesPass>(pdlModuleFileName);
54+
}
55+
} // namespace GlobalOptimization
56+
} // namespace iree_compiler
57+
} // namespace mlir

compiler/src/iree/compiler/GlobalOptimization/Passes.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ void buildGlobalOptimizationPassPipeline(
4343
mainPassManager.addPass(IREE::Util::createDemoteI64ToI32Pass());
4444
}
4545

46+
if (!transformOptions.options.customDispatchPatternModuleFileName.empty()) {
47+
mainPassManager.addPass(createMaterializeExternDispatchesPass(
48+
transformOptions.options.customDispatchPatternModuleFileName));
49+
}
50+
4651
// Preprocessing passes to get the program into a canonical state.
4752
FunctionLikeNest(mainPassManager)
4853
.addPass(createRemoveZeroExtentTensorsPass)

compiler/src/iree/compiler/GlobalOptimization/Passes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ std::unique_ptr<Pass> createDetachElementwiseFromNamedOpsPass();
5454
std::unique_ptr<OperationPass<mlir::ModuleOp>>
5555
createEraseUnusedLinalgOperands();
5656

57+
// Materializes logical encodings to physical encodings if there is a single
58+
// device target.
59+
std::unique_ptr<OperationPass<mlir::ModuleOp>>
60+
createMaterializeExternDispatchesPass(std::string pdlModuleFileName = "");
61+
5762
// Materializes logical encodings to physical encodings if there is a single
5863
// device target.
5964
std::unique_ptr<OperationPass<mlir::ModuleOp>>

compiler/src/iree/compiler/GlobalOptimization/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
include "mlir/Pass/PassBase.td"
1111

12+
<<<<<<< HEAD
1213
def Convert1X1FilterConv2DToMatmul:
1314
Pass<"iree-global-opt-convert-1x1-filter-conv2d-to-matmul", ""> {
1415
let summary = "Convert linalg convolution ops with 1x1 kernels into linalg matrix multiplication ops.";
@@ -27,6 +28,18 @@ def EraseUnusedLinalgOperands :
2728
let constructor = "mlir::iree_compiler::GlobalOptimization::createEraseUnusedLinalgOperands()";
2829
}
2930

31+
def MaterializeExternDispatchesPass :
32+
Pass<"iree-materialize-extern-dispatches", "mlir::ModuleOp"> {
33+
let summary = "Pass to form custom external dispatches";
34+
let constructor =
35+
"mlir::iree_compiler::GlobalOptimization::createMaterializeExternDispatchesPass()";
36+
let options = [
37+
Option<"pdlModuleFileName", "pdl-module-file-name", "std::string",
38+
/*default=*/"\"\"",
39+
"Optional file name to load a pdl module from.">
40+
];
41+
}
42+
3043
def MaterializeHomogeneousEncodings :
3144
Pass<"iree-global-opt-materialize-homogeneous-encodings", "mlir::ModuleOp"> {
3245
let summary = "Materializes logical encodings to physical encodings if there is a single device target.";

compiler/src/iree/compiler/Pipelines/Options.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ void GlobalOptimizationOptions::bindOptions(OptionsBinder &binder) {
152152
llvm::cl::desc("Strips debug assertions after any useful "
153153
"information has been extracted."),
154154
llvm::cl::cat(category));
155+
binder.opt<std::string>(
156+
"iree-opt-extern-dispatch-pattern-module",
157+
customDispatchPatternModuleFileName,
158+
llvm::cl::desc("File path to custom dispatch rewrite pattnern module."),
159+
llvm::cl::cat(category));
155160
}
156161

157162
void SchedulingOptions::bindOptions(OptionsBinder &binder) {

compiler/src/iree/compiler/Pipelines/Options.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ struct GlobalOptimizationOptions {
9999
// allow hoisting. The threshold is 1MB by default.
100100
int64_t constExprMaxSizeIncreaseThreshold = 1024 * 1024;
101101

102+
// File path to load custom dispatch rewrite patterns from.
103+
std::string customDispatchPatternModuleFileName = "";
104+
102105
void bindOptions(OptionsBinder &binder);
103106
using FromFlags = OptionsFromFlags<GlobalOptimizationOptions>;
104107
};

compiler/src/iree/compiler/Utils/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ iree_compiler_cc_library(
1818
name = "Utils",
1919
srcs = [
2020
"ConversionUtils.cpp",
21+
"CustomPatternApplicatorPassBase.cpp",
2122
"ElementPackingUtils.cpp",
2223
"FlatbufferUtils.cpp",
2324
"ModuleUtils.cpp",
@@ -29,6 +30,7 @@ iree_compiler_cc_library(
2930
],
3031
hdrs = [
3132
"ConversionUtils.h",
33+
"CustomPatternApplicatorPassBase.h",
3234
"ElementPackingUtils.h",
3335
"FlatbufferUtils.h",
3436
"IndexSet.h",

compiler/src/iree/compiler/Utils/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ iree_cc_library(
1515
Utils
1616
HDRS
1717
"ConversionUtils.h"
18+
"CustomPatternApplicatorPassBase.h"
1819
"ElementPackingUtils.h"
1920
"FlatbufferUtils.h"
2021
"IndexSet.h"
@@ -27,6 +28,7 @@ iree_cc_library(
2728
"TracingUtils.h"
2829
SRCS
2930
"ConversionUtils.cpp"
31+
"CustomPatternApplicatorPassBase.cpp"
3032
"ElementPackingUtils.cpp"
3133
"FlatbufferUtils.cpp"
3234
"ModuleUtils.cpp"

0 commit comments

Comments
 (0)