Skip to content

Commit

Permalink
[ChannelGenerator] Assign the channel only for circuit flow (#1013)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu-Zhewen authored Jan 8, 2025
1 parent 756cac1 commit f76c245
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,25 @@ LogicalResult assignChannels(AMDAIE::WorkgroupOp workgroupOp) {
return connectionOp.emitOpError()
<< "expected a `LogicalObjFifoOpInterface` target";
}
std::optional<AMDAIE::ConnectionType> connectionType =
connectionOp.getConnectionType();
bool isPacketFlow = connectionType && connectionType.value() ==
AMDAIE::ConnectionType::Packet;

rewriter.setInsertionPoint(connectionOp);
SmallVector<Value> sourceChannels;
for (Value tile : sourceLogicalObjFifo.getTiles()) {
assert(tileToGeneratorMap.contains(tile) &&
"no channel generator found for tile");
std::optional<uint8_t> maybeChannel =
tileToGeneratorMap[tile].getAndAssignProducerDMAChannel();
tileToGeneratorMap[tile].getProducerDMAChannel();
if (!maybeChannel) {
return connectionOp.emitOpError()
<< "no producer DMA channel available";
}
// Only assign the channel if it is for circuit flow.
if (!isPacketFlow)
tileToGeneratorMap[tile].assignProducerDMAChannel(maybeChannel.value());
auto channelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), tile, maybeChannel.value(),
StrmSwPortType::DMA, AMDAIE::DMAChannelDir::MM2S);
Expand All @@ -81,11 +88,14 @@ LogicalResult assignChannels(AMDAIE::WorkgroupOp workgroupOp) {
assert(tileToGeneratorMap.contains(tile) &&
"no channel generator found for tile");
std::optional<uint8_t> maybeChannel =
tileToGeneratorMap[tile].getAndAssignConsumerDMAChannel();
tileToGeneratorMap[tile].getConsumerDMAChannel();
if (!maybeChannel) {
return connectionOp.emitOpError()
<< "no consumer DMA channel available";
}
// Only assign the channel if it is for circuit flow.
if (!isPacketFlow)
tileToGeneratorMap[tile].assignConsumerDMAChannel(maybeChannel.value());
auto channelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), tile, maybeChannel.value(),
StrmSwPortType::DMA, AMDAIE::DMAChannelDir::S2MM);
Expand Down
57 changes: 46 additions & 11 deletions runtime/src/iree-amd-aie/aie_runtime/Utils/ChannelGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,51 +16,86 @@ using namespace llvm;

namespace mlir::iree_compiler::AMDAIE {

enum class ChannelAssignmentMode { FirstAvailable, RoundRobin };

/// Utility to generate valid channels.
class ChannelGenerator {
public:
ChannelGenerator() {}
ChannelGenerator(uint8_t numProducerChannels, uint8_t numConsumerChannels)
: numProducerChannels(numProducerChannels),
numConsumerChannels(numConsumerChannels) {}
numConsumerChannels(numConsumerChannels) {
assert(numProducerChannels > 0 && numConsumerChannels > 0 &&
"Invalid number of producer/consumer channels.");
// Initialize to the last channel for round-robin usage.
lastRetrievedProducerChannel = numProducerChannels - 1;
lastRetrievedConsumerChannel = numConsumerChannels - 1;
}

/// Returns its next usable producer channel.
std::optional<uint8_t> getAndAssignProducerDMAChannel() {
for (uint8_t i = 0; i < numProducerChannels; i++) {
/// Retrieves the next producer channel using the specified strategy.
/// Defaults to round-robin for balanced load distribution, using
/// `lastRetrievedProducerChannel` to track the last channel accessed.
std::optional<uint8_t> getProducerDMAChannel(
ChannelAssignmentMode mode = ChannelAssignmentMode::RoundRobin) {
for (uint8_t offset = 1; offset <= numProducerChannels; ++offset) {
uint8_t i;
if (mode == ChannelAssignmentMode::FirstAvailable) {
i = offset - 1;
} else if (mode == ChannelAssignmentMode::RoundRobin) {
i = (lastRetrievedProducerChannel + offset) % numProducerChannels;
} else {
assert(false && "Unsupported ChannelAssignmentMode");
}
if (!assignedProducerChannels.count(i)) {
assignedProducerChannels.insert(i);
lastRetrievedProducerChannel = i;
return i;
}
}
return std::nullopt;
}

/// Returns its next usable consumer channel.
std::optional<uint8_t> getAndAssignConsumerDMAChannel() {
for (uint8_t i = 0; i < numConsumerChannels; i++) {
/// Retrieves the next consumer channel using the specified strategy.
/// Defaults to round-robin for balanced load distribution, using
/// `lastRetrievedConsumerChannel` to track the last channel accessed.
std::optional<uint8_t> getConsumerDMAChannel(
ChannelAssignmentMode mode = ChannelAssignmentMode::RoundRobin) {
for (uint8_t offset = 1; offset <= numConsumerChannels; ++offset) {
uint8_t i;
if (mode == ChannelAssignmentMode::FirstAvailable) {
i = offset - 1;
} else if (mode == ChannelAssignmentMode::RoundRobin) {
i = (lastRetrievedConsumerChannel + offset) % numConsumerChannels;
} else {
assert(false && "Unsupported ChannelAssignmentMode");
}
if (!assignedConsumerChannels.count(i)) {
assignedConsumerChannels.insert(i);
lastRetrievedConsumerChannel = i;
return i;
}
}
return std::nullopt;
}

/// Assigns the provided producer channel.
/// Assigns the provided producer channel, only used for circuit flow.
void assignProducerDMAChannel(uint8_t channel) {
assignedProducerChannels.insert(channel);
}

/// Assigns the provided consumer channel.
/// Assigns the provided consumer channel, only used for circuit flow.
void assignConsumerDMAChannel(uint8_t channel) {
assignedConsumerChannels.insert(channel);
}

private:
uint8_t numProducerChannels = 0;
uint8_t numConsumerChannels = 0;
// Tracks the channels that are used by circuit flows.
DenseSet<uint8_t> assignedProducerChannels;
DenseSet<uint8_t> assignedConsumerChannels;
// Tracks the last retrieved channel in `getProducerDMAChannel` and
// `getConsumerDMAChannel` for round-robin usage.
uint8_t lastRetrievedProducerChannel = 0;
uint8_t lastRetrievedConsumerChannel = 0;
};

} // namespace mlir::iree_compiler::AMDAIE
Expand Down
112 changes: 100 additions & 12 deletions runtime/src/iree-amd-aie/aie_runtime/Utils/test/ChannelGeneratorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,74 @@ namespace {

using namespace mlir::iree_compiler::AMDAIE;

TEST(ChannelGeneratorTest, GetFirstAvailable) {
ChannelGenerator generator(2, 2);
EXPECT_EQ(
generator.getProducerDMAChannel(ChannelAssignmentMode::FirstAvailable)
.value(),
0);
EXPECT_EQ(
generator.getConsumerDMAChannel(ChannelAssignmentMode::FirstAvailable)
.value(),
0);
EXPECT_EQ(
generator.getProducerDMAChannel(ChannelAssignmentMode::FirstAvailable)
.value(),
0);
EXPECT_EQ(
generator.getConsumerDMAChannel(ChannelAssignmentMode::FirstAvailable)
.value(),
0);
EXPECT_EQ(
generator.getProducerDMAChannel(ChannelAssignmentMode::FirstAvailable)
.value(),
0);
EXPECT_EQ(
generator.getConsumerDMAChannel(ChannelAssignmentMode::FirstAvailable)
.value(),
0);
}

TEST(ChannelGeneratorTest, GetRoundRobin) {
ChannelGenerator generator(2, 2);
EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin)
.value(),
0);
EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin)
.value(),
0);
EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin)
.value(),
1);
EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin)
.value(),
1);
EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin)
.value(),
0);
EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin)
.value(),
0);
EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin)
.value(),
1);
EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin)
.value(),
1);
}

