Skip to content

Commit

Permalink
[Stream] Implement SpecializeEncodings pass (1/n) (#19502)
Browse files Browse the repository at this point in the history
There are three major changes in the revision:

- Introduce `AffinityAnalysisDialectInterface` Stream dialect interface.
It is used to fetch attributes that are defined by other dialects. In
the revision, HAL implements the dialect interface, and it can return
whatever attribute attached in HAL::ExecutableTarget attributes. The
main idea of the dialect interface is that Stream **does not** need to
depend on HAL to get the layout information.
- Add `cloneWithLayouts` method to the EncodingAttr. It is used in the
encoding specialization pass where it can resolve the layout
requirements and add it to the `layouts` field. The other optional
parameters are dropped because the layout is already resolved. It can be
a new Encoding dialect attribute because it is just describing the
layout. The stream tensor ops do not need to know the `op_type`,
`element_types` and `operand_index` parameters. It only needs the layout
information, and the attribute should implement the interface method.
- Partially implement the SpecializeEncodings pass. The responsibility
of the pass is large, so I decide to implement it incrementally. This
revision only implements the mechanism of updating stream tensor ops'
encoding, and only stream.tensor.sizeof op is supported. The rest of the
support for other stream tensor op can be added later on. The executable
duplication and the update of dispatch ops will be implemented in
subsequent PRs.

---------

Signed-off-by: hanhanW <hanhan0912@gmail.com>
  • Loading branch information
hanhanW authored Jan 9, 2025
1 parent 74f8d3c commit 02d145e
Show file tree
Hide file tree
Showing 16 changed files with 303 additions and 0 deletions.
10 changes: 10 additions & 0 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Support/LLVM.h"
Expand Down Expand Up @@ -113,6 +114,15 @@ EncodingAttr EncodingAttr::clone(AffineMap bcastMap) {
AffineMapAttr::get(bcastMap), getRoundDimsTo(), getLayouts());
}

EncodingAttr EncodingAttr::cloneWithLayouts(ArrayRef<Attribute> layouts) {
MLIRContext *ctx = getContext();
return get(ctx, getOperandIndex(), getOpType(), getElementTypes(),
/*user_indexing_maps=*/ArrayAttr(),
/*bcast_map=*/AffineMapAttr(),
/*round_dims_to=*/DenseI64ArrayAttr(),
ArrayAttr::get(ctx, layouts));
}

/// Returns the bit-width of the scalar type. If the type is complex, it returns
/// the type of individual elements * 2 (1 for real and 1 for complex).
static unsigned getTypeBitWidth(Type type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ def EncodingAttr :

/// Clones an encoding with a new bcast_map
EncodingAttr clone(AffineMap bcastMap);

/// Clones an encoding with a new layout list and drops other optional
/// parameters (because they are resolved).
EncodingAttr cloneWithLayouts(ArrayRef<Attribute> layouts);
}];

let genVerifyDecl = 0;
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ iree_compiler_cc_library(
deps = [
":IR",
"//compiler/src/iree/compiler/Dialect/HAL:hal_imports",
"//compiler/src/iree/compiler/Dialect/HAL/Analysis",
"//compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Dialect/VM/Conversion",
"@llvm-project//llvm:Support",
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ iree_cc_library(
MLIRParser
MLIRSCFDialect
MLIRTransformUtils
iree::compiler::Dialect::HAL::Analysis
iree::compiler::Dialect::HAL::Conversion::HALToVM
iree::compiler::Dialect::HAL::hal_imports
iree::compiler::Dialect::Stream::IR
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::VM::Conversion
PUBLIC
Expand Down
27 changes: 27 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@

#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"

#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
#include "iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/HAL/hal.imports.h"
#include "iree/compiler/Dialect/Stream/IR/StreamInterfaces.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
Expand Down Expand Up @@ -115,6 +118,29 @@ class HALToVMConversionInterface : public VMConversionDialectInterface {
}
};

class HALAffinityAnalysisDialectInterface
: public IREE::Stream::AffinityAnalysisDialectInterface {
public:
using AffinityAnalysisDialectInterface::AffinityAnalysisDialectInterface;
IREE::Stream::ResolveLayoutAttrFn
makeLayoutAttrResolver(ModuleOp moduleOp) const {
return [=](IREE::Stream::AffinityAttr affinityAttr, Operation *op,
SetVector<Attribute> &layoutAttrs) -> LogicalResult {
// This needs to be in the lambda because the moduleOp could be modified..
IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp);
if (failed(deviceAnalysis.run())) {
return op->emitError("failed to run DeviceAnalysis");
}
SetVector<IREE::HAL::ExecutableTargetAttr> resultSet;
deviceAnalysis.gatherRequiredExecutableTargets(affinityAttr, op,
resultSet);
// TODO(hanchung): Populate the EncodingLayoutAttr when it is ready.
layoutAttrs.insert(resultSet.begin(), resultSet.end());
return success();
};
};
};

} // namespace

