diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEGenerateColumnControlOverlay.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEGenerateColumnControlOverlay.cpp index 2e692a316..843006260 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEGenerateColumnControlOverlay.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEGenerateColumnControlOverlay.cpp @@ -16,18 +16,12 @@ namespace mlir::iree_compiler::AMDAIE { namespace { -/// Utility function to get available DMA channels that can be later used for -/// control packets. -LogicalResult getAvailableShimChannels( - AMDAIE::WorkgroupOp workgroupOp, ArrayRef shimTileOps, +/// Initializes the channel generators for the shim tiles, excluding any +/// channels that are already in use by existing circuit flows. +LogicalResult initializeChannelsGenerators( + AMDAIE::WorkgroupOp workgroupOp, const AMDAIEDeviceModel &deviceModel, + const DenseSet &shimTileOps, DenseMap &shimTileToGeneratorMap) { - // Get the device model. - std::optional device = getConfigAMDAIEDevice(workgroupOp); - if (!device) { - return workgroupOp->emitOpError() - << "could not find an AMDAIEDevice attribute"; - } - AMDAIEDeviceModel deviceModel = AMDAIE::getDeviceModel(device.value()); // Check the number of DMA channels available for the shim tile. uint8_t numShimDmaChannels = deviceModel.getDmaProp( AMDAIETileType::SHIMNOC, AMDAIEDmaProp::NumChannels); @@ -39,18 +33,19 @@ LogicalResult getAvailableShimChannels( workgroupOp->walk([&](AMDAIE::FlowOp flowOp) { if (flowOp.getIsPacketFlow()) return WalkResult::advance(); SmallVector sourceChannels; - for (auto value : flowOp.getSources()) { - if (auto channelOp = dyn_cast(value.getDefiningOp())) { + for (Value source : flowOp.getSources()) { + if (auto channelOp = + dyn_cast(source.getDefiningOp())) { sourceChannels.push_back(channelOp); } } - for (auto channelOp : sourceChannels) { + for (AMDAIE::ChannelOp channelOp : sourceChannels) { AMDAIE::TileOp tileOp = channelOp.getTileOp(); uint8_t channel = channelOp.getValue(); StrmSwPortType portType = channelOp.getPortType(); AMDAIE::DMAChannelDir direction = channelOp.getDirection(); - if (llvm::is_contained(shimTileOps, tileOp) && - portType == StrmSwPortType::DMA) { + if (shimTileOps.contains(tileOp) && portType == StrmSwPortType::DMA) { + // Assign to exclude. if (direction == AMDAIE::DMAChannelDir::MM2S) { shimTileToGeneratorMap[tileOp.getResult()].assignProducerDMAChannel( channel); @@ -68,6 +63,14 @@ LogicalResult getAvailableShimChannels( LogicalResult generateColumnControlOverlay(AMDAIE::WorkgroupOp workgroupOp, bool routeShimToTileCtrl, bool routeShimCtrlToTct) { + // Get the device model. + std::optional device = getConfigAMDAIEDevice(workgroupOp); + if (!device) { + return workgroupOp->emitOpError() + << "could not find an AMDAIEDevice attribute"; + } + AMDAIEDeviceModel deviceModel = AMDAIE::getDeviceModel(device.value()); + IRRewriter rewriter(workgroupOp->getContext()); DenseSet occupiedCols; DenseMap columnToShimTile; @@ -75,7 +78,8 @@ LogicalResult generateColumnControlOverlay(AMDAIE::WorkgroupOp workgroupOp, uint32_t col = getConstantIndexOrAssert(tileOp.getCol()); uint32_t row = getConstantIndexOrAssert(tileOp.getRow()); occupiedCols.insert(col); - if (row == 0) columnToShimTile[col] = tileOp; + if (deviceModel.getTileType(col, row) == AMDAIETileType::SHIMNOC) + columnToShimTile[col] = tileOp; }); // If the column is occupied, but the shim tile op is not present, then create @@ -96,16 +100,16 @@ LogicalResult generateColumnControlOverlay(AMDAIE::WorkgroupOp workgroupOp, // control packets. if (routeShimToTileCtrl) { DenseMap shimTileToGeneratorMap; - SmallVector shimTileOps = llvm::to_vector<4>(llvm::map_range( - columnToShimTile, [](auto pair) { return pair.second; })); - if (failed(getAvailableShimChannels(workgroupOp, shimTileOps, - shimTileToGeneratorMap))) { + DenseSet shimTileOps; + for (const auto &pair : columnToShimTile) shimTileOps.insert(pair.second); + if (failed(initializeChannelsGenerators( + workgroupOp, deviceModel, shimTileOps, shimTileToGeneratorMap))) { return failure(); } WalkResult res = workgroupOp->walk([&](AMDAIE::TileOp tileOp) { uint32_t col = getConstantIndexOrAssert(tileOp.getCol()); TileOp shimTileOp = columnToShimTile[col]; - // Get the available channel, but do not assigning it. Allow it to be + // Get the available channel, but do not assign it. Allow it to be // shared across multiple packet flows as needed. std::optional maybeChannel = shimTileToGeneratorMap[shimTileOp.getResult()]