Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a new pass to generate column-wise control overlay #1012

Merged
merged 7 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/Transforms/Passes.h"
#include "iree-amd-aie/Transforms/Transforms.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h"
#include "iree-amd-aie/aie_runtime/Utils/ChannelGenerator.h"

#define DEBUG_TYPE "iree-amdaie-generate-control-overlay"

namespace mlir::iree_compiler::AMDAIE {

namespace {

/// Initializes the channel generators for the shim tiles, excluding any
/// channels that are already in use by existing circuit flows.
LogicalResult initializeChannelsGenerators(
AMDAIE::WorkgroupOp workgroupOp, const AMDAIEDeviceModel &deviceModel,
const DenseSet<TileOp> &shimTileOps,
DenseMap<Value, ChannelGenerator> &shimTileToGeneratorMap) {
// Check the number of DMA channels available for the shim tile.
uint8_t numShimDmaChannels = deviceModel.getDmaProp<uint8_t>(
AMDAIETileType::SHIMNOC, AMDAIEDmaProp::NumChannels);
std::for_each(shimTileOps.begin(), shimTileOps.end(), [&](TileOp shimTileOp) {
shimTileToGeneratorMap[shimTileOp.getResult()] =
ChannelGenerator(numShimDmaChannels, numShimDmaChannels);
});
// Exclude those channels that are already used by a circuit flow.
workgroupOp->walk([&](AMDAIE::FlowOp flowOp) {
if (flowOp.getIsPacketFlow()) return WalkResult::advance();
SmallVector<AMDAIE::ChannelOp> sourceChannels;
for (Value source : flowOp.getSources()) {
if (auto channelOp =
dyn_cast<AMDAIE::ChannelOp>(source.getDefiningOp())) {
sourceChannels.push_back(channelOp);
}
}
for (AMDAIE::ChannelOp channelOp : sourceChannels) {
AMDAIE::TileOp tileOp = channelOp.getTileOp();
uint8_t channel = channelOp.getValue();
StrmSwPortType portType = channelOp.getPortType();
AMDAIE::DMAChannelDir direction = channelOp.getDirection();
if (shimTileOps.contains(tileOp) && portType == StrmSwPortType::DMA) {
// Assign to exclude.
if (direction == AMDAIE::DMAChannelDir::MM2S) {
shimTileToGeneratorMap[tileOp.getResult()].assignProducerDMAChannel(
channel);
} else if (direction == AMDAIE::DMAChannelDir::S2MM) {
shimTileToGeneratorMap[tileOp.getResult()].assignConsumerDMAChannel(
channel);
} else {
assert(false && "unexpected DMA channel direction");
}
}
}
return WalkResult::advance();
});
return success();
}

LogicalResult generateControlOverlay(AMDAIE::WorkgroupOp workgroupOp,
bool routeShimToTileCtrl,
bool routeShimCtrlToTct) {
// Get the device model.
std::optional<AMDAIEDevice> device = getConfigAMDAIEDevice(workgroupOp);
if (!device) {
return workgroupOp->emitOpError()
<< "could not find an AMDAIEDevice attribute";
}
AMDAIEDeviceModel deviceModel = AMDAIE::getDeviceModel(device.value());

IRRewriter rewriter(workgroupOp->getContext());
DenseSet<uint32_t> occupiedCols;
DenseMap<uint32_t, AMDAIE::TileOp> columnToShimTile;
workgroupOp->walk([&](AMDAIE::TileOp tileOp) {
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
uint32_t row = getConstantIndexOrAssert(tileOp.getRow());
occupiedCols.insert(col);
if (deviceModel.isShimNOCTile(col, row)) columnToShimTile[col] = tileOp;
});

// If the column is occupied, but the shim tile op is not present, then create
// one.
rewriter.setInsertionPoint(workgroupOp.getControlCode());
for (uint32_t col : occupiedCols) {
if (!columnToShimTile.count(col)) {
auto colIndex = rewriter.create<arith::ConstantIndexOp>(
rewriter.getUnknownLoc(), col);
auto rowIndex =
rewriter.create<arith::ConstantIndexOp>(rewriter.getUnknownLoc(), 0);
columnToShimTile[col] = rewriter.create<AMDAIE::TileOp>(
rewriter.getUnknownLoc(), colIndex, rowIndex);
}
}

// Create a packet flow from the shim DMA to the tile CTRL, for sending
// control packets.
if (routeShimToTileCtrl) {
DenseMap<Value, ChannelGenerator> shimTileToGeneratorMap;
DenseSet<TileOp> shimTileOps;
for (const auto &pair : columnToShimTile) shimTileOps.insert(pair.second);
if (failed(initializeChannelsGenerators(
workgroupOp, deviceModel, shimTileOps, shimTileToGeneratorMap))) {
return failure();
}
WalkResult res = workgroupOp->walk([&](AMDAIE::TileOp tileOp) {
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
TileOp shimTileOp = columnToShimTile[col];
// Get the available channel, but do not assign it. Allow it to be
// shared across multiple packet flows as needed.
std::optional<uint8_t> maybeChannel =
shimTileToGeneratorMap[shimTileOp.getResult()]
.getProducerDMAChannel();
if (!maybeChannel) {
shimTileOp.emitOpError() << "no producer DMA channel available";
return WalkResult::interrupt();
}
auto shimDmaChannelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), shimTileOp, maybeChannel.value(),
StrmSwPortType::DMA, AMDAIE::DMAChannelDir::MM2S);
auto tileCtrlChannelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), tileOp, 0, StrmSwPortType::CTRL,
AMDAIE::DMAChannelDir::S2MM);
rewriter.create<AMDAIE::FlowOp>(
rewriter.getUnknownLoc(), ValueRange{shimDmaChannelOp},
ValueRange{tileCtrlChannelOp},
/*isPacketFlow*/ true, /*packetId*/ nullptr);
return WalkResult::advance();
});
if (res.wasInterrupted()) return failure();
}

