Skip to content

Commit

Permalink
Add support for external custom dispatch rewrite patterns
Browse files Browse the repository at this point in the history
Currently custom dispatches must be written in as a part of downstream
flows (targeting `hal.dispatch.extern` or Flow ops directly). This
limits how lightweight deployment of custom dispatches with nearly stock
IREE flows. This adds a mechanism to plugin and apply pdl patterns
completely separate from other IR. This allows hermetically sealing the
details of a custom dispatch as a separate IR module that is easy to
reuse across models.

The intent is to follow up to this patch with similar passes in the
various input dialects to allow different potential rewrite points, as
well as turn some of the rewriting utilities into transform operations
as another way to represent the dispatch + rewrite logic without the
need for C++ and a downstram project.
  • Loading branch information
qedawkins committed Oct 19, 2023
1 parent acddd0c commit 5d9d65c
Show file tree
Hide file tree
Showing 19 changed files with 562 additions and 0 deletions.
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/Utils/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@ package(

iree_compiler_cc_library(
name = "Utils",
srcs = [
"ExternBuildingUtils.cpp",
],
hdrs = [
"DeviceSwitchBuilder.h",
"ExternBuildingUtils.h",
],
deps = [
"//compiler/src/iree/compiler/Dialect/HAL/IR",
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ iree_cc_library(
Utils
HDRS
"DeviceSwitchBuilder.h"
"ExternBuildingUtils.h"
SRCS
"ExternBuildingUtils.cpp"
DEPS
LLVMSupport
MLIRFuncDialect
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_
#define IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_

#include "iree/compiler/Dialect/HAL/Utils/ExternBuildingUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"

namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {

namespace {

/// Helper to build a hal.dispatch.extern op with the given arguments. This
/// returns the block arguments of the workgroup count region and the extern
/// op. Note that the workgroup count region will not include the terminator
/// and that is left up to the user to properly populate.
static FailureOr<Operation *>
createDispatchExtern(PatternRewriter &rewriter, ValueRange workload,
TypeRange resultTypes, ValueRange resultDims,
ValueRange arguments, ValueRange argumentDims,
DenseI64ArrayAttr tiedOperands, DictionaryAttr attrDict) {
Location rootLoc = (*arguments.begin()).getLoc();
SmallVector<int64_t> tiedOperandsIntList(tiedOperands.asArrayRef());
SmallVector<NamedAttribute> namedAttributes(attrDict.begin(), attrDict.end());
Operation *externOp = rewriter.create<DispatchExternOp>(
rootLoc, workload, resultTypes, resultDims, arguments, argumentDims,
tiedOperandsIntList, namedAttributes);
return externOp;
}

/// Helper to emplace a block on the given hal.dispatch.extern op. This returns
/// the block arguments of the updated workgroup count region. Note that the
/// workgroup count region will not include the terminator and that is left up
/// to the user to properly populate.
static FailureOr<ValueRange>
emplaceExternWorkgroupCountRegion(PatternRewriter &rewriter, Operation *op) {
Location rootLoc = op->getLoc();
auto externOp = dyn_cast<DispatchExternOp>(op);
if (!externOp) {
return failure();
}

SmallVector<Type> countTypes({rewriter.getType<IREE::HAL::DeviceType>()});
SmallVector<Location> countLocs({rootLoc});
for (auto workloadIdx : externOp.getWorkload()) {
countTypes.push_back(workloadIdx.getType());
countLocs.push_back(workloadIdx.getLoc());
}

auto &entryBlock = externOp.getWorkgroupCount().emplaceBlock();
auto countArgs = entryBlock.addArguments(countTypes, countLocs);

ArrayRef<BlockArgument> countArgsArray(countArgs.begin(), countArgs.end());

/// Update the insertion point to the beginning of the block to enable
/// contructing the workgroup count region.
rewriter.setInsertionPointToStart(&entryBlock);
return ValueRange(countArgsArray);
}
} // namespace

void registerExternDispatchRewriteFunction(PDLPatternModule &pdlPatterns) {
pdlPatterns.registerRewriteFunction("create_dispatch_extern",
createDispatchExtern);
pdlPatterns.registerRewriteFunction("emplace_extern_workgroup_count",
emplaceExternWorkgroupCountRegion);
}

} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir

#endif // IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_
27 changes: 27 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/Utils/ExternBuildingUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_
#define IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_

#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"

namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {

void registerExternDispatchRewriteFunction(PDLPatternModule &pdlPatterns);

} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir

#endif // IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ iree_compiler_cc_library(
name = "GlobalOptimization",
srcs = [
"MaterializeHomogeneousEncodings.cpp",
"MaterializeExternDispatches.cpp",
"Passes.cpp",
],
hdrs = [
Expand All @@ -57,6 +58,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Dialect/Flow/Transforms",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect",
"//compiler/src/iree/compiler/Dialect/HAL/Utils",
"//compiler/src/iree/compiler/Dialect/Util/Transforms",
"//compiler/src/iree/compiler/Pipelines:Options",
"//compiler/src/iree/compiler/Utils",
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ iree_cc_library(
HDRS
"Passes.h"
SRCS
"MaterializeExternDispatches.cpp"
"MaterializeHomogeneousEncodings.cpp"
"Passes.cpp"
DEPS
Expand All @@ -55,6 +56,7 @@ iree_cc_library(
iree::compiler::Dialect::Flow::Transforms
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::HAL::Utils
iree::compiler::Dialect::Util::Transforms
iree::compiler::Pipelines::Options
iree::compiler::Utils
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Utils/ExternBuildingUtils.h"
#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/Utils/CustomPatternApplicatorPassBase.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;

namespace {

class MaterializeExternDispatchesPass
: public iree_compiler::PatternApplicatorPassBase<
MaterializeExternDispatchesPass,
iree_compiler::GlobalOptimization::
MaterializeExternDispatchesPassBase> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mlir::iree_compiler::IREE::HAL::HALDialect, pdl::PDLDialect,
pdl_interp::PDLInterpDialect>();
}

LogicalResult initializePatterns(MLIRContext *context,
RewritePatternSet &tmpPatterns) {
iree_compiler::IREE::HAL::registerExternDispatchRewriteFunction(
tmpPatterns.getPDLPatterns());
return iree_compiler::detail::populatePDLModuleFromFileName(
context, tmpPatterns, this->pdlModuleFileName);
}

MaterializeExternDispatchesPass(StringRef pdlModuleFileName = StringRef()) {
this->pdlModuleFileName = pdlModuleFileName.str();
}
MaterializeExternDispatchesPass(const MaterializeExternDispatchesPass &pass) =
default;
};
} // namespace

namespace mlir {
namespace iree_compiler {
namespace GlobalOptimization {
std::unique_ptr<Pass>
createMaterializeExternDispatchesPass(std::string pdlModuleFileName) {
return std::make_unique<MaterializeExternDispatchesPass>(pdlModuleFileName);
}
} // namespace GlobalOptimization
} // namespace iree_compiler
} // namespace mlir
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ void buildGlobalOptimizationPassPipeline(
mainPassManager.addPass(IREE::Util::createDemoteI64ToI32Pass());
}

if (!transformOptions.options.customDispatchPatternModuleFileName.empty()) {
mainPassManager.addPass(createMaterializeExternDispatchesPass(
transformOptions.options.customDispatchPatternModuleFileName));
}

// Preprocessing passes to get the program into a canonical state.
FunctionLikeNest(mainPassManager)
.addPass(IREE::Flow::createRemoveZeroExtentTensorsPass)
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/GlobalOptimization/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ struct TransformOptions : public PassPipelineOptions<TransformOptions> {
void buildGlobalOptimizationPassPipeline(
OpPassManager &mainPassManager, const TransformOptions &transformOptions);

// Materializes logical encodings to physical encodings if there is a single
// device target.
std::unique_ptr<OperationPass<mlir::ModuleOp>>
createMaterializeExternDispatchesPass(std::string pdlModuleFileName = "");

// Materializes logical encodings to physical encodings if there is a single
// device target.
std::unique_ptr<OperationPass<mlir::ModuleOp>>
Expand Down
12 changes: 12 additions & 0 deletions compiler/src/iree/compiler/GlobalOptimization/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@

include "mlir/Pass/PassBase.td"

def MaterializeExternDispatchesPass :
Pass<"iree-materialize-extern-dispatches", "mlir::ModuleOp"> {
let summary = "Pass to form custom external dispatches";
let constructor =
"mlir::iree_compiler::GlobalOptimization::createMaterializeExternDispatchesPass()";
let options = [
Option<"pdlModuleFileName", "pdl-module-file-name", "std::string",
/*default=*/"\"\"",
"Optional file name to load a pdl module from.">
];
}

def MaterializeHomogeneousEncodings :
Pass<"iree-global-opt-materialize-homogeneous-encodings", "mlir::ModuleOp"> {
let summary = "Materializes logical encodings to physical encodings if there is a single device target.";
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Pipelines/Options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ void GlobalOptimizationOptions::bindOptions(OptionsBinder &binder) {
llvm::cl::desc("Strips debug assertions after any useful "
"information has been extracted."),
llvm::cl::cat(category));
binder.opt<std::string>(
"iree-opt-extern-dispatch-pattern-module",
customDispatchPatternModuleFileName,
llvm::cl::desc("File path to custom dispatch rewrite pattnern module."),
llvm::cl::cat(category));
}

void SchedulingOptions::bindOptions(OptionsBinder &binder) {
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Pipelines/Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ struct GlobalOptimizationOptions {
// Strips debug assertions after any useful information has been extracted.
bool stripAssertions = false;

// File path to load custom dispatch rewrite patterns from.
std::string customDispatchPatternModuleFileName = "";

void bindOptions(OptionsBinder &binder);
using FromFlags = OptionsFromFlags<GlobalOptimizationOptions>;
};
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Utils/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ iree_compiler_cc_library(
name = "Utils",
srcs = [
"ConversionUtils.cpp",
"CustomPatternApplicatorPassBase.cpp",
"ElementPackingUtils.cpp",
"FlatbufferUtils.cpp",
"ModuleUtils.cpp",
Expand All @@ -29,6 +30,7 @@ iree_compiler_cc_library(
],
hdrs = [
"ConversionUtils.h",
"CustomPatternApplicatorPassBase.h",
"ElementPackingUtils.h",
"FlatbufferUtils.h",
"IndexSet.h",
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ iree_cc_library(
Utils
HDRS
"ConversionUtils.h"
"CustomPatternApplicatorPassBase.h"
"ElementPackingUtils.h"
"FlatbufferUtils.h"
"IndexSet.h"
Expand All @@ -27,6 +28,7 @@ iree_cc_library(
"TracingUtils.h"
SRCS
"ConversionUtils.cpp"
"CustomPatternApplicatorPassBase.cpp"
"ElementPackingUtils.cpp"
"FlatbufferUtils.cpp"
"ModuleUtils.cpp"
Expand Down
Loading

0 comments on commit 5d9d65c

Please sign in to comment.