HALDialect::HALDialect(MLIRContext *context)
Expand All @@ -131,6 +157,7 @@ HALDialect::HALDialect(MLIRContext *context)
#include "iree/compiler/Dialect/HAL/IR/HALOps.cpp.inc"
>();
addInterfaces<HALInlinerInterface, HALOpAsmInterface,
HALAffinityAnalysisDialectInterface,
HALToVMConversionInterface>();
}

Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Dialect/Stream/IR/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ iree_compiler_cc_library(
hdrs = [
"StreamDialect.h",
"StreamEnums.h.inc",
"StreamInterfaces.h",
"StreamOpInterfaces.h.inc",
"StreamOps.h",
"StreamOps.h.inc",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ iree_cc_library(
HDRS
"StreamDialect.h"
"StreamEnums.h.inc"
"StreamInterfaces.h"
"StreamOpInterfaces.h.inc"
"StreamOps.h"
"StreamOps.h.inc"
Expand Down
36 changes: 36 additions & 0 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_COMPILER_DIALECT_STREAM_IR_STREAMINTERACES_H_
#define IREE_COMPILER_DIALECT_STREAM_IR_STREAMINTERACES_H_

#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/Operation.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir::iree_compiler::IREE::Stream {

using ResolveLayoutAttrFn = std::function<LogicalResult(
AffinityAttr, Operation *, SetVector<Attribute> &)>;

class AffinityAnalysisDialectInterface
: public DialectInterface::Base<AffinityAnalysisDialectInterface> {
public:
AffinityAnalysisDialectInterface(Dialect *dialect) : Base(dialect) {}

/// The `moduleOp` must remain live and unmodified for as long as the returned
/// capture is. Otherwise, it will likely be incorrect or crash if the module
/// op is mutated, especially when module scope analysis is run.
virtual ResolveLayoutAttrFn
makeLayoutAttrResolver(ModuleOp moduleOp) const = 0;
};

} // namespace mlir::iree_compiler::IREE::Stream

#endif // IREE_COMPILER_DIALECT_STREAM_IR_STREAM_INTERFACES_H_
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ iree_compiler_cc_library(
"ScheduleConcurrency.cpp",
"ScheduleExecution.cpp",
"SpecializeDispatches.cpp",
"SpecializeEncodings.cpp",
"VerifyAffinities.cpp",
"VerifyAsyncAccessRanges.cpp",
"VerifyLowerings.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ iree_cc_library(
"ScheduleConcurrency.cpp"
"ScheduleExecution.cpp"
"SpecializeDispatches.cpp"
"SpecializeEncodings.cpp"
"VerifyAffinities.cpp"
"VerifyAsyncAccessRanges.cpp"
"VerifyLowerings.cpp"
Expand Down
13 changes: 13 additions & 0 deletions compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ static llvm::cl::opt<bool> clAnnotateInputAffinities(
"the pipeline for debugging."),
llvm::cl::init(false));

// TODO(hanchung): Enable the pass by default once the implementation is done.
static llvm::cl::opt<bool> clSpecializeEncodings(
"iree-stream-experimental-specialize-encodings",
llvm::cl::desc(
"Enables SpecializeEncodingPass in Stream pass pipeline. This pass is "
"currently under development, so it is not enabled by default. It can "
"only handle limited cases at this moment."),
llvm::cl::init(false));

namespace mlir::iree_compiler::IREE::Stream {

using FunctionLikeNest =
Expand Down Expand Up @@ -140,6 +149,10 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager,
// Tensor lowering and resource management
//----------------------------------------------------------------------------

if (clSpecializeEncodings) {
passManager.addPass(IREE::Stream::createSpecializeEncodingsPass());
}

// Lower stream.tensor.* ops to stream.async.* ops based on
// affinity/configuration assigned during placement.
FunctionLikeNest(passManager)
Expand Down
10 changes: 10 additions & 0 deletions compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,16 @@ def SpecializeDispatchesPass :
];
}

def SpecializeEncodingsPass :
Pass<"iree-stream-specialize-encodings", "mlir::ModuleOp"> {
let summary = "Specializes data-tiling encodings based on device analysis.";
let description = [{
Attaches layouts to encodings and duplicates executables based on device
analysis.
TODO: Unpack the context. The pass is not fully implemented yet.
}];
}

def AnnotateDispatchArgumentsPass :
Pass<"iree-stream-annotate-dispatch-arguments", "mlir::ModuleOp"> {
let summary = "Annotates dispatch arguments with potential values derived from dispatch sites.";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
#include "iree/compiler/Dialect/Stream/IR/StreamInterfaces.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTraits.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler::IREE::Stream {

#define DEBUG_TYPE "iree-stream-specialize-encodings"

#define GEN_PASS_DEF_SPECIALIZEENCODINGSPASS
#include "iree/compiler/Dialect/Stream/Transforms/Passes.h.inc"

namespace {
/// Returns a stably sorted list of dialect interfaces of T for all dialects
/// used within the given module.
template <typename T>
SmallVector<const T *> gatherUsedDialectInterfaces(mlir::ModuleOp moduleOp) {
SmallPtrSet<const T *, 4> resultSet;
for (auto dialect : moduleOp.getContext()->getLoadedDialects()) {
auto *dialectInterface = dialect->getRegisteredInterface<T>();
if (!dialectInterface)
continue;
resultSet.insert(dialectInterface);
}

// NOTE: to ensure deterministic output we sort the result so that imports are
// always added in a consistent order.
SmallVector<const T *> results = {resultSet.begin(), resultSet.end()};
llvm::sort(
results, +[](const T *a, const T *b) {
return a->getDialect()->getNamespace().compare(
b->getDialect()->getNamespace()) < 0;
});
return results;
}

// TODO(hanchung): Add "cloneWithEncoding" method to RankedTensorType.
static RankedTensorType cloneWithEncoding(RankedTensorType type,
Attribute encodingAttr) {
return RankedTensorType::get(type.getShape(), type.getElementType(),
encodingAttr);
}

static LogicalResult addLayoutsToTensorPhaseOps(
ModuleOp moduleOp, FunctionOpInterface funcOp,
IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr) {
SmallVector<IREE::Stream::AffinityOpInterface> candidates;
funcOp.walk([&](IREE::Stream::AffinityOpInterface affinityOp) {
// Only need to update encoding types for ops that have TensorPhaseOp trait.
if (!affinityOp->hasTrait<OpTrait::IREE::Stream::TensorPhaseOp>()) {
return;
}

// Bail out if the operation does not have an affinity attribute.
auto affinityAttr = affinityOp.getAffinityAttr();
if (!affinityAttr) {
return;
}
candidates.push_back(affinityOp);
});

if (candidates.empty()) {
return success();
}

IRRewriter rewriter(funcOp.getContext());
for (auto affinityOp : candidates) {
auto affinityAttr = affinityOp.getAffinityAttr();
SetVector<Attribute> layouts;
if (failed(resolveLayoutAttr(affinityAttr, moduleOp, layouts))) {
return affinityOp.emitError("failed on making layouts");
}

// Returns an updated encoding attribute if an encoding attribute is present
// in the type. Otherwise, returns std::nullopt.
auto getEncodingWithNewLayouts =
[=](Type type) -> std::optional<IREE::Encoding::EncodingAttr> {
auto rankedTensorType = dyn_cast<RankedTensorType>(type);
if (!rankedTensorType) {
return std::nullopt;
}
auto encodingAttr = IREE::Encoding::getEncodingAttr(rankedTensorType);
if (!encodingAttr) {
return std::nullopt;
}
return encodingAttr.cloneWithLayouts(layouts.getArrayRef());
};

// TODO(hanchung): Update other Stream operations.
LogicalResult result =
TypeSwitch<Operation *, LogicalResult>(affinityOp)
.Case<IREE::Stream::TensorSizeOfOp>([&](auto sizeOfOp) {
auto encodingType =
dyn_cast<RankedTensorType>(sizeOfOp.getEncoding());
if (!encodingType) {
return success();
}
std::optional<IREE::Encoding::EncodingAttr> encodingAttr =
getEncodingWithNewLayouts(encodingType);
if (!encodingAttr) {
return success();
}
rewriter.modifyOpInPlace(sizeOfOp, [&] {
sizeOfOp.setEncoding(
cloneWithEncoding(encodingType, encodingAttr.value()));
});
return success();
})
.Default([](auto *op) { return failure(); });

if (failed(result)) {
return failure();
}
}
return success();
}
} // namespace

struct SpecializeEncodingsPass
: public impl::SpecializeEncodingsPassBase<SpecializeEncodingsPass> {
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
auto usedDialects = gatherUsedDialectInterfaces<
IREE::Stream::AffinityAnalysisDialectInterface>(moduleOp);
if (usedDialects.size() != 1) {
moduleOp.emitError("expected only one dialect implementing "
"AffinityAnalysisDialectInterface");
return signalPassFailure();
}

llvm::MapVector<StringRef, IREE::Stream::ExecutableOp> executableOps;
for (auto executableOp : moduleOp.getOps<IREE::Stream::ExecutableOp>()) {
executableOps[executableOp.getName()] = executableOp;
}

IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr =
usedDialects[0]->makeLayoutAttrResolver(moduleOp);
for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
if (failed(addLayoutsToTensorPhaseOps(moduleOp, funcOp,
resolveLayoutAttr))) {
funcOp.emitError(
"failed on adding layouts to Stream::TensorPhaseOp with encodings");
return signalPassFailure();
}

// TODO(hanchung): Duplicate executables and update dispatch ops.
}
}
};

} // namespace mlir::iree_compiler::IREE::Stream
Loading

0 comments on commit 02d145e

Please sign in to comment.