Skip to content

Commit

Permalink
resolve comment
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu-Zhewen committed Jan 8, 2025
1 parent 88605ba commit 35ef38f
Showing 1 changed file with 26 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,12 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

/// Utility function to get available DMA channels that can be later used for
/// control packets.
LogicalResult getAvailableShimChannels(
AMDAIE::WorkgroupOp workgroupOp, ArrayRef<TileOp> shimTileOps,
/// Initializes the channel generators for the shim tiles, excluding any
/// channels that are already in use by existing circuit flows.
LogicalResult initializeChannelsGenerators(
AMDAIE::WorkgroupOp workgroupOp, const AMDAIEDeviceModel &deviceModel,
const DenseSet<TileOp> &shimTileOps,
DenseMap<Value, ChannelGenerator> &shimTileToGeneratorMap) {
// Get the device model.
std::optional<AMDAIEDevice> device = getConfigAMDAIEDevice(workgroupOp);
if (!device) {
return workgroupOp->emitOpError()
<< "could not find an AMDAIEDevice attribute";
}
AMDAIEDeviceModel deviceModel = AMDAIE::getDeviceModel(device.value());
// Check the number of DMA channels available for the shim tile.
uint8_t numShimDmaChannels = deviceModel.getDmaProp<uint8_t>(
AMDAIETileType::SHIMNOC, AMDAIEDmaProp::NumChannels);
Expand All @@ -39,18 +33,19 @@ LogicalResult getAvailableShimChannels(
workgroupOp->walk([&](AMDAIE::FlowOp flowOp) {
if (flowOp.getIsPacketFlow()) return WalkResult::advance();
SmallVector<AMDAIE::ChannelOp> sourceChannels;
for (auto value : flowOp.getSources()) {
if (auto channelOp = dyn_cast<AMDAIE::ChannelOp>(value.getDefiningOp())) {
for (Value source : flowOp.getSources()) {
if (auto channelOp =
dyn_cast<AMDAIE::ChannelOp>(source.getDefiningOp())) {
sourceChannels.push_back(channelOp);
}
}
for (auto channelOp : sourceChannels) {
for (AMDAIE::ChannelOp channelOp : sourceChannels) {
AMDAIE::TileOp tileOp = channelOp.getTileOp();
uint8_t channel = channelOp.getValue();
StrmSwPortType portType = channelOp.getPortType();
AMDAIE::DMAChannelDir direction = channelOp.getDirection();
if (llvm::is_contained(shimTileOps, tileOp) &&
portType == StrmSwPortType::DMA) {
if (shimTileOps.contains(tileOp) && portType == StrmSwPortType::DMA) {
// Assign to exclude.
if (direction == AMDAIE::DMAChannelDir::MM2S) {
shimTileToGeneratorMap[tileOp.getResult()].assignProducerDMAChannel(
channel);
Expand All @@ -68,14 +63,23 @@ LogicalResult getAvailableShimChannels(
LogicalResult generateColumnControlOverlay(AMDAIE::WorkgroupOp workgroupOp,
bool routeShimToTileCtrl,
bool routeShimCtrlToTct) {
// Get the device model.
std::optional<AMDAIEDevice> device = getConfigAMDAIEDevice(workgroupOp);
if (!device) {
return workgroupOp->emitOpError()
<< "could not find an AMDAIEDevice attribute";
}
AMDAIEDeviceModel deviceModel = AMDAIE::getDeviceModel(device.value());

IRRewriter rewriter(workgroupOp->getContext());
DenseSet<uint32_t> occupiedCols;
DenseMap<uint32_t, AMDAIE::TileOp> columnToShimTile;
workgroupOp->walk([&](AMDAIE::TileOp tileOp) {
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
uint32_t row = getConstantIndexOrAssert(tileOp.getRow());
occupiedCols.insert(col);
if (row == 0) columnToShimTile[col] = tileOp;
if (deviceModel.getTileType(col, row) == AMDAIETileType::SHIMNOC)
columnToShimTile[col] = tileOp;
});

// If the column is occupied, but the shim tile op is not present, then create
Expand All @@ -96,16 +100,16 @@ LogicalResult generateColumnControlOverlay(AMDAIE::WorkgroupOp workgroupOp,
// control packets.
if (routeShimToTileCtrl) {
DenseMap<Value, ChannelGenerator> shimTileToGeneratorMap;
SmallVector<TileOp> shimTileOps = llvm::to_vector<4>(llvm::map_range(
columnToShimTile, [](auto pair) { return pair.second; }));
if (failed(getAvailableShimChannels(workgroupOp, shimTileOps,
shimTileToGeneratorMap))) {
DenseSet<TileOp> shimTileOps;
for (const auto &pair : columnToShimTile) shimTileOps.insert(pair.second);
if (failed(initializeChannelsGenerators(
workgroupOp, deviceModel, shimTileOps, shimTileToGeneratorMap))) {
return failure();
}
WalkResult res = workgroupOp->walk([&](AMDAIE::TileOp tileOp) {
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
TileOp shimTileOp = columnToShimTile[col];
// Get the available channel, but do not assigning it. Allow it to be
// Get the available channel, but do not assign it. Allow it to be
// shared across multiple packet flows as needed.
std::optional<uint8_t> maybeChannel =
shimTileToGeneratorMap[shimTileOp.getResult()]
Expand Down

0 comments on commit 35ef38f

Please sign in to comment.