-
Notifications
You must be signed in to change notification settings - Fork 639
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for external custom dispatch rewrite patterns
Currently custom dispatches must be written in as a part of downstream flows (targeting `hal.dispatch.extern` or Flow ops directly). This limits how lightweight deployment of custom dispatches with nearly stock IREE flows. This adds a mechanism to plugin and apply pdl patterns completely separate from other IR. This allows hermetically sealing the details of a custom dispatch as a separate IR module that is easy to reuse across models. The intent is to follow up to this patch with similar passes in the various input dialects to allow different potential rewrite points, as well as turn some of the rewriting utilities into transform operations as another way to represent the dispatch + rewrite logic without the need for C++ and a downstram project.
- Loading branch information
Showing
19 changed files
with
562 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
91 changes: 91 additions & 0 deletions
91
compiler/src/iree/compiler/Dialect/HAL/Utils/ExternBuildingUtils.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
// Copyright 2023 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#ifndef IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_ | ||
#define IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_ | ||
|
||
#include "iree/compiler/Dialect/HAL/Utils/ExternBuildingUtils.h" | ||
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" | ||
#include "iree/compiler/Dialect/HAL/IR/HALOps.h" | ||
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" | ||
#include "mlir/IR/Builders.h" | ||
#include "mlir/IR/IRMapping.h" | ||
#include "mlir/IR/Location.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Pass/PassManager.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "mlir/Transforms/RegionUtils.h" | ||
|
||
namespace mlir { | ||
namespace iree_compiler { | ||
namespace IREE { | ||
namespace HAL { | ||
|
||
namespace { | ||
|
||
/// Helper to build a hal.dispatch.extern op with the given arguments. This | ||
/// returns the block arguments of the workgroup count region and the extern | ||
/// op. Note that the workgroup count region will not include the terminator | ||
/// and that is left up to the user to properly populate. | ||
static FailureOr<Operation *> | ||
createDispatchExtern(PatternRewriter &rewriter, ValueRange workload, | ||
TypeRange resultTypes, ValueRange resultDims, | ||
ValueRange arguments, ValueRange argumentDims, | ||
DenseI64ArrayAttr tiedOperands, DictionaryAttr attrDict) { | ||
Location rootLoc = (*arguments.begin()).getLoc(); | ||
SmallVector<int64_t> tiedOperandsIntList(tiedOperands.asArrayRef()); | ||
SmallVector<NamedAttribute> namedAttributes(attrDict.begin(), attrDict.end()); | ||
Operation *externOp = rewriter.create<DispatchExternOp>( | ||
rootLoc, workload, resultTypes, resultDims, arguments, argumentDims, | ||
tiedOperandsIntList, namedAttributes); | ||
return externOp; | ||
} | ||
|
||
/// Helper to emplace a block on the given hal.dispatch.extern op. This returns | ||
/// the block arguments of the updated workgroup count region. Note that the | ||
/// workgroup count region will not include the terminator and that is left up | ||
/// to the user to properly populate. | ||
static FailureOr<ValueRange> | ||
emplaceExternWorkgroupCountRegion(PatternRewriter &rewriter, Operation *op) { | ||
Location rootLoc = op->getLoc(); | ||
auto externOp = dyn_cast<DispatchExternOp>(op); | ||
if (!externOp) { | ||
return failure(); | ||
} | ||
|
||
SmallVector<Type> countTypes({rewriter.getType<IREE::HAL::DeviceType>()}); | ||
SmallVector<Location> countLocs({rootLoc}); | ||
for (auto workloadIdx : externOp.getWorkload()) { | ||
countTypes.push_back(workloadIdx.getType()); | ||
countLocs.push_back(workloadIdx.getLoc()); | ||
} | ||
|
||
auto &entryBlock = externOp.getWorkgroupCount().emplaceBlock(); | ||
auto countArgs = entryBlock.addArguments(countTypes, countLocs); | ||
|
||
ArrayRef<BlockArgument> countArgsArray(countArgs.begin(), countArgs.end()); | ||
|
||
/// Update the insertion point to the beginning of the block to enable | ||
/// contructing the workgroup count region. | ||
rewriter.setInsertionPointToStart(&entryBlock); | ||
return ValueRange(countArgsArray); | ||
} | ||
} // namespace | ||
|
||
void registerExternDispatchRewriteFunction(PDLPatternModule &pdlPatterns) { | ||
pdlPatterns.registerRewriteFunction("create_dispatch_extern", | ||
createDispatchExtern); | ||
pdlPatterns.registerRewriteFunction("emplace_extern_workgroup_count", | ||
emplaceExternWorkgroupCountRegion); | ||
} | ||
|
||
} // namespace HAL | ||
} // namespace IREE | ||
} // namespace iree_compiler | ||
} // namespace mlir | ||
|
||
#endif // IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_ |
27 changes: 27 additions & 0 deletions
27
compiler/src/iree/compiler/Dialect/HAL/Utils/ExternBuildingUtils.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// Copyright 2023 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#ifndef IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_ | ||
#define IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_ | ||
|
||
#include "mlir/IR/Attributes.h" | ||
#include "mlir/IR/Builders.h" | ||
#include "mlir/IR/Operation.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
|
||
namespace mlir { | ||
namespace iree_compiler { | ||
namespace IREE { | ||
namespace HAL { | ||
|
||
void registerExternDispatchRewriteFunction(PDLPatternModule &pdlPatterns); | ||
|
||
} // namespace HAL | ||
} // namespace IREE | ||
} // namespace iree_compiler | ||
} // namespace mlir | ||
|
||
#endif // IREE_COMPILER_DIALECT_HAL_UTILS_EXTERN_BUILDING_UTILS_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
57 changes: 57 additions & 0 deletions
57
compiler/src/iree/compiler/GlobalOptimization/MaterializeExternDispatches.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
// Copyright 2023 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" | ||
#include "iree/compiler/Dialect/HAL/IR/HALOps.h" | ||
#include "iree/compiler/Dialect/HAL/Utils/ExternBuildingUtils.h" | ||
#include "iree/compiler/GlobalOptimization/PassDetail.h" | ||
#include "iree/compiler/Utils/CustomPatternApplicatorPassBase.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/PDL/IR/PDL.h" | ||
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" | ||
#include "mlir/Pass/Pass.h" | ||
|
||
using namespace mlir; | ||
|
||
namespace { | ||
|
||
class MaterializeExternDispatchesPass | ||
: public iree_compiler::PatternApplicatorPassBase< | ||
MaterializeExternDispatchesPass, | ||
iree_compiler::GlobalOptimization:: | ||
MaterializeExternDispatchesPassBase> { | ||
public: | ||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<mlir::iree_compiler::IREE::HAL::HALDialect, pdl::PDLDialect, | ||
pdl_interp::PDLInterpDialect>(); | ||
} | ||
|
||
LogicalResult initializePatterns(MLIRContext *context, | ||
RewritePatternSet &tmpPatterns) { | ||
iree_compiler::IREE::HAL::registerExternDispatchRewriteFunction( | ||
tmpPatterns.getPDLPatterns()); | ||
return iree_compiler::detail::populatePDLModuleFromFileName( | ||
context, tmpPatterns, this->pdlModuleFileName); | ||
} | ||
|
||
MaterializeExternDispatchesPass(StringRef pdlModuleFileName = StringRef()) { | ||
this->pdlModuleFileName = pdlModuleFileName.str(); | ||
} | ||
MaterializeExternDispatchesPass(const MaterializeExternDispatchesPass &pass) = | ||
default; | ||
}; | ||
} // namespace | ||
|
||
namespace mlir { | ||
namespace iree_compiler { | ||
namespace GlobalOptimization { | ||
std::unique_ptr<Pass> | ||
createMaterializeExternDispatchesPass(std::string pdlModuleFileName) { | ||
return std::make_unique<MaterializeExternDispatchesPass>(pdlModuleFileName); | ||
} | ||
} // namespace GlobalOptimization | ||
} // namespace iree_compiler | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.