TEST(ChannelGeneratorTest, GetAssign) {
ChannelGenerator generator(2, 2);
EXPECT_EQ(generator.getAndAssignProducerDMAChannel().value(), 0);
EXPECT_EQ(generator.getAndAssignConsumerDMAChannel().value(), 0);
EXPECT_EQ(generator.getAndAssignProducerDMAChannel().value(), 1);
EXPECT_EQ(generator.getAndAssignConsumerDMAChannel().value(), 1);
EXPECT_EQ(generator.getAndAssignProducerDMAChannel(), std::nullopt);
EXPECT_EQ(generator.getAndAssignConsumerDMAChannel(), std::nullopt);
EXPECT_EQ(generator.getProducerDMAChannel().value(), 0);
generator.assignProducerDMAChannel(0);
EXPECT_EQ(generator.getConsumerDMAChannel().value(), 0);
generator.assignConsumerDMAChannel(0);
EXPECT_EQ(generator.getProducerDMAChannel().value(), 1);
generator.assignProducerDMAChannel(1);
EXPECT_EQ(generator.getConsumerDMAChannel().value(), 1);
generator.assignConsumerDMAChannel(1);
EXPECT_EQ(generator.getProducerDMAChannel(), std::nullopt);
EXPECT_EQ(generator.getConsumerDMAChannel(), std::nullopt);
}

TEST(ChannelGeneratorTest, Occupied) {
Expand All @@ -29,12 +89,40 @@ TEST(ChannelGeneratorTest, Occupied) {
generator.assignConsumerDMAChannel(0);
generator.assignProducerDMAChannel(2);
generator.assignConsumerDMAChannel(2);
EXPECT_EQ(generator.getAndAssignProducerDMAChannel().value(), 1);
EXPECT_EQ(generator.getAndAssignConsumerDMAChannel().value(), 1);
EXPECT_EQ(generator.getAndAssignProducerDMAChannel().value(), 3);
EXPECT_EQ(generator.getAndAssignConsumerDMAChannel().value(), 3);
EXPECT_EQ(generator.getAndAssignProducerDMAChannel(), std::nullopt);
EXPECT_EQ(generator.getAndAssignConsumerDMAChannel(), std::nullopt);
EXPECT_EQ(
generator.getProducerDMAChannel(ChannelAssignmentMode::FirstAvailable)
.value(),
1);
EXPECT_EQ(
generator.getConsumerDMAChannel(ChannelAssignmentMode::FirstAvailable)
.value(),
1);
EXPECT_EQ(
generator.getProducerDMAChannel(ChannelAssignmentMode::FirstAvailable)
.value(),
1);
EXPECT_EQ(
generator.getConsumerDMAChannel(ChannelAssignmentMode::FirstAvailable)
.value(),
1);
EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin)
.value(),
3);
EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin)
.value(),
3);
EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin)
.value(),
1);
EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin)
.value(),
1);
EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin)
.value(),
3);
EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin)
.value(),
3);
}

} // namespace
Expand Down

0 comments on commit f76c245

Please sign in to comment.