Skip to content

Commit

Permalink
Merge branch 'main' into noalias
Browse files Browse the repository at this point in the history
  • Loading branch information
newling authored Jan 10, 2025
2 parents 328d1ac + 75ea24b commit aa01f4b
Show file tree
Hide file tree
Showing 15 changed files with 394 additions and 37 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/Transforms/Passes.h"
#include "iree-amd-aie/Transforms/Transforms.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h"
#include "iree-amd-aie/aie_runtime/Utils/ChannelGenerator.h"

#define DEBUG_TYPE "iree-amdaie-generate-control-overlay"

namespace mlir::iree_compiler::AMDAIE {

namespace {

/// Initializes the channel generators for the shim tiles, excluding any
/// channels that are already in use by existing circuit flows.
LogicalResult initializeChannelsGenerators(
AMDAIE::WorkgroupOp workgroupOp, const AMDAIEDeviceModel &deviceModel,
const DenseSet<TileOp> &shimTileOps,
DenseMap<Value, ChannelGenerator> &shimTileToGeneratorMap) {
// Check the number of DMA channels available for the shim tile.
uint8_t numShimDmaChannels = deviceModel.getDmaProp<uint8_t>(
AMDAIETileType::SHIMNOC, AMDAIEDmaProp::NumChannels);
std::for_each(shimTileOps.begin(), shimTileOps.end(), [&](TileOp shimTileOp) {
shimTileToGeneratorMap[shimTileOp.getResult()] =
ChannelGenerator(numShimDmaChannels, numShimDmaChannels);
});
// Exclude those channels that are already used by a circuit flow.
workgroupOp->walk([&](AMDAIE::FlowOp flowOp) {
if (flowOp.getIsPacketFlow()) return WalkResult::advance();
SmallVector<AMDAIE::ChannelOp> sourceChannels;
for (Value source : flowOp.getSources()) {
if (auto channelOp =
dyn_cast<AMDAIE::ChannelOp>(source.getDefiningOp())) {
sourceChannels.push_back(channelOp);
}
}
for (AMDAIE::ChannelOp channelOp : sourceChannels) {
AMDAIE::TileOp tileOp = channelOp.getTileOp();
uint8_t channel = channelOp.getValue();
StrmSwPortType portType = channelOp.getPortType();
AMDAIE::DMAChannelDir direction = channelOp.getDirection();
if (shimTileOps.contains(tileOp) && portType == StrmSwPortType::DMA) {
// Assign to exclude.
if (direction == AMDAIE::DMAChannelDir::MM2S) {
shimTileToGeneratorMap[tileOp.getResult()].assignProducerDMAChannel(
channel);
} else if (direction == AMDAIE::DMAChannelDir::S2MM) {
shimTileToGeneratorMap[tileOp.getResult()].assignConsumerDMAChannel(
channel);
} else {
assert(false && "unexpected DMA channel direction");
}
}
}
return WalkResult::advance();
});
return success();
}

LogicalResult generateControlOverlay(AMDAIE::WorkgroupOp workgroupOp,
bool routeShimToTileCtrl,
bool routeShimCtrlToTct) {
// Get the device model.
std::optional<AMDAIEDevice> device = getConfigAMDAIEDevice(workgroupOp);
if (!device) {
return workgroupOp->emitOpError()
<< "could not find an AMDAIEDevice attribute";
}
AMDAIEDeviceModel deviceModel = AMDAIE::getDeviceModel(device.value());

IRRewriter rewriter(workgroupOp->getContext());
DenseSet<uint32_t> occupiedCols;
DenseMap<uint32_t, AMDAIE::TileOp> columnToShimTile;
workgroupOp->walk([&](AMDAIE::TileOp tileOp) {
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
uint32_t row = getConstantIndexOrAssert(tileOp.getRow());
occupiedCols.insert(col);
if (deviceModel.isShimNOCTile(col, row)) columnToShimTile[col] = tileOp;
});

// If the column is occupied, but the shim tile op is not present, then create
// one.
rewriter.setInsertionPoint(workgroupOp.getControlCode());
for (uint32_t col : occupiedCols) {
if (!columnToShimTile.count(col)) {
auto colIndex = rewriter.create<arith::ConstantIndexOp>(
rewriter.getUnknownLoc(), col);
auto rowIndex =
rewriter.create<arith::ConstantIndexOp>(rewriter.getUnknownLoc(), 0);
columnToShimTile[col] = rewriter.create<AMDAIE::TileOp>(
rewriter.getUnknownLoc(), colIndex, rowIndex);
}
}

// Create a packet flow from the shim DMA to the tile CTRL, for sending
// control packets.
if (routeShimToTileCtrl) {
DenseMap<Value, ChannelGenerator> shimTileToGeneratorMap;
DenseSet<TileOp> shimTileOps;
for (const auto &pair : columnToShimTile) shimTileOps.insert(pair.second);
if (failed(initializeChannelsGenerators(
workgroupOp, deviceModel, shimTileOps, shimTileToGeneratorMap))) {
return failure();
}
WalkResult res = workgroupOp->walk([&](AMDAIE::TileOp tileOp) {
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
TileOp shimTileOp = columnToShimTile[col];
// Get the available channel, but do not assign it. Allow it to be
// shared across multiple packet flows as needed.
std::optional<uint8_t> maybeChannel =
shimTileToGeneratorMap[shimTileOp.getResult()]
.getProducerDMAChannel();
if (!maybeChannel) {
shimTileOp.emitOpError() << "no producer DMA channel available";
return WalkResult::interrupt();
}
auto shimDmaChannelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), shimTileOp, maybeChannel.value(),
StrmSwPortType::DMA, AMDAIE::DMAChannelDir::MM2S);
auto tileCtrlChannelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), tileOp, 0, StrmSwPortType::CTRL,
AMDAIE::DMAChannelDir::S2MM);
rewriter.create<AMDAIE::FlowOp>(
rewriter.getUnknownLoc(), ValueRange{shimDmaChannelOp},
ValueRange{tileCtrlChannelOp},
/*isPacketFlow*/ true, /*packetId*/ nullptr);
return WalkResult::advance();
});
if (res.wasInterrupted()) return failure();
}