// Create a circuit flow from the shim CTRL to the shim SOUTH 0, for sending
// Task Completion Tokens (TCTs).
if (routeShimCtrlToTct) {
for (auto [_, shimTileOp] : columnToShimTile) {
auto shimCtrlChannelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), shimTileOp, 0, StrmSwPortType::CTRL,
AMDAIE::DMAChannelDir::MM2S);
auto shimSouthChannelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), shimTileOp, 0, StrmSwPortType::SOUTH,
AMDAIE::DMAChannelDir::S2MM);
rewriter.create<AMDAIE::FlowOp>(
rewriter.getUnknownLoc(), ValueRange{shimCtrlChannelOp},
ValueRange{shimSouthChannelOp},
/*isPacketFlow*/ false, /*packetId*/ nullptr);
}
}

return success();
}

class AMDAIEGenerateControlOverlayPass
: public impl::AMDAIEGenerateControlOverlayBase<
AMDAIEGenerateControlOverlayPass> {
public:
AMDAIEGenerateControlOverlayPass(
const AMDAIEGenerateControlOverlayOptions &options)
: AMDAIEGenerateControlOverlayBase(options) {}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AMDAIEDialect>();
}

void runOnOperation() override;
};

void AMDAIEGenerateControlOverlayPass::runOnOperation() {
Operation *parentOp = getOperation();
WalkResult res = parentOp->walk([&](AMDAIE::WorkgroupOp workgroupOp) {
if (failed(generateControlOverlay(workgroupOp, routeShimToTileCtrl,
routeShimCtrlToTct))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});

if (res.wasInterrupted()) return signalPassFailure();
}

} // namespace

