Skip to content

Commit

Permalink
Add a new pass to generate column-wise control overlay (#1012)
Browse files Browse the repository at this point in the history
Adapted from Xilinx/mlir-aie#1705. 

Introduces a new pass to automatically insert the following two groups
of flows:

1. `route-shim-to-tct`, circuit flows from shim `CTRL` to shim `SOUTH`
ports, for sending TCTs.
2. `route-shim-to-tile-ctrl`, packet flows between shim `DMA` to
shim/mem/compute tile `CTRL` ports, for sending control packets.
  • Loading branch information
Yu-Zhewen authored Jan 9, 2025
1 parent 53b96d5 commit 75ea24b
Show file tree
Hide file tree
Showing 10 changed files with 322 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/Transforms/Passes.h"
#include "iree-amd-aie/Transforms/Transforms.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h"
#include "iree-amd-aie/aie_runtime/Utils/ChannelGenerator.h"

#define DEBUG_TYPE "iree-amdaie-generate-control-overlay"

namespace mlir::iree_compiler::AMDAIE {

namespace {

/// Initializes the channel generators for the shim tiles, excluding any
/// channels that are already in use by existing circuit flows.
LogicalResult initializeChannelsGenerators(
AMDAIE::WorkgroupOp workgroupOp, const AMDAIEDeviceModel &deviceModel,
const DenseSet<TileOp> &shimTileOps,
DenseMap<Value, ChannelGenerator> &shimTileToGeneratorMap) {
// Check the number of DMA channels available for the shim tile.
uint8_t numShimDmaChannels = deviceModel.getDmaProp<uint8_t>(
AMDAIETileType::SHIMNOC, AMDAIEDmaProp::NumChannels);
std::for_each(shimTileOps.begin(), shimTileOps.end(), [&](TileOp shimTileOp) {
shimTileToGeneratorMap[shimTileOp.getResult()] =
ChannelGenerator(numShimDmaChannels, numShimDmaChannels);
});
// Exclude those channels that are already used by a circuit flow.
workgroupOp->walk([&](AMDAIE::FlowOp flowOp) {
if (flowOp.getIsPacketFlow()) return WalkResult::advance();
SmallVector<AMDAIE::ChannelOp> sourceChannels;
for (Value source : flowOp.getSources()) {
if (auto channelOp =
dyn_cast<AMDAIE::ChannelOp>(source.getDefiningOp())) {
sourceChannels.push_back(channelOp);
}
}
for (AMDAIE::ChannelOp channelOp : sourceChannels) {
AMDAIE::TileOp tileOp = channelOp.getTileOp();
uint8_t channel = channelOp.getValue();
StrmSwPortType portType = channelOp.getPortType();
AMDAIE::DMAChannelDir direction = channelOp.getDirection();
if (shimTileOps.contains(tileOp) && portType == StrmSwPortType::DMA) {
// Assign to exclude.
if (direction == AMDAIE::DMAChannelDir::MM2S) {
shimTileToGeneratorMap[tileOp.getResult()].assignProducerDMAChannel(
channel);
} else if (direction == AMDAIE::DMAChannelDir::S2MM) {
shimTileToGeneratorMap[tileOp.getResult()].assignConsumerDMAChannel(
channel);
} else {
assert(false && "unexpected DMA channel direction");
}
}
}
return WalkResult::advance();
});
return success();
}

LogicalResult generateControlOverlay(AMDAIE::WorkgroupOp workgroupOp,
bool routeShimToTileCtrl,
bool routeShimCtrlToTct) {
// Get the device model.
std::optional<AMDAIEDevice> device = getConfigAMDAIEDevice(workgroupOp);
if (!device) {
return workgroupOp->emitOpError()
<< "could not find an AMDAIEDevice attribute";
}
AMDAIEDeviceModel deviceModel = AMDAIE::getDeviceModel(device.value());

IRRewriter rewriter(workgroupOp->getContext());
DenseSet<uint32_t> occupiedCols;
DenseMap<uint32_t, AMDAIE::TileOp> columnToShimTile;
workgroupOp->walk([&](AMDAIE::TileOp tileOp) {
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
uint32_t row = getConstantIndexOrAssert(tileOp.getRow());
occupiedCols.insert(col);
if (deviceModel.isShimNOCTile(col, row)) columnToShimTile[col] = tileOp;
});

// If the column is occupied, but the shim tile op is not present, then create
// one.
rewriter.setInsertionPoint(workgroupOp.getControlCode());
for (uint32_t col : occupiedCols) {
if (!columnToShimTile.count(col)) {
auto colIndex = rewriter.create<arith::ConstantIndexOp>(
rewriter.getUnknownLoc(), col);
auto rowIndex =
rewriter.create<arith::ConstantIndexOp>(rewriter.getUnknownLoc(), 0);
columnToShimTile[col] = rewriter.create<AMDAIE::TileOp>(
rewriter.getUnknownLoc(), colIndex, rowIndex);
}
}

// Create a packet flow from the shim DMA to the tile CTRL, for sending
// control packets.
if (routeShimToTileCtrl) {
DenseMap<Value, ChannelGenerator> shimTileToGeneratorMap;
DenseSet<TileOp> shimTileOps;
for (const auto &pair : columnToShimTile) shimTileOps.insert(pair.second);
if (failed(initializeChannelsGenerators(
workgroupOp, deviceModel, shimTileOps, shimTileToGeneratorMap))) {
return failure();
}
WalkResult res = workgroupOp->walk([&](AMDAIE::TileOp tileOp) {
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
TileOp shimTileOp = columnToShimTile[col];
// Get the available channel, but do not assign it. Allow it to be
// shared across multiple packet flows as needed.
std::optional<uint8_t> maybeChannel =
shimTileToGeneratorMap[shimTileOp.getResult()]
.getProducerDMAChannel();
if (!maybeChannel) {
shimTileOp.emitOpError() << "no producer DMA channel available";
return WalkResult::interrupt();
}
auto shimDmaChannelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), shimTileOp, maybeChannel.value(),
StrmSwPortType::DMA, AMDAIE::DMAChannelDir::MM2S);
auto tileCtrlChannelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), tileOp, 0, StrmSwPortType::CTRL,
AMDAIE::DMAChannelDir::S2MM);
rewriter.create<AMDAIE::FlowOp>(
rewriter.getUnknownLoc(), ValueRange{shimDmaChannelOp},
ValueRange{tileCtrlChannelOp},
/*isPacketFlow*/ true, /*packetId*/ nullptr);
return WalkResult::advance();
});
if (res.wasInterrupted()) return failure();
}