// Create a circuit flow from the shim CTRL to the shim SOUTH 0, for sending
// Task Completion Tokens (TCTs).
if (routeShimCtrlToTct) {
for (auto [_, shimTileOp] : columnToShimTile) {
auto shimCtrlChannelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), shimTileOp, 0, StrmSwPortType::CTRL,
AMDAIE::DMAChannelDir::MM2S);
auto shimSouthChannelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), shimTileOp, 0, StrmSwPortType::SOUTH,
AMDAIE::DMAChannelDir::S2MM);
rewriter.create<AMDAIE::FlowOp>(
rewriter.getUnknownLoc(), ValueRange{shimCtrlChannelOp},
ValueRange{shimSouthChannelOp},
/*isPacketFlow*/ false, /*packetId*/ nullptr);
}
}

return success();
}

class AMDAIEGenerateControlOverlayPass
: public impl::AMDAIEGenerateControlOverlayBase<
AMDAIEGenerateControlOverlayPass> {
public:
AMDAIEGenerateControlOverlayPass(
const AMDAIEGenerateControlOverlayOptions &options)
: AMDAIEGenerateControlOverlayBase(options) {}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AMDAIEDialect>();
}

void runOnOperation() override;
};

void AMDAIEGenerateControlOverlayPass::runOnOperation() {
Operation *parentOp = getOperation();
WalkResult res = parentOp->walk([&](AMDAIE::WorkgroupOp workgroupOp) {
if (failed(generateControlOverlay(workgroupOp, routeShimToTileCtrl,
routeShimCtrlToTct))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});

if (res.wasInterrupted()) return signalPassFailure();
}

} // namespace