std::unique_ptr<Pass> createAMDAIEGenerateControlOverlayPass(
AMDAIEGenerateControlOverlayOptions options) {
return std::make_unique<AMDAIEGenerateControlOverlayPass>(options);
}

} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,9 @@ SmallVector<Operation *> AIEDeviceBuilder::createFlowOps(
for (AMDAIE::ChannelOp consumerChannel : consumerChannels) {
Value aieConsumerTile = mapper.lookup(consumerChannel.getTile());
AIE::FlowOp flowOp = rewriter.create<AIE::FlowOp>(
rewriter.getUnknownLoc(), aieProducerTile, AIE::WireBundle::DMA,
producerChannel.getValue(), aieConsumerTile, AIE::WireBundle::DMA,
rewriter.getUnknownLoc(), aieProducerTile,
producerChannel.getPortType(), producerChannel.getValue(),
aieConsumerTile, consumerChannel.getPortType(),
consumerChannel.getValue());
flowOps.push_back(flowOp.getOperation());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ iree_cc_library(
"AMDAIEFuseConsumerIntoLoop.cpp"
"AMDAIEFuseFillIntoForall.cpp"
"AMDAIEFusePackIntoLoop.cpp"
"AMDAIEGenerateControlOverlay.cpp"
"AMDAIEHoistForAffineApply.cpp"
"AMDAIEHoistLogicalObjFifo.cpp"
"AMDAIEInsertCores.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ namespace mlir::iree_compiler::AMDAIE {
#define GEN_PASS_DEF_AMDAIEFUSECONSUMERINTOLOOP
#define GEN_PASS_DEF_AMDAIEFUSEFILLINTOFORALL
#define GEN_PASS_DEF_AMDAIEFUSEPACKINTOLOOP
#define GEN_PASS_DEF_AMDAIEGENERATECONTROLOVERLAY
#define GEN_PASS_DEF_AMDAIEHOISTFORLOOPAFFINEAPPLY
#define GEN_PASS_DEF_AMDAIEHOISTLOGICALOBJFIFO
#define GEN_PASS_DEF_AMDAIEINSERTAIEWORKGROUP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,11 @@ void addAMDAIEObjectFifoLoweringPasses(
passManager.addPass(createAMDAIEObjFifoBufferizationPass());
passManager.addPass(createAMDAIETemporaryAllocBufferizationPass());
passManager.addPass(createAMDAIEConnectionToFlowPass());

passManager.addPass(createAMDAIEGenerateControlOverlayPass());
passManager.addPass(createCSEPass());
passManager.addPass(createCanonicalizerPass());

passManager.addPass(createAMDAIEAssignPacketIdsPass());

passManager.addPass(createAMDAIENpuDmaToHalfDmaCpyNdPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,11 @@ std::unique_ptr<Pass> createAMDAIEFuseConsumerIntoLoopPass(
/// Create a pass to fuse the linalg.fill into the forall loops.
std::unique_ptr<Pass> createAMDAIEFuseFillIntoForallPass();

/// Create pass to generate packet-flow routings for control packets entering or
/// leaving each tile.
std::unique_ptr<Pass> createAMDAIEGenerateControlOverlayPass(
AMDAIEGenerateControlOverlayOptions options = {});

/// Hoist an affine.apply op on a scf.for op's induction variable.
std::unique_ptr<Pass> createAMDAIEHoistForLoopAffineApplyPass();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,17 @@ def AMDAIEFusePackIntoLoop :
];
}

def AMDAIEGenerateControlOverlay : Pass<"iree-amdaie-generate-control-overlay"> {
let summary = "Spawn a streaming interconnect network for CTRL ports.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEGenerateControlOverlayPass()";
let options = [
Option<"routeShimCtrlToTct", "route-shim-to-tct", "bool", /*default=*/"true",
"Flag to generate TCT routing between tile CTRL and shim SOUTH ports.">,
Option<"routeShimToTileCtrl", "route-shim-to-tile-ctrl", "bool", /*default=*/"false",
"Flag to generate routing between shim dma DMA and tile CTRL ports, for configuration.">
];
}

def AMDAIEHoistForLoopAffineApply : Pass<"iree-amdaie-hoist-for-affine-apply"> {
let summary = "Hoist an affine apply op on a scf.for op's induction variable.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEHoistForLoopAffineApplyPass()";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ iree_lit_test_suite(
"fuse_consumer_into_loop.mlir"
"fuse_fill_into_forall.mlir"
"fuse_pack_into_loop.mlir"
"generate_control_overlay.mlir"
"hoist_for_affine_apply.mlir"
"hoist_logical_obj_fifo.mlir"
"insert_cores.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-amdaie-generate-control-overlay{route-shim-to-tct=true route-shim-to-tile-ctrl=true},canonicalize,cse))" --split-input-file --verify-diagnostics %s | FileCheck %s

// Device attribute is required for route-shim-to-tile-ctrl.
module {
func.func @no_amdaie_device() {
// expected-error @+1 {{could not find an AMDAIEDevice attribute}}
amdaie.workgroup {
amdaie.controlcode {
amdaie.end
}
}
return
}
}

// -----

// Shim tile (0, 0) has two producer (MM2S) channels,
// both of which are already utilized by existing circuit flows.
// No producer DMA channel is available for route-shim-to-tile-ctrl.
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @no_available_channel() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
amdaie.workgroup {
// expected-error @+1 {{no producer DMA channel available}}
%tile_0_0 = amdaie.tile(%c0, %c0)
%tile_0_1 = amdaie.tile(%c0, %c1)
%channel_0 = amdaie.channel(%tile_0_0, 0, port_type = DMA, direction = MM2S)
%channel_1 = amdaie.channel(%tile_0_1, 0, port_type = DMA, direction = S2MM)
%flow_0 = amdaie.flow({%channel_0} -> {%channel_1}) {is_packet_flow = false}
%channel_2 = amdaie.channel(%tile_0_0, 1, port_type = DMA, direction = MM2S)
%channel_3 = amdaie.channel(%tile_0_1, 1, port_type = DMA, direction = S2MM)
%flow_1 = amdaie.flow({%channel_2} -> {%channel_3}) {is_packet_flow = false}
amdaie.controlcode {
amdaie.end
}
}
return
}
}


// -----

// Successfully inserted six packet flows from shim DMA channels to tile CTRL channels,
// and one circuit flow from shim CTRL to shim SOUTH 0.
// CHECK-LABEL: @column_control_overlay
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[C5:.*]] = arith.constant 5 : index
// CHECK: amdaie.workgroup {
// CHECK: %[[TILE_0_0:.*]] = amdaie.tile(%[[C0]], %[[C0]])
// CHECK: %[[TILE_0_1:.*]] = amdaie.tile(%[[C0]], %[[C1]])
// CHECK: %[[TILE_0_2:.*]] = amdaie.tile(%[[C0]], %[[C2]])
// CHECK: %[[TILE_0_3:.*]] = amdaie.tile(%[[C0]], %[[C3]])
// CHECK: %[[TILE_0_4:.*]] = amdaie.tile(%[[C0]], %[[C4]])
// CHECK: %[[TILE_0_5:.*]] = amdaie.tile(%[[C0]], %[[C5]])
// CHECK: %[[CHANNEL_0:.*]] = amdaie.channel(%[[TILE_0_0]], 0, port_type = DMA, direction = MM2S)
// CHECK: %[[CHANNEL_1:.*]] = amdaie.channel(%[[TILE_0_0]], 0, port_type = CTRL, direction = S2MM)
// CHECK: %[[FLOW_0:.*]] = amdaie.flow({%[[CHANNEL_0]]} -> {%[[CHANNEL_1]]}) {is_packet_flow = true}
// CHECK: %[[CHANNEL_2:.*]] = amdaie.channel(%[[TILE_0_0]], 1, port_type = DMA, direction = MM2S)
// CHECK: %[[CHANNEL_3:.*]] = amdaie.channel(%[[TILE_0_1]], 0, port_type = CTRL, direction = S2MM)
// CHECK: %[[FLOW_1:.*]] = amdaie.flow({%[[CHANNEL_2]]} -> {%[[CHANNEL_3]]}) {is_packet_flow = true}
// CHECK: %[[CHANNEL_4:.*]] = amdaie.channel(%[[TILE_0_2]], 0, port_type = CTRL, direction = S2MM)
// CHECK: %[[FLOW_2:.*]] = amdaie.flow({%[[CHANNEL_0]]} -> {%[[CHANNEL_4]]}) {is_packet_flow = true}
// CHECK: %[[CHANNEL_5:.*]] = amdaie.channel(%[[TILE_0_3]], 0, port_type = CTRL, direction = S2MM)
// CHECK: %[[FLOW_3:.*]] = amdaie.flow({%[[CHANNEL_2]]} -> {%[[CHANNEL_5]]}) {is_packet_flow = true}
// CHECK: %[[CHANNEL_6:.*]] = amdaie.channel(%[[TILE_0_4]], 0, port_type = CTRL, direction = S2MM)
// CHECK: %[[FLOW_4:.*]] = amdaie.flow({%[[CHANNEL_0]]} -> {%[[CHANNEL_6]]}) {is_packet_flow = true}
// CHECK: %[[CHANNEL_7:.*]] = amdaie.channel(%[[TILE_0_5]], 0, port_type = CTRL, direction = S2MM)
// CHECK: %[[FLOW_5:.*]] = amdaie.flow({%[[CHANNEL_2]]} -> {%[[CHANNEL_7]]}) {is_packet_flow = true}
// CHECK: %[[CHANNEL_8:.*]] = amdaie.channel(%[[TILE_0_0]], 0, port_type = CTRL, direction = MM2S)
// CHECK: %[[CHANNEL_9:.*]] = amdaie.channel(%[[TILE_0_0]], 0, port_type = SOUTH, direction = S2MM)
// CHECK: %[[FLOW_6:.*]] = amdaie.flow({%[[CHANNEL_8]]} -> {%[[CHANNEL_9]]}) {is_packet_flow = false}
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @column_control_overlay() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c5 = arith.constant 5 : index
amdaie.workgroup {
%tile_0_0 = amdaie.tile(%c0, %c0)
%tile_0_1 = amdaie.tile(%c0, %c1)
%tile_0_2 = amdaie.tile(%c0, %c2)
%tile_0_3 = amdaie.tile(%c0, %c3)
%tile_0_4 = amdaie.tile(%c0, %c4)
%tile_0_5 = amdaie.tile(%c0, %c5)
amdaie.controlcode {
amdaie.end
}
}
return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ LogicalResult configureStreamSwitch(const AMDAIEDeviceModel &deviceModel,
const TileLoc &tileLoc,
const std::vector<Connect> &connects) {
auto devInst = const_cast<XAie_DevInst *>(&deviceModel.devInst);
// FIXME hack for TCT routing
// mlir-air legacy, hack for TCT routing
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, since the new pass operates on AMDAIE::TileOp and FlowOp, this hack cannot be removed for tests using AIR.

// TODO copy-pasted: Support both channels
// TODO(max): find a way to keep track so that multiple calls don't
// rewrite/overwrite with same data.
Expand Down
Loading