// Create a circuit flow from the shim CTRL to the shim SOUTH 0, for sending
// Task Completion Tokens (TCTs).
if (routeShimCtrlToTct) {
for (auto [_, shimTileOp] : columnToShimTile) {
auto shimCtrlChannelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), shimTileOp, 0, StrmSwPortType::CTRL,
AMDAIE::DMAChannelDir::MM2S);
auto shimSouthChannelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), shimTileOp, 0, StrmSwPortType::SOUTH,
AMDAIE::DMAChannelDir::S2MM);
rewriter.create<AMDAIE::FlowOp>(
rewriter.getUnknownLoc(), ValueRange{shimCtrlChannelOp},
ValueRange{shimSouthChannelOp},
/*isPacketFlow*/ false, /*packetId*/ nullptr);
}
}

return success();
}

class AMDAIEGenerateControlOverlayPass
: public impl::AMDAIEGenerateControlOverlayBase<
AMDAIEGenerateControlOverlayPass> {
public:
AMDAIEGenerateControlOverlayPass(
const AMDAIEGenerateControlOverlayOptions &options)
: AMDAIEGenerateControlOverlayBase(options) {}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AMDAIEDialect>();
}

void runOnOperation() override;
};

void AMDAIEGenerateControlOverlayPass::runOnOperation() {
Operation *parentOp = getOperation();
WalkResult res = parentOp->walk([&](AMDAIE::WorkgroupOp workgroupOp) {
if (failed(generateControlOverlay(workgroupOp, routeShimToTileCtrl,
routeShimCtrlToTct))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});

if (res.wasInterrupted()) return signalPassFailure();
}

} // namespace

