diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEAssignChannels.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEAssignChannels.cpp index 4bf66f282..6075286fe 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEAssignChannels.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEAssignChannels.cpp @@ -59,6 +59,10 @@ LogicalResult assignChannels(AMDAIE::WorkgroupOp workgroupOp) { return connectionOp.emitOpError() << "expected a `LogicalObjFifoOpInterface` target"; } + std::optional connectionType = + connectionOp.getConnectionType(); + bool isPacketFlow = connectionType && connectionType.value() == + AMDAIE::ConnectionType::Packet; rewriter.setInsertionPoint(connectionOp); SmallVector sourceChannels; @@ -66,11 +70,14 @@ LogicalResult assignChannels(AMDAIE::WorkgroupOp workgroupOp) { assert(tileToGeneratorMap.contains(tile) && "no channel generator found for tile"); std::optional maybeChannel = - tileToGeneratorMap[tile].getAndAssignProducerDMAChannel(); + tileToGeneratorMap[tile].getProducerDMAChannel(); if (!maybeChannel) { return connectionOp.emitOpError() << "no producer DMA channel available"; } + // Only assign the channel if it is for circuit flow. + if (!isPacketFlow) + tileToGeneratorMap[tile].assignProducerDMAChannel(maybeChannel.value()); auto channelOp = rewriter.create( rewriter.getUnknownLoc(), tile, maybeChannel.value(), StrmSwPortType::DMA, AMDAIE::DMAChannelDir::MM2S); @@ -81,11 +88,14 @@ LogicalResult assignChannels(AMDAIE::WorkgroupOp workgroupOp) { assert(tileToGeneratorMap.contains(tile) && "no channel generator found for tile"); std::optional maybeChannel = - tileToGeneratorMap[tile].getAndAssignConsumerDMAChannel(); + tileToGeneratorMap[tile].getConsumerDMAChannel(); if (!maybeChannel) { return connectionOp.emitOpError() << "no consumer DMA channel available"; } + // Only assign the channel if it is for circuit flow. + if (!isPacketFlow) + tileToGeneratorMap[tile].assignConsumerDMAChannel(maybeChannel.value()); auto channelOp = rewriter.create( rewriter.getUnknownLoc(), tile, maybeChannel.value(), StrmSwPortType::DMA, AMDAIE::DMAChannelDir::S2MM); diff --git a/runtime/src/iree-amd-aie/aie_runtime/Utils/ChannelGenerator.h b/runtime/src/iree-amd-aie/aie_runtime/Utils/ChannelGenerator.h index 52b0c3bb1..d491d90d3 100644 --- a/runtime/src/iree-amd-aie/aie_runtime/Utils/ChannelGenerator.h +++ b/runtime/src/iree-amd-aie/aie_runtime/Utils/ChannelGenerator.h @@ -16,42 +16,72 @@ using namespace llvm; namespace mlir::iree_compiler::AMDAIE { +enum class ChannelAssignmentMode { FirstAvailable, RoundRobin }; + /// Utility to generate valid channels. class ChannelGenerator { public: ChannelGenerator() {} ChannelGenerator(uint8_t numProducerChannels, uint8_t numConsumerChannels) : numProducerChannels(numProducerChannels), - numConsumerChannels(numConsumerChannels) {} + numConsumerChannels(numConsumerChannels) { + assert(numProducerChannels > 0 && numConsumerChannels > 0 && + "Invalid number of producer/consumer channels."); + // Initialize to the last channel for round-robin usage. + lastRetrievedProducerChannel = numProducerChannels - 1; + lastRetrievedConsumerChannel = numConsumerChannels - 1; + } - /// Returns its next usable producer channel. - std::optional getAndAssignProducerDMAChannel() { - for (uint8_t i = 0; i < numProducerChannels; i++) { + /// Retrieves the next producer channel using the specified strategy. + /// Defaults to round-robin for balanced load distribution, using + /// `lastRetrievedProducerChannel` to track the last channel accessed. + std::optional getProducerDMAChannel( + ChannelAssignmentMode mode = ChannelAssignmentMode::RoundRobin) { + for (uint8_t offset = 1; offset <= numProducerChannels; ++offset) { + uint8_t i; + if (mode == ChannelAssignmentMode::FirstAvailable) { + i = offset - 1; + } else if (mode == ChannelAssignmentMode::RoundRobin) { + i = (lastRetrievedProducerChannel + offset) % numProducerChannels; + } else { + assert(false && "Unsupported ChannelAssignmentMode"); + } if (!assignedProducerChannels.count(i)) { - assignedProducerChannels.insert(i); + lastRetrievedProducerChannel = i; return i; } } return std::nullopt; } - /// Returns its next usable consumer channel. - std::optional getAndAssignConsumerDMAChannel() { - for (uint8_t i = 0; i < numConsumerChannels; i++) { + /// Retrieves the next consumer channel using the specified strategy. + /// Defaults to round-robin for balanced load distribution, using + /// `lastRetrievedConsumerChannel` to track the last channel accessed. + std::optional getConsumerDMAChannel( + ChannelAssignmentMode mode = ChannelAssignmentMode::RoundRobin) { + for (uint8_t offset = 1; offset <= numConsumerChannels; ++offset) { + uint8_t i; + if (mode == ChannelAssignmentMode::FirstAvailable) { + i = offset - 1; + } else if (mode == ChannelAssignmentMode::RoundRobin) { + i = (lastRetrievedConsumerChannel + offset) % numConsumerChannels; + } else { + assert(false && "Unsupported ChannelAssignmentMode"); + } if (!assignedConsumerChannels.count(i)) { - assignedConsumerChannels.insert(i); + lastRetrievedConsumerChannel = i; return i; } } return std::nullopt; } - /// Assigns the provided producer channel. + /// Assigns the provided producer channel, only used for circuit flow. void assignProducerDMAChannel(uint8_t channel) { assignedProducerChannels.insert(channel); } - /// Assigns the provided consumer channel. + /// Assigns the provided consumer channel, only used for circuit flow. void assignConsumerDMAChannel(uint8_t channel) { assignedConsumerChannels.insert(channel); } @@ -59,8 +89,13 @@ class ChannelGenerator { private: uint8_t numProducerChannels = 0; uint8_t numConsumerChannels = 0; + // Tracks the channels that are used by circuit flows. DenseSet assignedProducerChannels; DenseSet assignedConsumerChannels; + // Tracks the last retrieved channel in `getProducerDMAChannel` and + // `getConsumerDMAChannel` for round-robin usage. + uint8_t lastRetrievedProducerChannel = 0; + uint8_t lastRetrievedConsumerChannel = 0; }; } // namespace mlir::iree_compiler::AMDAIE diff --git a/runtime/src/iree-amd-aie/aie_runtime/Utils/test/ChannelGeneratorTest.cpp b/runtime/src/iree-amd-aie/aie_runtime/Utils/test/ChannelGeneratorTest.cpp index 85c106f9b..6a811dd1b 100644 --- a/runtime/src/iree-amd-aie/aie_runtime/Utils/test/ChannelGeneratorTest.cpp +++ b/runtime/src/iree-amd-aie/aie_runtime/Utils/test/ChannelGeneratorTest.cpp @@ -13,14 +13,74 @@ namespace { using namespace mlir::iree_compiler::AMDAIE; +TEST(ChannelGeneratorTest, GetFirstAvailable) { + ChannelGenerator generator(2, 2); + EXPECT_EQ( + generator.getProducerDMAChannel(ChannelAssignmentMode::FirstAvailable) + .value(), + 0); + EXPECT_EQ( + generator.getConsumerDMAChannel(ChannelAssignmentMode::FirstAvailable) + .value(), + 0); + EXPECT_EQ( + generator.getProducerDMAChannel(ChannelAssignmentMode::FirstAvailable) + .value(), + 0); + EXPECT_EQ( + generator.getConsumerDMAChannel(ChannelAssignmentMode::FirstAvailable) + .value(), + 0); + EXPECT_EQ( + generator.getProducerDMAChannel(ChannelAssignmentMode::FirstAvailable) + .value(), + 0); + EXPECT_EQ( + generator.getConsumerDMAChannel(ChannelAssignmentMode::FirstAvailable) + .value(), + 0); +} + +TEST(ChannelGeneratorTest, GetRoundRobin) { + ChannelGenerator generator(2, 2); + EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 0); + EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 0); + EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 1); + EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 1); + EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 0); + EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 0); + EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 1); + EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 1); +} + TEST(ChannelGeneratorTest, GetAssign) { ChannelGenerator generator(2, 2); - EXPECT_EQ(generator.getAndAssignProducerDMAChannel().value(), 0); - EXPECT_EQ(generator.getAndAssignConsumerDMAChannel().value(), 0); - EXPECT_EQ(generator.getAndAssignProducerDMAChannel().value(), 1); - EXPECT_EQ(generator.getAndAssignConsumerDMAChannel().value(), 1); - EXPECT_EQ(generator.getAndAssignProducerDMAChannel(), std::nullopt); - EXPECT_EQ(generator.getAndAssignConsumerDMAChannel(), std::nullopt); + EXPECT_EQ(generator.getProducerDMAChannel().value(), 0); + generator.assignProducerDMAChannel(0); + EXPECT_EQ(generator.getConsumerDMAChannel().value(), 0); + generator.assignConsumerDMAChannel(0); + EXPECT_EQ(generator.getProducerDMAChannel().value(), 1); + generator.assignProducerDMAChannel(1); + EXPECT_EQ(generator.getConsumerDMAChannel().value(), 1); + generator.assignConsumerDMAChannel(1); + EXPECT_EQ(generator.getProducerDMAChannel(), std::nullopt); + EXPECT_EQ(generator.getConsumerDMAChannel(), std::nullopt); } TEST(ChannelGeneratorTest, Occupied) { @@ -29,12 +89,40 @@ TEST(ChannelGeneratorTest, Occupied) { generator.assignConsumerDMAChannel(0); generator.assignProducerDMAChannel(2); generator.assignConsumerDMAChannel(2); - EXPECT_EQ(generator.getAndAssignProducerDMAChannel().value(), 1); - EXPECT_EQ(generator.getAndAssignConsumerDMAChannel().value(), 1); - EXPECT_EQ(generator.getAndAssignProducerDMAChannel().value(), 3); - EXPECT_EQ(generator.getAndAssignConsumerDMAChannel().value(), 3); - EXPECT_EQ(generator.getAndAssignProducerDMAChannel(), std::nullopt); - EXPECT_EQ(generator.getAndAssignConsumerDMAChannel(), std::nullopt); + EXPECT_EQ( + generator.getProducerDMAChannel(ChannelAssignmentMode::FirstAvailable) + .value(), + 1); + EXPECT_EQ( + generator.getConsumerDMAChannel(ChannelAssignmentMode::FirstAvailable) + .value(), + 1); + EXPECT_EQ( + generator.getProducerDMAChannel(ChannelAssignmentMode::FirstAvailable) + .value(), + 1); + EXPECT_EQ( + generator.getConsumerDMAChannel(ChannelAssignmentMode::FirstAvailable) + .value(), + 1); + EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 3); + EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 3); + EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 1); + EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 1); + EXPECT_EQ(generator.getProducerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 3); + EXPECT_EQ(generator.getConsumerDMAChannel(ChannelAssignmentMode::RoundRobin) + .value(), + 3); } } // namespace