std::unique_ptr<Pass> createAMDAIEGenerateControlOverlayPass(
AMDAIEGenerateControlOverlayOptions options) {
return std::make_unique<AMDAIEGenerateControlOverlayPass>(options);
}

} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,9 @@ SmallVector<Operation *> AIEDeviceBuilder::createFlowOps(
for (AMDAIE::ChannelOp consumerChannel : consumerChannels) {
Value aieConsumerTile = mapper.lookup(consumerChannel.getTile());
AIE::FlowOp flowOp = rewriter.create<AIE::FlowOp>(
rewriter.getUnknownLoc(), aieProducerTile, AIE::WireBundle::DMA,
producerChannel.getValue(), aieConsumerTile, AIE::WireBundle::DMA,
rewriter.getUnknownLoc(), aieProducerTile,
producerChannel.getPortType(), producerChannel.getValue(),
aieConsumerTile, consumerChannel.getPortType(),
consumerChannel.getValue());
flowOps.push_back(flowOp.getOperation());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ iree_cc_library(
"AMDAIEFuseConsumerIntoLoop.cpp"
"AMDAIEFuseFillIntoForall.cpp"
"AMDAIEFusePackIntoLoop.cpp"
"AMDAIEGenerateControlOverlay.cpp"
"AMDAIEHoistForAffineApply.cpp"
"AMDAIEHoistLogicalObjFifo.cpp"
"AMDAIEInsertCores.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,23 @@ FailureOr<ParameterSetting> ParameterSetting::create(
}
} // namespace

/// Utility to set the packing inner permutation for A/LHS so that is packed as
/// [? ? m k] in case of matmul and [? ? ? m k] in case of batch_matmul.
static SmallVector<int64_t> setInnerPermA(bool isMatmulTransposeA) {
SmallVector<int64_t> innerPerm;
if (isMatmulTransposeA) {
innerPerm = {1, 0};
} else {
innerPerm = {0, 1};
}
return innerPerm;
}

/// Utility to set the packing inner permutation for B/RHS so that is packed as
/// - [? ? k n] in case of matmul
/// - [? ? ? k n] in case of batch_matmul
/// - [? ? n k] in case of matmul_transpose_b
/// - [? ? ? n k] in case of batch_matmul_transpose_b.
static SmallVector<int64_t> setInnerPermB(bool isMatmulTransposeB) {
SmallVector<int64_t> innerPerm;
if (isMatmulTransposeB) {
Expand All @@ -326,14 +343,34 @@ static SmallVector<int64_t> setInnerPermB(bool isMatmulTransposeB) {
return innerPerm;
}

static SmallVector<int64_t> setInnerPermA(bool isMatmulTransposeA) {
SmallVector<int64_t> innerPerm;
/// Utility to set the packing outer permutation for A/LHS so that is packed as
/// [M K ? ?] in case of matmul and [Batch M K ? ?] in case of batch_matmul.
static SmallVector<int64_t> setOuterPermA(bool isMatmulTransposeA,
bool isBatchMatmul) {
SmallVector<int64_t> outerPerm;
if (isMatmulTransposeA) {
innerPerm = {1, 0};
outerPerm = isBatchMatmul ? SmallVector<int64_t>{0, 2, 1}
: SmallVector<int64_t>{1, 0};
} else {
innerPerm = {0, 1};
outerPerm = isBatchMatmul ? SmallVector<int64_t>{0, 1, 2}
: SmallVector<int64_t>{0, 1};
}
return innerPerm;
return outerPerm;
}

/// Utility to set the packing outer permutation for B/RHS so that is packed as
/// [N K ? ?] in case of matmul and [Batch N K ? ?] in case of batch_matmul.
static SmallVector<int64_t> setOuterPermB(bool isMatmulTransposeB,
bool isBatchMatmul) {
SmallVector<int64_t> outerPerm;
if (isMatmulTransposeB) {
outerPerm = isBatchMatmul ? SmallVector<int64_t>{0, 1, 2}
: SmallVector<int64_t>{0, 1};
} else {
outerPerm = isBatchMatmul ? SmallVector<int64_t>{0, 2, 1}
: SmallVector<int64_t>{1, 0};
}
return outerPerm;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -362,7 +399,7 @@ static LogicalResult setRootConfigForPackPeelPipeline(
packedSizesL0.insert(packedSizesL0.begin(), 0);
}

// For matmul, transpose B matrix from [K N n k] to [K N k n]
// For matmul, transpose B matrix from [K N n k] to [N K k n]
// For matmul_transpose_b, we don't have to transpose the B matrix,
// since it is already [N K n k]
SmallVector<int64_t> transposePackIndices = {0, 1};
Expand All @@ -372,11 +409,12 @@ static LogicalResult setRootConfigForPackPeelPipeline(
SmallVector<int64_t> innerPermA = setInnerPermA(isMatmulTransposeA(linalgOp));
SmallVector<int64_t> innerPermB = setInnerPermB(isMatmulTransposeB(linalgOp));
SmallVector<SmallVector<int64_t>> innerPerm = {innerPermA, innerPermB};
SmallVector<int64_t> outerPermVec = {0, 1};
if (isa<linalg::BatchMatmulOp>(linalgOp)) {
outerPermVec.push_back(2);
}
SmallVector<SmallVector<int64_t>> outerPerm = {outerPermVec, outerPermVec};
bool isBatchMatmul = isa<linalg::BatchMatmulOp>(linalgOp);
SmallVector<int64_t> outerPermA =
setOuterPermA(isMatmulTransposeA(linalgOp), isBatchMatmul);
SmallVector<int64_t> outerPermB =
setOuterPermB(isMatmulTransposeB(linalgOp), isBatchMatmul);
SmallVector<SmallVector<int64_t>> outerPerm = {outerPermA, outerPermB};
if (isObjectFifo) {
// Add outer permutation for unpack. NOTE: This currently fails for some
// tests in the AIR pipeline.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ namespace mlir::iree_compiler::AMDAIE {
#define GEN_PASS_DEF_AMDAIEFUSECONSUMERINTOLOOP
#define GEN_PASS_DEF_AMDAIEFUSEFILLINTOFORALL
#define GEN_PASS_DEF_AMDAIEFUSEPACKINTOLOOP
#define GEN_PASS_DEF_AMDAIEGENERATECONTROLOVERLAY
#define GEN_PASS_DEF_AMDAIEHOISTFORLOOPAFFINEAPPLY
#define GEN_PASS_DEF_AMDAIEHOISTLOGICALOBJFIFO
#define GEN_PASS_DEF_AMDAIEINSERTAIEWORKGROUP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,11 @@ void addAMDAIEObjectFifoLoweringPasses(
passManager.addPass(createAMDAIEObjFifoBufferizationPass());
passManager.addPass(createAMDAIETemporaryAllocBufferizationPass());
passManager.addPass(createAMDAIEConnectionToFlowPass());

passManager.addPass(createAMDAIEGenerateControlOverlayPass());
passManager.addPass(createCSEPass());
passManager.addPass(createCanonicalizerPass());

passManager.addPass(createAMDAIEAssignPacketIdsPass());

passManager.addPass(createAMDAIENpuDmaToHalfDmaCpyNdPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,11 @@ std::unique_ptr<Pass> createAMDAIEFuseConsumerIntoLoopPass(
/// Create a pass to fuse the linalg.fill into the forall loops.
std::unique_ptr<Pass> createAMDAIEFuseFillIntoForallPass();

/// Create pass to generate packet-flow routings for control packets entering or
/// leaving each tile.
std::unique_ptr<Pass> createAMDAIEGenerateControlOverlayPass(
AMDAIEGenerateControlOverlayOptions options = {});

/// Hoist an affine.apply op on a scf.for op's induction variable.
std::unique_ptr<Pass> createAMDAIEHoistForLoopAffineApplyPass();

Expand Down
11 changes: 11 additions & 0 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,17 @@ def AMDAIEFusePackIntoLoop :
];
}

def AMDAIEGenerateControlOverlay : Pass<"iree-amdaie-generate-control-overlay"> {
let summary = "Spawn a streaming interconnect network for CTRL ports.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEGenerateControlOverlayPass()";
let options = [
Option<"routeShimCtrlToTct", "route-shim-to-tct", "bool", /*default=*/"true",
"Flag to generate TCT routing between tile CTRL and shim SOUTH ports.">,
Option<"routeShimToTileCtrl", "route-shim-to-tile-ctrl", "bool", /*default=*/"false",
"Flag to generate routing between shim dma DMA and tile CTRL ports, for configuration.">
];
}

def AMDAIEHoistForLoopAffineApply : Pass<"iree-amdaie-hoist-for-affine-apply"> {
let summary = "Hoist an affine apply op on a scf.for op's induction variable.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEHoistForLoopAffineApplyPass()";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ iree_lit_test_suite(
"fuse_consumer_into_loop.mlir"
"fuse_fill_into_forall.mlir"
"fuse_pack_into_loop.mlir"
"generate_control_overlay.mlir"
"hoist_for_affine_apply.mlir"
"hoist_logical_obj_fifo.mlir"
"insert_cores.mlir"
Expand Down
Loading

0 comments on commit aa01f4b

Please sign in to comment.