Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for external custom dispatch rewrite patterns #15235

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,16 @@ SmallVector<int64_t> TensorBarrierOp::getTiedResultOperandIndices() {
// hal.dispatch.extern
//===----------------------------------------------------------------------===//

void DispatchExternOp::build(OpBuilder &builder, OperationState &state,
ValueRange workload, TypeRange resultTypes,
ValueRange resultDims, ValueRange arguments,
ValueRange argumentDims,
DenseI64ArrayAttr tiedOperands,
DictionaryAttr attributes) {
build(builder, state, workload, resultTypes, resultDims, arguments,
argumentDims, tiedOperands.asArrayRef(), attributes.getValue());
}

void DispatchExternOp::build(OpBuilder &builder, OperationState &state,
ValueRange workload, TypeRange resultTypes,
ValueRange resultDims, ValueRange arguments,
Expand Down Expand Up @@ -550,6 +560,24 @@ void DispatchExternOp::build(OpBuilder &builder, OperationState &state,
state.addRegion();
}

/// Helper to emplace a block on the given hal.dispatch.extern op. This returns
/// the entry block of the updated workgroup count region and sets up the
/// block arguments as a !hal.device + inferred from the workload types.
/// The workgroup count region will not include the terminator and that is left
/// up to the user to properly populate.
Block *DispatchExternOp::emplaceWorkgroupCountRegion(OpBuilder &builder) {
SmallVector<Type> countTypes({builder.getType<IREE::HAL::DeviceType>()});
SmallVector<Location> countLocs({getLoc()});
for (auto workloadIdx : getWorkload()) {
countTypes.push_back(workloadIdx.getType());
countLocs.push_back(workloadIdx.getLoc());
}

auto &entryBlock = getWorkgroupCount().emplaceBlock();
entryBlock.addArguments(countTypes, countLocs);
return &entryBlock;
}

// Verifies that |dynamicDims| contains the appropriate number of dims for all
// of the dynamic dimensions in |values|.
static LogicalResult verifyOpDynamicDims(Operation *op, ValueRange values,
Expand Down
8 changes: 8 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,12 @@ def HAL_DispatchExternOp : HAL_PureOp<"dispatch.extern", [

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins
"ValueRange":$workload,
"TypeRange":$resultTypes, "ValueRange":$resultDims,
"ValueRange":$arguments, "ValueRange":$argumentDims,
"DenseI64ArrayAttr":$tiedOperands,
CArg<"DictionaryAttr", "nullptr">:$attributes)>,
OpBuilder<(ins
"ValueRange":$workload,
"TypeRange":$resultTypes, "ValueRange":$resultDims,
Expand All @@ -420,6 +426,8 @@ def HAL_DispatchExternOp : HAL_PureOp<"dispatch.extern", [
getResultTypes());
}

Block *emplaceWorkgroupCountRegion(OpBuilder &builder);

/// Returns the index of the args() operand in the Operation operands list.
unsigned mapArgOperandToOpOperand(unsigned i) { return i + getWorkload().size(); };

Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ iree_compiler_cc_library(
"DetachElementwiseFromNamedOps.cpp",
"EraseUnusedLinalgOperands.cpp",
"ExpandVectors.cpp",
"MaterializeExternDispatches.cpp",
"MaterializeHomogeneousEncodings.cpp",
"Passes.cpp",
"RemoveZeroExtentTensors.cpp",
Expand Down Expand Up @@ -82,6 +83,8 @@ iree_compiler_cc_library(
"@llvm-project//mlir:LinalgUtils",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:PDLDialect",
"@llvm-project//mlir:PDLInterpDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TensorTransforms",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ iree_cc_library(
"DetachElementwiseFromNamedOps.cpp"
"EraseUnusedLinalgOperands.cpp"
"ExpandVectors.cpp"
"MaterializeExternDispatches.cpp"
"MaterializeHomogeneousEncodings.cpp"
"Passes.cpp"
"RemoveZeroExtentTensors.cpp"
Expand All @@ -65,6 +66,8 @@ iree_cc_library(
MLIRLinalgUtils
MLIRMemRefDialect
MLIRMemRefTransforms
MLIRPDLDialect
MLIRPDLInterpDialect
MLIRPass
MLIRTensorDialect
MLIRTensorTransforms
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/Utils/CustomPatternApplicatorPassBase.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
namespace iree_compiler {
namespace GlobalOptimization {

namespace {

/// Wrapper to build a hal.dispatch.extern op with the given arguments.
static Operation *
createDispatchExtern(PatternRewriter &rewriter, Operation *target,
ValueRange workload, TypeRange resultTypes,
ValueRange resultDims, ValueRange arguments,
ValueRange argumentDims, DenseI64ArrayAttr tiedOperands,
DictionaryAttr attrDict) {
return rewriter.create<IREE::HAL::DispatchExternOp>(
target->getLoc(), workload, resultTypes, resultDims, arguments,
argumentDims, tiedOperands, attrDict);
}

/// Helper to emplace a block on the given hal.dispatch.extern op. This returns
/// the block arguments of the updated workgroup count region. Note that the
/// workgroup count region will not include the terminator and that is left up
/// to the user to properly populate.
static FailureOr<ValueRange>
emplaceExternWorkgroupCountRegion(PatternRewriter &rewriter, Operation *op) {
auto externOp = dyn_cast<IREE::HAL::DispatchExternOp>(op);
if (!externOp) {
return failure();
}

Block *entryBlock = externOp.emplaceWorkgroupCountRegion(rewriter);

/// Update the insertion point to the beginning of the block to enable
/// contructing the workgroup count region.
rewriter.setInsertionPointToStart(entryBlock);
return ValueRange(entryBlock->getArguments());
}

static void
registerExternDispatchRewriteFunction(PDLPatternModule &pdlPatterns) {
pdlPatterns.registerRewriteFunction("create_dispatch_extern",
createDispatchExtern);
pdlPatterns.registerRewriteFunction("emplace_extern_workgroup_count",
emplaceExternWorkgroupCountRegion);
}

} // namespace

class MaterializeExternDispatchesPass
: public iree_compiler::PatternApplicatorPassBase<
MaterializeExternDispatchesPass,
iree_compiler::GlobalOptimization::
MaterializeExternDispatchesPassBase> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<mlir::iree_compiler::IREE::HAL::HALDialect, arith::ArithDialect,
pdl::PDLDialect, pdl_interp::PDLInterpDialect>();
}

LogicalResult initializePatterns(MLIRContext *context,
RewritePatternSet &tmpPatterns) {
registerExternDispatchRewriteFunction(tmpPatterns.getPDLPatterns());
for (auto fileName : this->pdlModuleFileNames) {
if (failed(iree_compiler::detail::populatePDLModuleFromFileName(
context, tmpPatterns, fileName))) {
return failure();
}
}
return success();
}

MaterializeExternDispatchesPass(ArrayRef<std::string> pdlModuleFileNames) {
this->pdlModuleFileNames = pdlModuleFileNames;
}
MaterializeExternDispatchesPass(const MaterializeExternDispatchesPass &pass) =
default;
};

std::unique_ptr<Pass> createMaterializeExternDispatchesPass(
ArrayRef<std::string> pdlModuleFileNames) {
return std::make_unique<MaterializeExternDispatchesPass>(pdlModuleFileNames);
}
} // namespace GlobalOptimization
} // namespace iree_compiler
} // namespace mlir
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ void buildGlobalOptimizationPassPipeline(
mainPassManager.addPass(IREE::Util::createDemoteI64ToI32Pass());
}

if (!transformOptions.options.customDispatchPatternModuleFileNames.empty()) {
mainPassManager.addPass(createMaterializeExternDispatchesPass(
transformOptions.options.customDispatchPatternModuleFileNames));
}

// Preprocessing passes to get the program into a canonical state.
FunctionLikeNest(mainPassManager)
.addPass(createRemoveZeroExtentTensorsPass)
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/iree/compiler/GlobalOptimization/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ createEraseUnusedLinalgOperands();
// forms.
std::unique_ptr<Pass> createExpandVectorsPass();

// Materializes logical encodings to physical encodings if there is a single
// device target.
std::unique_ptr<OperationPass<mlir::ModuleOp>>
createMaterializeExternDispatchesPass(
ArrayRef<std::string> pdlModuleFileName = {});

// Materializes logical encodings to physical encodings if there is a single
// device target.
std::unique_ptr<OperationPass<mlir::ModuleOp>>
Expand Down
12 changes: 12 additions & 0 deletions compiler/src/iree/compiler/GlobalOptimization/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ def ExpandVectors : Pass<"iree-global-opt-expand-vectors", ""> {
let constructor = "mlir::iree_compiler::GlobalOptimization::createExpandVectorsPass()";
}

def MaterializeExternDispatchesPass :
Pass<"iree-materialize-extern-dispatches", ""> {
let summary = "Pass to form custom external dispatches";
let constructor =
"mlir::iree_compiler::GlobalOptimization::createMaterializeExternDispatchesPass()";
let options = [
ListOption<"pdlModuleFileNames", "pdl-module-file-names", "std::string",
"Optional files to load a pdl module from.",
"llvm::cl::ZeroOrMore">
];
}

def MaterializeHomogeneousEncodings :
Pass<"iree-global-opt-materialize-homogeneous-encodings", "mlir::ModuleOp"> {
let summary = "Materializes logical encodings to physical encodings if there is a single device target.";
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Pipelines/Options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ void GlobalOptimizationOptions::bindOptions(OptionsBinder &binder) {
llvm::cl::desc("Strips debug assertions after any useful "
"information has been extracted."),
llvm::cl::cat(category));
binder.list<std::string>(
"iree-opt-extern-dispatch-pattern-module",
customDispatchPatternModuleFileNames,
llvm::cl::desc("File path to custom dispatch rewrite pattern module."),
llvm::cl::ZeroOrMore, llvm::cl::cat(category));
}

void SchedulingOptions::bindOptions(OptionsBinder &binder) {
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Pipelines/Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ struct GlobalOptimizationOptions {
// allow hoisting. The threshold is 1MB by default.
int64_t constExprMaxSizeIncreaseThreshold = 1024 * 1024;

// File paths to load custom dispatch rewrite patterns from.
std::vector<std::string> customDispatchPatternModuleFileNames = {};

void bindOptions(OptionsBinder &binder);
using FromFlags = OptionsFromFlags<GlobalOptimizationOptions>;
};
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Utils/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ iree_compiler_cc_library(
name = "Utils",
srcs = [
"ConversionUtils.cpp",
"CustomPatternApplicatorPassBase.cpp",
"ElementPackingUtils.cpp",
"FlatbufferUtils.cpp",
"ModuleUtils.cpp",
Expand All @@ -29,6 +30,7 @@ iree_compiler_cc_library(
],
hdrs = [
"ConversionUtils.h",
"CustomPatternApplicatorPassBase.h",
"ElementPackingUtils.h",
"FlatbufferUtils.h",
"IndexSet.h",
Expand All @@ -50,9 +52,11 @@ iree_compiler_cc_library(
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:PDLDialect",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
],
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ iree_cc_library(
Utils
HDRS
"ConversionUtils.h"
"CustomPatternApplicatorPassBase.h"
"ElementPackingUtils.h"
"FlatbufferUtils.h"
"IndexSet.h"
Expand All @@ -27,6 +28,7 @@ iree_cc_library(
"TracingUtils.h"
SRCS
"ConversionUtils.cpp"
"CustomPatternApplicatorPassBase.cpp"
"ElementPackingUtils.cpp"
"FlatbufferUtils.cpp"
"ModuleUtils.cpp"
Expand All @@ -40,9 +42,11 @@ iree_cc_library(
MLIRArithDialect
MLIRFuncDialect
MLIRIR
MLIRPDLDialect
MLIRParser
MLIRPass
MLIRSupport
MLIRTensorDialect
MLIRTransformUtils
MLIRTransforms
iree::base
Expand Down
Loading
Loading