-
Notifications
You must be signed in to change notification settings - Fork 645
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Stream] Implement SpecializeEncodings pass (1/n) (#19502)
There are three major changes in the revision: - Introduce `AffinityAnalysisDialectInterface` Stream dialect interface. It is used to fetch attributes that are defined by other dialects. In the revision, HAL implements the dialect interface, and it can return whatever attribute attached in HAL::ExecutableTarget attributes. The main idea of the dialect interface is that Stream **does not** need to depend on HAL to get the layout information. - Add `cloneWithLayouts` method to the EncodingAttr. It is used in the encoding specialization pass where it can resolve the layout requirements and add it to the `layouts` field. The other optional parameters are dropped because the layout is already resolved. It can be a new Encoding dialect attribute because it is just describing the layout. The stream tensor ops do not need to know the `op_type`, `element_types` and `operand_index` parameters. It only needs the layout information, and the attribute should implement the interface method. - Partially implement the SpecializeEncodings pass. The responsibility of the pass is large, so I decide to implement it incrementally. This revision only implements the mechanism of updating stream tensor ops' encoding, and only stream.tensor.sizeof op is supported. The rest of the support for other stream tensor op can be added later on. The executable duplication and the update of dispatch ops will be implemented in subsequent PRs. --------- Signed-off-by: hanhanW <hanhan0912@gmail.com>
- Loading branch information
Showing
16 changed files
with
303 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
36 changes: 36 additions & 0 deletions
36
compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
// Copyright 2025 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#ifndef IREE_COMPILER_DIALECT_STREAM_IR_STREAMINTERACES_H_ | ||
#define IREE_COMPILER_DIALECT_STREAM_IR_STREAMINTERACES_H_ | ||
|
||
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" | ||
#include "mlir/IR/Attributes.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/DialectInterface.h" | ||
#include "mlir/IR/Operation.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
namespace mlir::iree_compiler::IREE::Stream { | ||
|
||
using ResolveLayoutAttrFn = std::function<LogicalResult( | ||
AffinityAttr, Operation *, SetVector<Attribute> &)>; | ||
|
||
class AffinityAnalysisDialectInterface | ||
: public DialectInterface::Base<AffinityAnalysisDialectInterface> { | ||
public: | ||
AffinityAnalysisDialectInterface(Dialect *dialect) : Base(dialect) {} | ||
|
||
/// The `moduleOp` must remain live and unmodified for as long as the returned | ||
/// capture is. Otherwise, it will likely be incorrect or crash if the module | ||
/// op is mutated, especially when module scope analysis is run. | ||
virtual ResolveLayoutAttrFn | ||
makeLayoutAttrResolver(ModuleOp moduleOp) const = 0; | ||
}; | ||
|
||
} // namespace mlir::iree_compiler::IREE::Stream | ||
|
||
#endif // IREE_COMPILER_DIALECT_STREAM_IR_STREAM_INTERFACES_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
169 changes: 169 additions & 0 deletions
169
compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
// Copyright 2025 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h" | ||
#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h" | ||
#include "iree/compiler/Dialect/Stream/IR/StreamInterfaces.h" | ||
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" | ||
#include "iree/compiler/Dialect/Stream/IR/StreamTraits.h" | ||
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" | ||
#include "iree/compiler/Dialect/Stream/Transforms/Passes.h" | ||
#include "llvm/ADT/TypeSwitch.h" | ||
#include "llvm/Support/Casting.h" | ||
#include "llvm/Support/Debug.h" | ||
#include "llvm/Support/LogicalResult.h" | ||
#include "mlir/IR/BuiltinTypes.h" | ||
#include "mlir/Interfaces/FunctionInterfaces.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Support/LLVM.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
namespace mlir::iree_compiler::IREE::Stream { | ||
|
||
#define DEBUG_TYPE "iree-stream-specialize-encodings" | ||
|
||
#define GEN_PASS_DEF_SPECIALIZEENCODINGSPASS | ||
#include "iree/compiler/Dialect/Stream/Transforms/Passes.h.inc" | ||
|
||
namespace { | ||
/// Returns a stably sorted list of dialect interfaces of T for all dialects | ||
/// used within the given module. | ||
template <typename T> | ||
SmallVector<const T *> gatherUsedDialectInterfaces(mlir::ModuleOp moduleOp) { | ||
SmallPtrSet<const T *, 4> resultSet; | ||
for (auto dialect : moduleOp.getContext()->getLoadedDialects()) { | ||
auto *dialectInterface = dialect->getRegisteredInterface<T>(); | ||
if (!dialectInterface) | ||
continue; | ||
resultSet.insert(dialectInterface); | ||
} | ||
|
||
// NOTE: to ensure deterministic output we sort the result so that imports are | ||
// always added in a consistent order. | ||
SmallVector<const T *> results = {resultSet.begin(), resultSet.end()}; | ||
llvm::sort( | ||
results, +[](const T *a, const T *b) { | ||
return a->getDialect()->getNamespace().compare( | ||
b->getDialect()->getNamespace()) < 0; | ||
}); | ||
return results; | ||
} | ||
|
||
// TODO(hanchung): Add "cloneWithEncoding" method to RankedTensorType. | ||
static RankedTensorType cloneWithEncoding(RankedTensorType type, | ||
Attribute encodingAttr) { | ||
return RankedTensorType::get(type.getShape(), type.getElementType(), | ||
encodingAttr); | ||
} | ||
|
||
static LogicalResult addLayoutsToTensorPhaseOps( | ||
ModuleOp moduleOp, FunctionOpInterface funcOp, | ||
IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr) { | ||
SmallVector<IREE::Stream::AffinityOpInterface> candidates; | ||
funcOp.walk([&](IREE::Stream::AffinityOpInterface affinityOp) { | ||
// Only need to update encoding types for ops that have TensorPhaseOp trait. | ||
if (!affinityOp->hasTrait<OpTrait::IREE::Stream::TensorPhaseOp>()) { | ||
return; | ||
} | ||
|
||
// Bail out if the operation does not have an affinity attribute. | ||
auto affinityAttr = affinityOp.getAffinityAttr(); | ||
if (!affinityAttr) { | ||
return; | ||
} | ||
candidates.push_back(affinityOp); | ||
}); | ||
|
||
if (candidates.empty()) { | ||
return success(); | ||
} | ||
|
||
IRRewriter rewriter(funcOp.getContext()); | ||
for (auto affinityOp : candidates) { | ||
auto affinityAttr = affinityOp.getAffinityAttr(); | ||
SetVector<Attribute> layouts; | ||
if (failed(resolveLayoutAttr(affinityAttr, moduleOp, layouts))) { | ||
return affinityOp.emitError("failed on making layouts"); | ||
} | ||
|
||
// Returns an updated encoding attribute if an encoding attribute is present | ||
// in the type. Otherwise, returns std::nullopt. | ||
auto getEncodingWithNewLayouts = | ||
[=](Type type) -> std::optional<IREE::Encoding::EncodingAttr> { | ||
auto rankedTensorType = dyn_cast<RankedTensorType>(type); | ||
if (!rankedTensorType) { | ||
return std::nullopt; | ||
} | ||
auto encodingAttr = IREE::Encoding::getEncodingAttr(rankedTensorType); | ||
if (!encodingAttr) { | ||
return std::nullopt; | ||
} | ||
return encodingAttr.cloneWithLayouts(layouts.getArrayRef()); | ||
}; | ||
|
||
// TODO(hanchung): Update other Stream operations. | ||
LogicalResult result = | ||
TypeSwitch<Operation *, LogicalResult>(affinityOp) | ||
.Case<IREE::Stream::TensorSizeOfOp>([&](auto sizeOfOp) { | ||
auto encodingType = | ||
dyn_cast<RankedTensorType>(sizeOfOp.getEncoding()); | ||
if (!encodingType) { | ||
return success(); | ||
} | ||
std::optional<IREE::Encoding::EncodingAttr> encodingAttr = | ||
getEncodingWithNewLayouts(encodingType); | ||
if (!encodingAttr) { | ||
return success(); | ||
} | ||
rewriter.modifyOpInPlace(sizeOfOp, [&] { | ||
sizeOfOp.setEncoding( | ||
cloneWithEncoding(encodingType, encodingAttr.value())); | ||
}); | ||
return success(); | ||
}) | ||
.Default([](auto *op) { return failure(); }); | ||
|
||
if (failed(result)) { | ||
return failure(); | ||
} | ||
} | ||
return success(); | ||
} | ||
} // namespace | ||
|
||
struct SpecializeEncodingsPass | ||
: public impl::SpecializeEncodingsPassBase<SpecializeEncodingsPass> { | ||
void runOnOperation() override { | ||
ModuleOp moduleOp = getOperation(); | ||
auto usedDialects = gatherUsedDialectInterfaces< | ||
IREE::Stream::AffinityAnalysisDialectInterface>(moduleOp); | ||
if (usedDialects.size() != 1) { | ||
moduleOp.emitError("expected only one dialect implementing " | ||
"AffinityAnalysisDialectInterface"); | ||
return signalPassFailure(); | ||
} | ||
|
||
llvm::MapVector<StringRef, IREE::Stream::ExecutableOp> executableOps; | ||
for (auto executableOp : moduleOp.getOps<IREE::Stream::ExecutableOp>()) { | ||
executableOps[executableOp.getName()] = executableOp; | ||
} | ||
|
||
IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr = | ||
usedDialects[0]->makeLayoutAttrResolver(moduleOp); | ||
for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) { | ||
if (failed(addLayoutsToTensorPhaseOps(moduleOp, funcOp, | ||
resolveLayoutAttr))) { | ||
funcOp.emitError( | ||
"failed on adding layouts to Stream::TensorPhaseOp with encodings"); | ||
return signalPassFailure(); | ||
} | ||
|
||
// TODO(hanchung): Duplicate executables and update dispatch ops. | ||
} | ||
} | ||
}; | ||
|
||
} // namespace mlir::iree_compiler::IREE::Stream |
Oops, something went wrong.