diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEAssignChannels.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEAssignChannels.cpp index c2e3b9c9e..6075286fe 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEAssignChannels.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEAssignChannels.cpp @@ -70,11 +70,14 @@ LogicalResult assignChannels(AMDAIE::WorkgroupOp workgroupOp) { assert(tileToGeneratorMap.contains(tile) && "no channel generator found for tile"); std::optional maybeChannel = - tileToGeneratorMap[tile].getAndAssignProducerDMAChannel(isPacketFlow); + tileToGeneratorMap[tile].getProducerDMAChannel(); if (!maybeChannel) { return connectionOp.emitOpError() << "no producer DMA channel available"; } + // Only assign the channel if it is for circuit flow. + if (!isPacketFlow) + tileToGeneratorMap[tile].assignProducerDMAChannel(maybeChannel.value()); auto channelOp = rewriter.create( rewriter.getUnknownLoc(), tile, maybeChannel.value(), StrmSwPortType::DMA, AMDAIE::DMAChannelDir::MM2S); @@ -85,11 +88,14 @@ LogicalResult assignChannels(AMDAIE::WorkgroupOp workgroupOp) { assert(tileToGeneratorMap.contains(tile) && "no channel generator found for tile"); std::optional maybeChannel = - tileToGeneratorMap[tile].getAndAssignConsumerDMAChannel(isPacketFlow); + tileToGeneratorMap[tile].getConsumerDMAChannel(); if (!maybeChannel) { return connectionOp.emitOpError() << "no consumer DMA channel available"; } + // Only assign the channel if it is for circuit flow. + if (!isPacketFlow) + tileToGeneratorMap[tile].assignConsumerDMAChannel(maybeChannel.value()); auto channelOp = rewriter.create( rewriter.getUnknownLoc(), tile, maybeChannel.value(), StrmSwPortType::DMA, AMDAIE::DMAChannelDir::S2MM); diff --git a/runtime/src/iree-amd-aie/aie_runtime/Utils/ChannelGenerator.h b/runtime/src/iree-amd-aie/aie_runtime/Utils/ChannelGenerator.h index f7ae4e289..c9814e9a2 100644 --- a/runtime/src/iree-amd-aie/aie_runtime/Utils/ChannelGenerator.h +++ b/runtime/src/iree-amd-aie/aie_runtime/Utils/ChannelGenerator.h @@ -16,6 +16,8 @@ using namespace llvm; namespace mlir::iree_compiler::AMDAIE { +enum class ChannelAssignmentMode { FirstAvailable, RoundRobin }; + /// Utility to generate valid channels. class ChannelGenerator { public: @@ -30,13 +32,20 @@ class ChannelGenerator { lastUsedConsumerChannel = numConsumerChannels - 1; } - /// Returns its next usable producer channel. - std::optional getAndAssignProducerDMAChannel(bool isPacketFlow) { + /// Returns its next usable producer channel. By default, it uses round-robin + /// for load balancing. + std::optional getProducerDMAChannel( + ChannelAssignmentMode mode = ChannelAssignmentMode::RoundRobin) { for (uint8_t offset = 1; offset <= numProducerChannels; ++offset) { - uint8_t i = (lastUsedProducerChannel + offset) % numProducerChannels; + uint8_t i; + if (mode == ChannelAssignmentMode::FirstAvailable) { + i = offset - 1; + } else if (mode == ChannelAssignmentMode::RoundRobin) { + i = (lastUsedProducerChannel + offset) % numProducerChannels; + } else { + assert(false && "Unsupported ChannelAssignmentMode"); + } if (!assignedProducerChannels.count(i)) { - // Only assign the channel if it is for the circuit flow. - if (!isPacketFlow) assignedProducerChannels.insert(i); lastUsedProducerChannel = i; return i; } @@ -44,13 +53,20 @@ class ChannelGenerator { return std::nullopt; } - /// Returns its next usable consumer channel. - std::optional getAndAssignConsumerDMAChannel(bool isPacketFlow) { + /// Returns its next usable consumer channel. By default, it uses round-robin + /// for load balancing. + std::optional getConsumerDMAChannel( + ChannelAssignmentMode mode = ChannelAssignmentMode::RoundRobin) { for (uint8_t offset = 1; offset <= numConsumerChannels; ++offset) { - uint8_t i = (lastUsedConsumerChannel + offset) % numConsumerChannels; + uint8_t i; + if (mode == ChannelAssignmentMode::FirstAvailable) { + i = offset - 1; + } else if (mode == ChannelAssignmentMode::RoundRobin) { + i = (lastUsedConsumerChannel + offset) % numConsumerChannels; + } else { + assert(false && "Unsupported ChannelAssignmentMode"); + } if (!assignedConsumerChannels.count(i)) { - // Only assign the channel if it is for the circuit flow. - if (!isPacketFlow) assignedConsumerChannels.insert(i); lastUsedConsumerChannel = i; return i; } diff --git a/runtime/src/iree-amd-aie/aie_runtime/Utils/test/ChannelGeneratorTest.cpp b/runtime/src/iree-amd-aie/aie_runtime/Utils/test/ChannelGeneratorTest.cpp index 9c365e29f..6a811dd1b 100644 --- a/runtime/src/iree-amd-aie/aie_runtime/Utils/test/ChannelGeneratorTest.cpp +++ b/runtime/src/iree-amd-aie/aie_runtime/Utils/test/ChannelGeneratorTest.cpp @@ -13,53 +13,116 @@ namespace { using namespace mlir::iree_compiler::AMDAIE; -TEST(ChannelGeneratorTest, GetAssign) { +TEST(ChannelGeneratorTest, GetFirstAvailable) { ChannelGenerator generator(2, 2); - bool isPacketFlow = false; - EXPECT_EQ(generator.getAndAssignProducerDMAChannel(isPacketFlow).value(), 0); - EXPECT_EQ(generator.getAndAssignConsumerDMAChannel(isPacketFlow).value(), 0); - EXPECT_EQ(generator.getAndAssignProducerDMAChannel(isPacketFlow).value(), 1); - EXPECT_EQ(generator.getAndAssignConsumerDMAChannel(isPacketFlow).value(), 1); - EXPECT_EQ(generator.getAndAssignProducerDMAChannel(isPacketFlow), - std::nullopt); - EXPECT_EQ(generator.getAndAssignConsumerDMAChannel(isPacketFlow), - std::nullopt); + EXPECT_EQ( + generator.getProducerDMAChannel(ChannelAssignmentMode::FirstAvailable) + .value(), + 0); + EXPECT_EQ( + generator.getConsumerDMAChannel(ChannelAssignmentMode::FirstAvailable) + .value(), + 0); + EXPECT_EQ( + generator.getProducerDMAChannel(ChannelAssignmentMode::FirstAvailable) + .value(), + 0); + EXPECT_EQ( + generator.getConsumerDMAChannel(ChannelAssignmentMode::FirstAvailable) + .value(), + 0); + EXPECT_EQ( + generator.getProducerDMAChannel(ChannelAssignmentMode::FirstAvailable) + .value(), + 0); + EXPECT_EQ( + generator.getConsumerDMAChannel(ChannelAssignmentMode::FirstAvailable) + .value(), + 0); } -TEST(ChannelGeneratorTest, Occupied) { - ChannelGenerator generator(4, 4); - bool isPacketFlow = false; +TEST(ChannelGeneratorTest, GetRoundRobin) { + ChannelGenerator generator(2, 2); + EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 0); + EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 0); + EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 1); + EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 1); + EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 0); + EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 0); + EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 1); + EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 1); +} + +TEST(ChannelGeneratorTest, GetAssign) { + ChannelGenerator generator(2, 2); + EXPECT_EQ(generator.getProducerDMAChannel().value(), 0); generator.assignProducerDMAChannel(0); + EXPECT_EQ(generator.getConsumerDMAChannel().value(), 0); generator.assignConsumerDMAChannel(0); - generator.assignProducerDMAChannel(2); - generator.assignConsumerDMAChannel(2); - EXPECT_EQ(generator.getAndAssignProducerDMAChannel(isPacketFlow).value(), 1); - EXPECT_EQ(generator.getAndAssignConsumerDMAChannel(isPacketFlow).value(), 1); - EXPECT_EQ(generator.getAndAssignProducerDMAChannel(isPacketFlow).value(), 3); - EXPECT_EQ(generator.getAndAssignConsumerDMAChannel(isPacketFlow).value(), 3); - EXPECT_EQ(generator.getAndAssignProducerDMAChannel(isPacketFlow), - std::nullopt); - EXPECT_EQ(generator.getAndAssignConsumerDMAChannel(isPacketFlow), - std::nullopt); + EXPECT_EQ(generator.getProducerDMAChannel().value(), 1); + generator.assignProducerDMAChannel(1); + EXPECT_EQ(generator.getConsumerDMAChannel().value(), 1); + generator.assignConsumerDMAChannel(1); + EXPECT_EQ(generator.getProducerDMAChannel(), std::nullopt); + EXPECT_EQ(generator.getConsumerDMAChannel(), std::nullopt); } -TEST(ChannelGeneratorTest, PacketFlow) { +TEST(ChannelGeneratorTest, Occupied) { ChannelGenerator generator(4, 4); generator.assignProducerDMAChannel(0); generator.assignConsumerDMAChannel(0); generator.assignProducerDMAChannel(2); generator.assignConsumerDMAChannel(2); - bool isPacketFlow = true; - // Packet flow should not occupy the channel exclusively, and the available - // channel appears in a round-robin fashion. - EXPECT_EQ(generator.getAndAssignProducerDMAChannel(isPacketFlow).value(), 1); - EXPECT_EQ(generator.getAndAssignConsumerDMAChannel(isPacketFlow).value(), 1); - EXPECT_EQ(generator.getAndAssignProducerDMAChannel(isPacketFlow).value(), 3); - EXPECT_EQ(generator.getAndAssignConsumerDMAChannel(isPacketFlow).value(), 3); - EXPECT_EQ(generator.getAndAssignProducerDMAChannel(isPacketFlow).value(), 1); - EXPECT_EQ(generator.getAndAssignConsumerDMAChannel(isPacketFlow).value(), 1); - EXPECT_EQ(generator.getAndAssignProducerDMAChannel(isPacketFlow).value(), 3); - EXPECT_EQ(generator.getAndAssignConsumerDMAChannel(isPacketFlow).value(), 3); + EXPECT_EQ( + generator.getProducerDMAChannel(ChannelAssignmentMode::FirstAvailable) + .value(), + 1); + EXPECT_EQ( + generator.getConsumerDMAChannel(ChannelAssignmentMode::FirstAvailable) + .value(), + 1); + EXPECT_EQ( + generator.getProducerDMAChannel(ChannelAssignmentMode::FirstAvailable) + .value(), + 1); + EXPECT_EQ( + generator.getConsumerDMAChannel(ChannelAssignmentMode::FirstAvailable) + .value(), + 1); + EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 3); + EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 3); + EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 1); + EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 1); + EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 3); + EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 3); } } // namespace