std::unique_ptr<Pass> createAMDAIEGenerateControlOverlayPass(
AMDAIEGenerateControlOverlayOptions options) {
return std::make_unique<AMDAIEGenerateControlOverlayPass>(options);
}

} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,9 @@ SmallVector<Operation *> AIEDeviceBuilder::createFlowOps(
for (AMDAIE::ChannelOp consumerChannel : consumerChannels) {
Value aieConsumerTile = mapper.lookup(consumerChannel.getTile());
AIE::FlowOp flowOp = rewriter.create<AIE::FlowOp>(
rewriter.getUnknownLoc(), aieProducerTile, AIE::WireBundle::DMA,
producerChannel.getValue(), aieConsumerTile, AIE::WireBundle::DMA,
rewriter.getUnknownLoc(), aieProducerTile,
producerChannel.getPortType(), producerChannel.getValue(),
aieConsumerTile, consumerChannel.getPortType(),
consumerChannel.getValue());
flowOps.push_back(flowOp.getOperation());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ iree_cc_library(
"AMDAIEFuseConsumerIntoLoop.cpp"
"AMDAIEFuseFillIntoForall.cpp"
"AMDAIEFusePackIntoLoop.cpp"
"AMDAIEGenerateControlOverlay.cpp"
"AMDAIEHoistForAffineApply.cpp"
"AMDAIEHoistLogicalObjFifo.cpp"
"AMDAIEInsertCores.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ namespace mlir::iree_compiler::AMDAIE {
#define GEN_PASS_DEF_AMDAIEFUSECONSUMERINTOLOOP
#define GEN_PASS_DEF_AMDAIEFUSEFILLINTOFORALL
#define GEN_PASS_DEF_AMDAIEFUSEPACKINTOLOOP
#define GEN_PASS_DEF_AMDAIEGENERATECONTROLOVERLAY
#define GEN_PASS_DEF_AMDAIEHOISTFORLOOPAFFINEAPPLY
#define GEN_PASS_DEF_AMDAIEHOISTLOGICALOBJFIFO
#define GEN_PASS_DEF_AMDAIEINSERTAIEWORKGROUP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,11 @@ void addAMDAIEObjectFifoLoweringPasses(
passManager.addPass(createAMDAIEObjFifoBufferizationPass());
passManager.addPass(createAMDAIETemporaryAllocBufferizationPass());
passManager.addPass(createAMDAIEConnectionToFlowPass());

passManager.addPass(createAMDAIEGenerateControlOverlayPass());
passManager.addPass(createCSEPass());
passManager.addPass(createCanonicalizerPass());

passManager.addPass(createAMDAIEAssignPacketIdsPass());

passManager.addPass(createAMDAIENpuDmaToHalfDmaCpyNdPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,11 @@ std::unique_ptr<Pass> createAMDAIEFuseConsumerIntoLoopPass(
/// Create a pass to fuse the linalg.fill into the forall loops.
std::unique_ptr<Pass> createAMDAIEFuseFillIntoForallPass();

/// Create pass to generate packet-flow routings for control packets entering or
/// leaving each tile.
std::unique_ptr<Pass> createAMDAIEGenerateControlOverlayPass(
AMDAIEGenerateControlOverlayOptions options = {});

/// Hoist an affine.apply op on a scf.for op's induction variable.
std::unique_ptr<Pass> createAMDAIEHoistForLoopAffineApplyPass();

Expand Down
11 changes: 11 additions & 0 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,17 @@ def AMDAIEFusePackIntoLoop :
];
}

def AMDAIEGenerateControlOverlay : Pass<"iree-amdaie-generate-control-overlay"> {
let summary = "Spawn a streaming interconnect network for CTRL ports.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEGenerateControlOverlayPass()";
let options = [
Option<"routeShimCtrlToTct", "route-shim-to-tct", "bool", /*default=*/"true",
"Flag to generate TCT routing between tile CTRL and shim SOUTH ports.">,
Option<"routeShimToTileCtrl", "route-shim-to-tile-ctrl", "bool", /*default=*/"false",
"Flag to generate routing between shim dma DMA and tile CTRL ports, for configuration.">
];
}

def AMDAIEHoistForLoopAffineApply : Pass<"iree-amdaie-hoist-for-affine-apply"> {
let summary = "Hoist an affine apply op on a scf.for op's induction variable.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEHoistForLoopAffineApplyPass()";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ iree_lit_test_suite(
"fuse_consumer_into_loop.mlir"
"fuse_fill_into_forall.mlir"
"fuse_pack_into_loop.mlir"
"generate_control_overlay.mlir"
"hoist_for_affine_apply.mlir"
"hoist_logical_obj_fifo.mlir"
"insert_cores.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-amdaie-generate-control-overlay{route-shim-to-tct=true route-shim-to-tile-ctrl=true},canonicalize,cse))" --split-input-file --verify-diagnostics %s | FileCheck %s

// Device attribute is required for route-shim-to-tile-ctrl.
module {
func.func @no_amdaie_device() {
// expected-error @+1 {{could not find an AMDAIEDevice attribute}}
amdaie.workgroup {
amdaie.controlcode {
amdaie.end
}
}
return
}
}

// -----

// Shim tile (0, 0) has two producer (MM2S) channels,
// both of which are already utilized by existing circuit flows.
// No producer DMA channel is available for route-shim-to-tile-ctrl.
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @no_available_channel() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
amdaie.workgroup {
// expected-error @+1 {{no producer DMA channel available}}
%tile_0_0 = amdaie.tile(%c0, %c0)
%tile_0_1 = amdaie.tile(%c0, %c1)
%channel_0 = amdaie.channel(%tile_0_0, 0, port_type = DMA, direction = MM2S)
%channel_1 = amdaie.channel(%tile_0_1, 0, port_type = DMA, direction = S2MM)
%flow_0 = amdaie.flow({%channel_0} -> {%channel_1}) {is_packet_flow = false}
%channel_2 = amdaie.channel(%tile_0_0, 1, port_type = DMA, direction = MM2S)
%channel_3 = amdaie.channel(%tile_0_1, 1, port_type = DMA, direction = S2MM)
%flow_1 = amdaie.flow({%channel_2} -> {%channel_3}) {is_packet_flow = false}
amdaie.controlcode {
amdaie.end
}
}
return
}
}


// -----

// Successfully inserted six packet flows from shim DMA channels to tile CTRL channels,
// and one circuit flow from shim CTRL to shim SOUTH 0.
// CHECK-LABEL: @column_control_overlay
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[C5:.*]] = arith.constant 5 : index
// CHECK: amdaie.workgroup {
// CHECK: %[[TILE_0_0:.*]] = amdaie.tile(%[[C0]], %[[C0]])
// CHECK: %[[TILE_0_1:.*]] = amdaie.tile(%[[C0]], %[[C1]])
// CHECK: %[[TILE_0_2:.*]] = amdaie.tile(%[[C0]], %[[C2]])
// CHECK: %[[TILE_0_3:.*]] = amdaie.tile(%[[C0]], %[[C3]])
// CHECK: %[[TILE_0_4:.*]] = amdaie.tile(%[[C0]], %[[C4]])
// CHECK: %[[TILE_0_5:.*]] = amdaie.tile(%[[C0]], %[[C5]])
// CHECK: %[[CHANNEL_0:.*]] = amdaie.channel(%[[TILE_0_0]], 0, port_type = DMA, direction = MM2S)
// CHECK: %[[CHANNEL_1:.*]] = amdaie.channel(%[[TILE_0_0]], 0, port_type = CTRL, direction = S2MM)
// CHECK: %[[FLOW_0:.*]] = amdaie.flow({%[[CHANNEL_0]]} -> {%[[CHANNEL_1]]}) {is_packet_flow = true}
// CHECK: %[[CHANNEL_2:.*]] = amdaie.channel(%[[TILE_0_0]], 1, port_type = DMA, direction = MM2S)
// CHECK: %[[CHANNEL_3:.*]] = amdaie.channel(%[[TILE_0_1]], 0, port_type = CTRL, direction = S2MM)
// CHECK: %[[FLOW_1:.*]] = amdaie.flow({%[[CHANNEL_2]]} -> {%[[CHANNEL_3]]}) {is_packet_flow = true}
// CHECK: %[[CHANNEL_4:.*]] = amdaie.channel(%[[TILE_0_2]], 0, port_type = CTRL, direction = S2MM)
// CHECK: %[[FLOW_2:.*]] = amdaie.flow({%[[CHANNEL_0]]} -> {%[[CHANNEL_4]]}) {is_packet_flow = true}
// CHECK: %[[CHANNEL_5:.*]] = amdaie.channel(%[[TILE_0_3]], 0, port_type = CTRL, direction = S2MM)
// CHECK: %[[FLOW_3:.*]] = amdaie.flow({%[[CHANNEL_2]]} -> {%[[CHANNEL_5]]}) {is_packet_flow = true}
// CHECK: %[[CHANNEL_6:.*]] = amdaie.channel(%[[TILE_0_4]], 0, port_type = CTRL, direction = S2MM)
// CHECK: %[[FLOW_4:.*]] = amdaie.flow({%[[CHANNEL_0]]} -> {%[[CHANNEL_6]]}) {is_packet_flow = true}
// CHECK: %[[CHANNEL_7:.*]] = amdaie.channel(%[[TILE_0_5]], 0, port_type = CTRL, direction = S2MM)
// CHECK: %[[FLOW_5:.*]] = amdaie.flow({%[[CHANNEL_2]]} -> {%[[CHANNEL_7]]}) {is_packet_flow = true}
// CHECK: %[[CHANNEL_8:.*]] = amdaie.channel(%[[TILE_0_0]], 0, port_type = CTRL, direction = MM2S)
// CHECK: %[[CHANNEL_9:.*]] = amdaie.channel(%[[TILE_0_0]], 0, port_type = SOUTH, direction = S2MM)
// CHECK: %[[FLOW_6:.*]] = amdaie.flow({%[[CHANNEL_8]]} -> {%[[CHANNEL_9]]}) {is_packet_flow = false}
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @column_control_overlay() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c5 = arith.constant 5 : index
amdaie.workgroup {
%tile_0_0 = amdaie.tile(%c0, %c0)
%tile_0_1 = amdaie.tile(%c0, %c1)
%tile_0_2 = amdaie.tile(%c0, %c2)
%tile_0_3 = amdaie.tile(%c0, %c3)
%tile_0_4 = amdaie.tile(%c0, %c4)
%tile_0_5 = amdaie.tile(%c0, %c5)
amdaie.controlcode {
amdaie.end
}
}
return
}
}
2 changes: 1 addition & 1 deletion runtime/src/iree-amd-aie/aie_runtime/iree_aie_configure.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ LogicalResult configureStreamSwitch(const AMDAIEDeviceModel &deviceModel,
const TileLoc &tileLoc,
const std::vector<Connect> &connects) {
auto devInst = const_cast<XAie_DevInst *>(&deviceModel.devInst);
// FIXME hack for TCT routing
// mlir-air legacy, hack for TCT routing
// TODO copy-pasted: Support both channels
// TODO(max): find a way to keep track so that multiple calls don't
// rewrite/overwrite with same data.
Expand Down

0 comments on commit 75ea24b

Please sign in to comment.