diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIECanonicalizeDoublyStridedOp.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIECanonicalizeDoublyStridedOp.cpp index 894c3e14f..42b4e37b2 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIECanonicalizeDoublyStridedOp.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIECanonicalizeDoublyStridedOp.cpp @@ -156,12 +156,12 @@ struct FoldDmaOpUnitDims SmallVector targetStrides = op.getTargetMixedStrides(); SmallVector newSourceOffsets, newSourceSizes, newSourceStrides, newTargetOffsets, newTargetSizes, newTargetStrides; - LogicalResult sourceRes = - foldUnitDims(op.getContext(), sourceOffsets, sourceSizes, sourceStrides, - newSourceOffsets, newSourceSizes, newSourceStrides); - LogicalResult targetRes = - foldUnitDims(op.getContext(), targetOffsets, targetSizes, targetStrides, - newTargetOffsets, newTargetSizes, newTargetStrides); + LogicalResult sourceRes = foldUnitDims( + *op.getContext(), sourceOffsets, sourceSizes, sourceStrides, + newSourceOffsets, newSourceSizes, newSourceStrides); + LogicalResult targetRes = foldUnitDims( + *op.getContext(), targetOffsets, targetSizes, targetStrides, + newTargetOffsets, newTargetSizes, newTargetStrides); if (failed(sourceRes) && failed(targetRes)) { return failure(); } diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIECombineStridedOps.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIECombineStridedOps.cpp index 04a1d1663..b9f2b305d 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIECombineStridedOps.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIECombineStridedOps.cpp @@ -16,7 +16,6 @@ #include "iree-amd-aie/Transforms/Transforms.h" #include "iree-amd-aie/Transforms/Utils/AMDAIEDmaUtils.h" #include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h" -#include "llvm/ADT/STLExtras.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "iree-amdaie-combine-strided-ops" @@ -47,8 +46,10 @@ struct CombineStridedOps std::unique_ptr sourceDmaDimConfig; std::unique_ptr targetDmaDimConfig; + SmallVector userOpsToBeErased; AMDAIE::DoublyStridedOpInterface nextStridedOp; + if (auto npuDmaOp = dyn_cast(op.getOperation())) { LLVM_DEBUG(llvm::dbgs() << "npuDmaOp: " << npuDmaOp << "\n"); // Fail if any non-wait user operations. @@ -105,6 +106,10 @@ struct CombineStridedOps return failure(); } + MLIRContext &context = *rewriter.getContext(); + auto dimCountCheck = std::bind(&DmaDimConfig::exceedsNbDims, + std::ref(sourceDmaDimConfig), _1); + SmallVector sourceOffsetsA = op.getSourceMixedOffsets(); SmallVector sourceSizesA = op.getSourceMixedSizes(); SmallVector sourceStridesA = op.getSourceMixedStrides(); @@ -114,11 +119,15 @@ struct CombineStridedOps nextStridedOp.getSourceMixedSizes(); SmallVector sourceStridesB = nextStridedOp.getSourceMixedStrides(); - bool areSourcesCombinable = areAccessPatternsCombinable( - sourceOffsetsA, sourceSizesA, sourceStridesA, sourceOffsetsB, - sourceSizesB, sourceStridesB, - std::bind(&DmaDimConfig::exceedsNbDims, std::ref(sourceDmaDimConfig), - _1)); + SmallVector newSourceOffsets; + SmallVector newSourceSizes; + SmallVector newSourceStrides; + if (failed(combineAccessPatterns( + context, sourceOffsetsA, sourceSizesA, sourceStridesA, + sourceOffsetsB, sourceSizesB, sourceStridesB, newSourceOffsets, + newSourceSizes, newSourceStrides, dimCountCheck))) { + return failure(); + } SmallVector targetOffsetsA = op.getTargetMixedOffsets(); SmallVector targetSizesA = op.getTargetMixedSizes(); @@ -129,53 +138,25 @@ struct CombineStridedOps nextStridedOp.getTargetMixedSizes(); SmallVector targetStridesB = nextStridedOp.getTargetMixedStrides(); - bool areTargetsCombinable = areAccessPatternsCombinable( - targetOffsetsA, targetSizesA, targetStridesA, targetOffsetsB, - targetSizesB, targetStridesB, - std::bind(&DmaDimConfig::exceedsNbDims, std::ref(targetDmaDimConfig), - _1)); - - LLVM_DEBUG(llvm::dbgs() - << "areSourcesCombinable: " << areSourcesCombinable << "\n"); - LLVM_DEBUG(llvm::dbgs() - << "areTargetsCombinable: " << areTargetsCombinable << "\n"); - - if (areSourcesCombinable && areTargetsCombinable) { - SmallVector newSourceOffsets; - SmallVector newSourceSizes; - SmallVector newSourceStrides; - if (failed(combineAccessPatterns( - rewriter, sourceOffsetsA, sourceSizesA, sourceStridesA, - sourceOffsetsB, sourceSizesB, sourceStridesB, newSourceOffsets, - newSourceSizes, newSourceStrides, - std::bind(&DmaDimConfig::exceedsNbDims, - std::ref(sourceDmaDimConfig), _1)))) { - return failure(); - } - - SmallVector newTargetOffsets; - SmallVector newTargetSizes; - SmallVector newTargetStrides; - if (failed(combineAccessPatterns( - rewriter, targetOffsetsA, targetSizesA, targetStridesA, - targetOffsetsB, targetSizesB, targetStridesB, newTargetOffsets, - newTargetSizes, newTargetStrides, - std::bind(&DmaDimConfig::exceedsNbDims, - std::ref(targetDmaDimConfig), _1)))) { - return failure(); - } + SmallVector newTargetOffsets; + SmallVector newTargetSizes; + SmallVector newTargetStrides; + if (failed(combineAccessPatterns( + context, targetOffsetsA, targetSizesA, targetStridesA, + targetOffsetsB, targetSizesB, targetStridesB, newTargetOffsets, + newTargetSizes, newTargetStrides, dimCountCheck))) { + return failure(); + } - rewriter.setInsertionPoint(op); - auto newDoublyStridedOp = nextStridedOp.createDoublyStridedOp( - rewriter, newTargetOffsets, newTargetSizes, newTargetStrides, - newSourceOffsets, newSourceSizes, newSourceStrides); - rewriter.replaceOp(nextStridedOp, newDoublyStridedOp.getOperation()); + rewriter.setInsertionPoint(op); + auto newDoublyStridedOp = nextStridedOp.createDoublyStridedOp( + rewriter, newTargetOffsets, newTargetSizes, newTargetStrides, + newSourceOffsets, newSourceSizes, newSourceStrides); + rewriter.replaceOp(nextStridedOp, newDoublyStridedOp.getOperation()); - for (Operation *userOp : userOpsToBeErased) rewriter.eraseOp(userOp); - rewriter.eraseOp(op); - return success(); - } - return failure(); + for (Operation *userOp : userOpsToBeErased) rewriter.eraseOp(userOp); + rewriter.eraseOp(op); + return success(); } template diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEDmaComposition.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEDmaComposition.cpp index 543bcf7da..e0bcfdcc9 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEDmaComposition.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEDmaComposition.cpp @@ -55,7 +55,6 @@ void AMDAIEDmaCompositionPass::runOnOperation() { onlyZeroStrideOnOuterDim); populateStridedOpCombinationPattern(patterns); populateCanonicalizeDoublyStridedOpPatterns(patterns, false, deviceModel); - if (failed(applyPatternsGreedily(parentOp, std::move(patterns)))) { parentOp->emitOpError("failed to compose strided operations"); return signalPassFailure(); diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToAIE.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToAIE.cpp index 7a0ce4ed2..f0c3cc145 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToAIE.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToAIE.cpp @@ -345,7 +345,7 @@ LogicalResult AIEDeviceBuilder::foldDimsAndReturnAsStatic( SmallVector offsets( strides.size(), getAsIndexOpFoldResult(rewriter.getContext(), 0)); SmallVector unitOffsets, unitSizes, unitStrides, newOffsets; - (void)foldUnitDims(rewriter.getContext(), offsets, sizes, strides, + (void)foldUnitDims(*rewriter.getContext(), offsets, sizes, strides, unitOffsets, unitSizes, unitStrides); DmaDimConfig dmaDimConfig(deviceModel, memSpace); SmallVector maxSizes = dmaDimConfig.getMaxSizes(unitOffsets.size()); diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEDmaUtils.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEDmaUtils.cpp index 89657708d..b10c1358e 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEDmaUtils.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEDmaUtils.cpp @@ -7,201 +7,231 @@ #include "AMDAIEDmaUtils.h" #include "AMDAIEUtils.h" +#include "iree-amd-aie/IR/AMDAIEOps.h" #include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #define DEBUG_TYPE "iree-amdaie-dma-utils" namespace mlir::iree_compiler::AMDAIE { -static bool isEqualConstantIntOrValueArrayFromIndices( - ArrayRef ofrsA, ArrayRef ofrsB, - size_t indexA = 0, size_t indexB = 0) { - if ((ofrsA.size() - indexA) != (ofrsB.size() - indexB)) return false; - return isEqualConstantIntOrValueArray(ofrsA.drop_front(indexA), - ofrsB.drop_front(indexB)); +namespace { + +/// Update the strides of `X` to match the strides of `Y` if it is possible to +/// do so without changing the underlying access pattern of `X`. For example if +/// +/// X : offset = [5] sizes = [1] strides = [6] +/// Y : offset = [a] sizes = [b] strides = [3] +/// +/// Then the access pattern for X can be changed to +/// +/// X : offset = [10] sizes = [1] strides = [3] +/// +/// For this to be possible in dimension `d`: +/// 1) the size in `d` od X must be 1, and +/// 2) the updated offset in `d` of `X` (offset * strideX / stride) must be an +/// integer. +void equalizeStrides(MLIRContext &ctx, ArrayRef sizesX, + SmallVector &stridesX, + SmallVector &offsetsX, + ArrayRef stridesY) { + for (int i = 0; i < sizesX.size(); ++i) { + if (stridesX[i] != stridesY[i]) { + if (sizesX[i] == 1) { + auto maybeConstantOffset = getConstantIntValue(offsetsX[i]); + if (maybeConstantOffset.has_value()) { + auto offset = maybeConstantOffset.value(); + int64_t offsetDenominator = stridesY[i]; + int64_t offsetNumerator = offset * stridesX[i]; + if (offsetNumerator % offsetDenominator == 0) { + offsetsX[i] = getAsIndexOpFoldResult( + &ctx, offsetNumerator / offsetDenominator); + stridesX[i] = stridesY[i]; + } + } else { + // This is the case: + // 1) strides are different + // 2) size of X is 1 + // 3) offset of X is not-constant. + // To handle this, we need to use arith ops to perform the + // calculation. For now, we're ignoring this case. + } + } + } + } } -bool areAccessPatternsEqualFromIndices(ArrayRef offsetsA, - ArrayRef sizesA, - ArrayRef stridesA, - ArrayRef offsetsB, - ArrayRef sizesB, - ArrayRef stridesB, - size_t indexA, size_t indexB) { - return isEqualConstantIntOrValueArrayFromIndices(offsetsA, offsetsB, indexA, - indexB) && - isEqualConstantIntOrValueArrayFromIndices(sizesA, sizesB, indexA, - indexB) && - isEqualConstantIntOrValueArrayFromIndices(stridesA, stridesB, indexA, - indexB); +/// Offsets, sizes, and strides define an access pattern into an array, where +/// the i'th element accessed, for 0 <= i < prod_{d=0}^{D-1} sizes[d], is at +/// index +/// +/// sum_{d=0}^{D-1} (d[i] + offset[d[i]]) * stride[d[i]] (1) +/// +/// where d[i] is the decomposition of the index into the D dimensions: +/// +/// i = sum_{d=0}^{D-1} d[i] * size[d] (2) +/// +/// +/// Equation (1) can be rewritten with a 'global' offset 'global_offset' as +/// +/// global_offset + sum_{d=0}^{D-1} d[i] * stride[d[i]] (3) +/// +/// where the global offset is +/// +/// global_offset = sum_{d=0}^{D-1} offset[d] * stride[d]. +/// +/// This function computes the difference between the global offsets of two +/// access patterns. If it is not constant, i.e. if the difference contains +/// an MLIR value which is not obviously a constant, then nullopt is returned. +/// +/// This function is useful when determining if the access pattern A, followed +/// by the access pattern B, can be merged into a single access pattern. +/// +/// \return global_offset(X) - global_offset(Y). +std::optional getGlobalOffsetDifference( + ArrayRef offsetsX, ArrayRef stridesX, + ArrayRef offsetsY, ArrayRef stridesY) { + int64_t globalOffset{0}; + DenseMap offsetToStride; + + auto increment = [&](Value v, int64_t signedStride) { + auto iter = offsetToStride.find(v); + if (iter == offsetToStride.end()) { + offsetToStride[v] = signedStride; + } else { + iter->second += signedStride; + } + }; + + auto update = [&](ArrayRef offsets, ArrayRef strides, + int64_t sign) { + for (auto [offset, stride] : llvm::zip(offsets, strides)) { + auto off = getConstantIntValue(offset); + if (off.has_value()) { + globalOffset += sign * off.value() * stride; + } else { + increment(cast(offset), sign * stride); + } + } + }; + + update(offsetsX, stridesX, 1); + update(offsetsY, stridesY, -1); + + for (auto [offset, stride] : offsetToStride) { + // There is a non-constant offset with a stride that is not zero. + // This means that the global offset difference is not a constant. + if (stride != 0) return std::nullopt; + } + + return globalOffset; } -bool areAccessPatternsCombinable(const SmallVector &offsetsA, - const SmallVector &sizesA, - const SmallVector &stridesA, - const SmallVector &offsetsB, - const SmallVector &sizesB, - const SmallVector &stridesB, - function_ref exceedsNbDims) { - assert(offsetsA.size() == sizesA.size() && - "expected same number of source offsets and sizes"); - assert(offsetsA.size() == stridesA.size() && - "expected same number of source offsets and strides"); - assert(offsetsB.size() == sizesB.size() && - "expected same number of source offsets and sizes"); - assert(offsetsB.size() == stridesB.size() && - "expected same number of source offsets and strides"); - if (std::abs((ssize_t)offsetsA.size() - (ssize_t)offsetsB.size()) > 1) - return false; - // Empty access patterns are always combinable as they effectively mean: - // 'don't perform any or change the addressing'. - if (offsetsA.empty() && offsetsB.empty()) return true; - // In case both access patterns have the same number of dimension, a new - // dimension will need to be added, so fail if there aren't enough - // dimensions. - if (offsetsA.size() == offsetsB.size() && - exceedsNbDims(offsetsA.size() + 1)) { - LLVM_DEBUG(llvm::dbgs() << "Exceeded maximum number of dimensions\n"); - return false; +/// Inputs come from 2 access patterns, where the access pattern for Y +/// has one more dimension than the access pattern for X. This function +/// inserts a singleton dimension into the access pattern for X, at the +/// first dimension from the back where it differs from the access pattern +/// for Y. Thus the ranks of the access patterns for X and Y will be made the +/// same. +void insertUnitDimension(MLIRContext &ctx, SmallVector &offsetsX, + SmallVector &sizesX, + SmallVector &stridesX, + ArrayRef stridesY) { + assert(stridesY.size() == stridesX.size() + 1); + uint32_t index = stridesX.size(); + while (index > 0) { + if (stridesY[index] != stridesX[index - 1]) break; + index--; } + OpFoldResult zeroFoldResult = getAsIndexOpFoldResult(&ctx, 0); + sizesX.insert(sizesX.begin() + index, 1); + stridesX.insert(stridesX.begin() + index, stridesY[index]); + offsetsX.insert(offsetsX.begin() + index, zeroFoldResult); +} - // Equality of the last N elements of the access patterns of A and B with N = - // min(sizeA, sizeB) results in some simple cases in which the access - // patterns are combinable. Note that abs(sizeB - sizeA) should <= 1 and this - // is checked for earlier, so just asserted here. - assert(std::abs((ssize_t)offsetsB.size() - (ssize_t)offsetsA.size()) <= 1 && - "The distance between the indices should be smaller or equal to one."); - size_t indexA = offsetsA.size() > offsetsB.size() ? 1 : 0; - size_t indexB = offsetsB.size() > offsetsA.size() ? 1 : 0; - if (areAccessPatternsEqualFromIndices(offsetsA, sizesA, stridesA, offsetsB, - sizesB, stridesB, indexA, indexB)) { - if (offsetsA.size() == offsetsB.size()) { - return true; - } else if (offsetsA.size() > offsetsB.size()) { - // The access pattern A has N repetitions of access pattern B, so they can - // be combined together into N+1 repetitions. - return isConstantIntValue(stridesA[0], 0); - } else { - // offsetsB.size() > offsetsA.size() - // The access pattern B has N repetitions of access pattern A, so they can - // be combined together into N+1 repetitions. - if (isConstantIntValue(stridesB[0], 0)) return true; - // The access pattern of B is the same as the access pattern of A, but at - // a different offset. They can be combined by reducing the offset of B to - // zero. - if (isConstantIntValue(offsetsB[0], 1)) return true; - return false; - } +/// If pattern `A` followed by `B` can be merged, merge `B` into `A` and return +/// true. Otherwise return false, without a guarantee that operands aren't +/// mutated. +bool mergeInFirst(MLIRContext &ctx, SmallVector &offsetsA, + SmallVector &sizesA, SmallVector &stridesA, + SmallVector offsetsB, + SmallVector sizesB, SmallVector stridesB) { + if (offsetsA.size() == 0 && offsetsB.size() == 0) return true; + + // Local canonicalization to improve chances of merging + if (sizesA.size() + 1 == sizesB.size()) { + insertUnitDimension(ctx, offsetsA, sizesA, stridesA, stridesB); + } else if (sizesB.size() + 1 == sizesA.size()) { + insertUnitDimension(ctx, offsetsB, sizesB, stridesB, stridesA); + } else if (sizesA.size() != sizesB.size()) { + // If the ranks are of the accesses differ by more than 1, it is impossible + // to merge them (unless the higher ranked access pattern has 2+ leading + // dimensions of size 1, which we're ignoring for now). + return false; } + equalizeStrides(ctx, sizesA, stridesA, offsetsA, stridesB); + equalizeStrides(ctx, sizesB, stridesB, offsetsB, stridesA); - for (auto &&[strideA, strideB] : - llvm::zip(llvm::reverse(stridesA), llvm::reverse(stridesB))) { - std::optional maybeStrideA = getConstantIntValue(strideA); - std::optional maybeStrideB = getConstantIntValue(strideB); - // Handle static and constant value with same int value. - if (maybeStrideA && maybeStrideB && - maybeStrideA.value() == maybeStrideB.value()) { - continue; - } - if (strideA != strideB) return false; + // Check that strides and sizes are compatible for merging. + if (stridesA != stridesB) return false; + if (ArrayRef(sizesA).drop_front() != + ArrayRef(sizesB).drop_front()) { + return false; } - // Don't check the outermost dimension of size at this point. - SmallVector innerSizesA; - SmallVector innerSizesB; - std::copy(sizesA.begin() + 1, sizesA.end(), std::back_inserter(innerSizesA)); - std::copy(sizesB.begin() + 1, sizesB.end(), std::back_inserter(innerSizesB)); - for (auto &&[sizeA, sizeB] : - llvm::zip(llvm::reverse(innerSizesA), llvm::reverse(innerSizesB))) { - std::optional maybeSizeA = getConstantIntValue(sizeA); - std::optional maybeSizeB = getConstantIntValue(sizeB); - // Handle static and constant value with same int value. - if (maybeSizeA && maybeSizeB && maybeSizeA.value() == maybeSizeB.value()) { - continue; - } - if (sizeA != sizeB) return false; + std::optional maybeOffsetDifference = + getGlobalOffsetDifference(offsetsB, stridesB, offsetsA, stridesA); + + // The case where the global offset difference is not constant is difficult to + // handle, unless we can prove that it is non-negative. Leaving this edge case + // for future work. + if (!maybeOffsetDifference.has_value()) return false; + + auto offsetDifference = maybeOffsetDifference.value(); + + // The special case where the global offset difference exactly matches the + // pattern of A, in this case no new dimension is needed when merging the + // patterns. + if (offsetDifference == sizesA[0] * stridesA[0]) { + sizesA[0] += sizesB[0]; + return true; } - // Edge case for sizesA[0] != sizesB[0]. - if (offsetsB.size() == offsetsA.size() && sizesA[0] != sizesB[0]) { - std::optional constOffsetA = getConstantIntValue(offsetsA[0]); - std::optional constSizeA = getConstantIntValue(sizesA[0]); - std::optional constOffsetB = getConstantIntValue(offsetsB[0]); - std::optional constSizeB = getConstantIntValue(sizesB[0]); - if (constOffsetA && constOffsetB && constSizeA && constSizeB) { - int64_t offsetDiff = constOffsetB.value() - constOffsetA.value(); - if (constSizeA.value() != offsetDiff) return false; - } else { - return false; - } + // This is the case where the 2 patterns don't connect seamlessly, and we need + // to introduce a new dimension to contain the new offset. + else if (sizesA[0] == sizesB[0] && offsetDifference >= 0) { + sizesA.insert(sizesA.begin(), 2); + stridesA.insert(stridesA.begin(), offsetDifference); + OpFoldResult zeroFoldResult = getAsIndexOpFoldResult(&ctx, 0); + offsetsA.insert(offsetsA.begin(), zeroFoldResult); + return true; } - bool foundDiff{false}; - for (auto iter : llvm::enumerate( - llvm::zip(llvm::reverse(offsetsA), llvm::reverse(offsetsB)))) { - const OpFoldResult &offsetA = std::get<0>(iter.value()); - const OpFoldResult &offsetB = std::get<1>(iter.value()); - if (offsetA == offsetB) continue; - std::optional maybeOffsetA = getConstantIntValue(offsetA); - std::optional maybeOffsetB = getConstantIntValue(offsetB); - if (maybeOffsetA && maybeOffsetB && - maybeOffsetA.value() == maybeOffsetB.value()) { - continue; - } - // Retrieve the corresponding stride for this dimension. - std::optional maybeStride = - getConstantIntValue(stridesA[stridesA.size() - 1 - iter.index()]); - if (maybeOffsetA && maybeOffsetB && maybeStride) { - int64_t diff = - (maybeOffsetB.value() - maybeOffsetA.value()) * maybeStride.value(); - // Handle the three different size cases. Return early in case of an - // incompatibility. - if (offsetsA.size() > offsetsB.size()) { - std::optional constOffset = getConstantIntValue(offsetsA[0]); - std::optional constStride = getConstantIntValue(stridesA[0]); - std::optional constSize = getConstantIntValue(sizesA[0]); - if (constOffset && constStride && constSize && - constOffset.value() == 0 && - (constStride.value() * constSize.value()) == diff) { - if (foundDiff) return false; - foundDiff = true; - } else { - return false; - } - } else if (offsetsB.size() > offsetsA.size()) { - std::optional constOffset = getConstantIntValue(offsetsB[0]); - std::optional constStride = getConstantIntValue(stridesB[0]); - if (constOffset && constStride && constOffset.value() == 0 && - constStride.value() == diff) { - if (foundDiff) return false; - foundDiff = true; - } else { - return false; - } - } else { - if (foundDiff) return false; - foundDiff = true; - } - } else { - return false; - } + return false; +} + +/// If all elements are int64_ts, then return a vector of int64_ts. +std::optional> getIntVals(ArrayRef ofrs) { + if (std::any_of(ofrs.begin(), ofrs.end(), [](OpFoldResult ofr) { + return !getConstantIntValue(ofr).has_value(); + })) { + return std::nullopt; } - return foundDiff; + return llvm::to_vector<4>(llvm::map_range( + ofrs, [](OpFoldResult ofr) { return getConstantIntValue(ofr).value(); })); } -LogicalResult combineAccessPatterns(RewriterBase &rewriter, - const SmallVector &offsetsA, - const SmallVector &sizesA, - const SmallVector &stridesA, - const SmallVector &offsetsB, - const SmallVector &sizesB, - const SmallVector &stridesB, - SmallVector &newOffsets, - SmallVector &newSizes, - SmallVector &newStrides, - function_ref exceedsNbDims) { +} // namespace + +LogicalResult combineAccessPatterns( + MLIRContext &ctx, ArrayRef offsetsA, + ArrayRef sizesA, ArrayRef stridesA, + ArrayRef offsetsB, ArrayRef sizesB, + ArrayRef stridesB, SmallVector &newOffsets, + SmallVector &newSizes, SmallVector &newStrides, + function_ref exceedsNbDims) { assert(offsetsA.size() == sizesA.size() && "expected same number of source offsets and sizes"); assert(offsetsA.size() == stridesA.size() && @@ -210,89 +240,44 @@ LogicalResult combineAccessPatterns(RewriterBase &rewriter, "expected same number of source offsets and sizes"); assert(offsetsB.size() == stridesB.size() && "expected same number of source offsets and strides"); - if (!areAccessPatternsCombinable(offsetsA, sizesA, stridesA, offsetsB, sizesB, - stridesB, exceedsNbDims)) { - return failure(); - } - if (offsetsA.empty() && offsetsB.empty()) return success(); - if (offsetsB.size() > offsetsA.size()) { - newOffsets = offsetsB; - newSizes = sizesB; - newStrides = stridesB; - // If the offset on the first dimension of B is larger than zero, we can - // just decrease that one by one to accomplish the access pattern merge. - // Otherwise, we check for and update the other differing offsets. - std::optional offset = getConstantIntValue(newOffsets[0]); - if (offset && offset.value() > 0) { - newOffsets[0] = rewriter.getI64IntegerAttr(offset.value() - 1); - } else { - for (int i = 1; i <= offsetsA.size(); i++) { - if (offsetsA[offsetsA.size() - i] != offsetsB[offsetsB.size() - i]) { - newOffsets[newOffsets.size() - i] = offsetsA[offsetsA.size() - i]; - break; - } - } - } - std::optional size = getConstantIntValue(newSizes[0]); - if (!size) return failure(); - newSizes[0] = rewriter.getI64IntegerAttr(size.value() + 1); - } else if (offsetsA.size() > offsetsB.size()) { - newOffsets = offsetsA; - newSizes = sizesA; - newStrides = stridesA; - std::optional size = getConstantIntValue(newSizes[0]); - if (!size) return failure(); - newSizes[0] = rewriter.getI64IntegerAttr(size.value() + 1); - } else { - // Edge case for sizesA[0] != sizesB[0]. - if (sizesA[0] != sizesB[0]) { - newOffsets = offsetsA; - newSizes = sizesA; - newStrides = stridesA; - std::optional sizeA = getConstantIntValue(sizesA[0]); - std::optional sizeB = getConstantIntValue(sizesB[0]); - if (!sizeA || !sizeB) return failure(); - newSizes[0] = rewriter.getI64IntegerAttr(sizeA.value() + sizeB.value()); - } else { - // All dims of sizes are the same, so add a new dimension with - // 'offset == 0', 'size == 2' and 'stride == offsetDiff'. - newOffsets.push_back(rewriter.getI64IntegerAttr(0)); - int64_t offsetDiff{0}; - int64_t strideMultiplier{0}; - for (auto iter : llvm::enumerate(llvm::zip(offsetsA, offsetsB))) { - const OpFoldResult &offsetA = std::get<0>(iter.value()); - const OpFoldResult &offsetB = std::get<1>(iter.value()); - newOffsets.push_back(offsetA); - if (offsetA != offsetB) { - std::optional constOffsetA = getConstantIntValue(offsetA); - std::optional constOffsetB = getConstantIntValue(offsetB); - if (!constOffsetA || !constOffsetB) { - return emitError(rewriter.getUnknownLoc()) - << "differing offsets should be constants"; - } - offsetDiff = constOffsetB.value() - constOffsetA.value(); - std::optional maybeStride = - getConstantIntValue(stridesA[iter.index()]); - if (!maybeStride) { - return emitError(rewriter.getUnknownLoc()) - << "no constant stride found at the same index where the " - "offset " - "difference occurs"; - } - strideMultiplier = maybeStride.value(); - } - } - newSizes.push_back(rewriter.getI64IntegerAttr(2)); - newSizes.append(sizesA.begin(), sizesA.end()); - newStrides.push_back( - rewriter.getI64IntegerAttr(offsetDiff * strideMultiplier)); - newStrides.append(stridesA.begin(), stridesA.end()); - } + + auto maybeSizesA = getIntVals(sizesA); + if (!maybeSizesA.has_value()) return failure(); + SmallVector iSizesA = std::move(maybeSizesA.value()); + + auto maybeStridesA = getIntVals(stridesA); + if (!maybeStridesA.has_value()) return failure(); + SmallVector iStridesA = std::move(maybeStridesA.value()); + + auto maybeSizesB = getIntVals(sizesB); + if (!maybeSizesB.has_value()) return failure(); + SmallVector iSizesB = std::move(maybeSizesB.value()); + + auto maybeStridesB = getIntVals(stridesB); + if (!maybeStridesB.has_value()) return failure(); + SmallVector iStridesB = std::move(maybeStridesB.value()); + + SmallVector mergedOffsets(offsetsA.begin(), offsetsA.end()); + + bool combined = + mergeInFirst(ctx, mergedOffsets, iSizesA, iStridesA, + SmallVector(offsetsB.begin(), offsetsB.end()), + iSizesB, iStridesB); + + if (!combined) return failure(); + + SmallVector mergedSizes; + SmallVector mergedStrides; + for (int i = 0; i < mergedOffsets.size(); ++i) { + mergedSizes.push_back(getAsIndexOpFoldResult(&ctx, iSizesA[i])); + mergedStrides.push_back(getAsIndexOpFoldResult(&ctx, iStridesA[i])); } - assert(newOffsets.size() == newSizes.size() && - "expected same number of new offsets and sizes"); - assert(newOffsets.size() == newStrides.size() && - "expected same number of new offsets and strides"); + + (void)foldUnitDims(ctx, mergedOffsets, mergedSizes, mergedStrides, newOffsets, + newSizes, newStrides); + + if (exceedsNbDims(newOffsets.size())) return failure(); + return success(); } @@ -430,53 +415,76 @@ LogicalResult foldSingleDim(SmallVector &offsets, return success(); } -/// Fold unit dimensions within a strided access pattern. There are two cases -/// being handled here: -/// 1. If a dimension has `size == 1` and `offset == 0`, the dimension can be -/// folded entirely. -/// 2. If a dimension has `size == 1` and `offset != 0`, it can be folded into -/// another dimension with the same stride if that exists. -LogicalResult foldUnitDims(MLIRContext *ctx, - const SmallVector &offsets, - const SmallVector &sizes, - const SmallVector &strides, +bool mergeConstantOffsetIn(MLIRContext &ctx, int64_t offsetToMerge, + SmallVector &offsets, + ArrayRef strides) { + if (offsetToMerge == 0) return true; + for (uint32_t i = 0; i < offsets.size(); ++i) { + auto cOffset = getConstantIntValue(offsets[i]); + auto cStride = getConstantIntValue(strides[i]); + if (cOffset.has_value() && cStride.has_value()) { + auto offset = cOffset.value(); + auto stride = cStride.value(); + if (offsetToMerge % stride == 0) { + offset += offsetToMerge / stride; + offsets[i] = getAsIndexOpFoldResult(&ctx, offset); + return true; + } + } + } + return false; +} + +LogicalResult foldUnitDims(MLIRContext &ctx, ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides, SmallVector &newOffsets, SmallVector &newSizes, SmallVector &newStrides) { - bool foldableUnitDimsFound = false; - DenseMap> strideToIndexAndOffset; - for (int i = 0; i < offsets.size(); i++) { - // If a dimension has `size == 1` and `offset == 0`, the dimension can be - /// folded entirely. - if (isConstantIntValue(offsets[i], 0) && isConstantIntValue(sizes[i], 1)) { - foldableUnitDimsFound = true; - continue; + // All size-1 dimensions with constant offset will be merged into one (or + // maybe even zero) dimensions. + int64_t constantOffset{0}; + for (int i = 0; i < offsets.size(); ++i) { + auto cOffset = getConstantIntValue(offsets[i]); + auto cStride = getConstantIntValue(strides[i]); + if (cOffset.has_value() && cStride.has_value() && + isConstantIntValue(sizes[i], 1)) { + constantOffset += cOffset.value() * cStride.value(); + } else { + newOffsets.push_back(offsets[i]); + newSizes.push_back(sizes[i]); + newStrides.push_back(strides[i]); } - std::optional maybeOffset = getConstantIntValue(offsets[i]); - std::optional maybeStride = getConstantIntValue(strides[i]); - if (maybeOffset && maybeStride) { - int64_t offset = maybeOffset.value(); - int64_t stride = maybeStride.value(); - if (isConstantIntValue(sizes[i], 1) && - strideToIndexAndOffset.contains(stride)) { - foldableUnitDimsFound = true; - strideToIndexAndOffset[stride].second += offset; - // Continue to not add to newOffsets, newSizes, newStrides - continue; - } else { - strideToIndexAndOffset[stride] = {newOffsets.size(), offset}; + } + + bool mergedIntoExistingDim = + mergeConstantOffsetIn(ctx, constantOffset, newOffsets, newStrides); + + if (!mergedIntoExistingDim) { + OpFoldResult one = getAsIndexOpFoldResult(&ctx, 1); + OpFoldResult off = getAsIndexOpFoldResult(&ctx, constantOffset); + newOffsets.insert(newOffsets.begin(), one); + newSizes.insert(newSizes.begin(), one); + newStrides.insert(newStrides.begin(), off); + } + + // Ensure all size-1 dimensions are at the front. + uint32_t insertionIndex = 0; + for (uint32_t i = 0; i < newOffsets.size(); ++i) { + if (i > insertionIndex) { + if (isConstantIntValue(newSizes[i], 1)) { + std::swap(newOffsets[i], newOffsets[insertionIndex]); + std::swap(newSizes[i], newSizes[insertionIndex]); + std::swap(newStrides[i], newStrides[insertionIndex]); + insertionIndex++; } } - newOffsets.push_back(offsets[i]); - newStrides.push_back(strides[i]); - newSizes.push_back(sizes[i]); - } - // Update offsets - for (auto &&[stride, indexAndOffset] : strideToIndexAndOffset) { - newOffsets[indexAndOffset.first] = - getAsIndexOpFoldResult(ctx, indexAndOffset.second); } - return success(foldableUnitDimsFound); + + // Technically we should return success if there was just a permutation of + // dimensions, but I'm nervous that we could get into an infinite loop -- we + // really need an energy function for canonicalization. + return newOffsets.size() < offsets.size() ? success() : failure(); } LogicalResult moveNpuDmaSyncUsersAfterAncestorInSameBlock( diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEDmaUtils.h b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEDmaUtils.h index 0c909c5f3..35bfe9dd8 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEDmaUtils.h +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEDmaUtils.h @@ -7,12 +7,7 @@ #ifndef IREE_AMD_AIE_TRANSFORMS_AMDAIEDMAUTILS_H_ #define IREE_AMD_AIE_TRANSFORMS_AMDAIEDMAUTILS_H_ -#include "iree-amd-aie/IR/AMDAIEAttrs.h" -#include "iree-amd-aie/IR/AMDAIEDmaOpInterface.h" -#include "iree-amd-aie/IR/AMDAIEOps.h" #include "iree-amd-aie/aie_runtime/iree_aie_runtime.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpDefinition.h" @@ -84,33 +79,18 @@ bool areAccessPatternsEqualFromIndices(ArrayRef offsetsA, ArrayRef stridesB, size_t indexA = 0, size_t indexB = 0); -/// Check whether the two access patterns of strided operations can be combined -/// into one. Takes a `maxNbDims` argument to check whether a combined access -/// pattern would not exceed the maximum number of dimensions. -bool areAccessPatternsCombinable(const SmallVector &offsetsA, - const SmallVector &sizesA, - const SmallVector &stridesA, - const SmallVector &offsetsB, - const SmallVector &sizesB, - const SmallVector &stridesB, - function_ref exceedsNbDims); - /// Combine two access patterns into a single one. Assumes that access pattern A /// belongs to a strided op which is ordered before the strided op B. Takes a /// `maxNbDims` argument to ensure that a combined access pattern would not /// exceed the maximum number of dimensions. Returns `success` if the access /// patterns were combined successfully. -LogicalResult combineAccessPatterns(RewriterBase &rewriter, - const SmallVector &offsetsA, - const SmallVector &sizesA, - const SmallVector &stridesA, - const SmallVector &offsetsB, - const SmallVector &sizesB, - const SmallVector &stridesB, - SmallVector &newOffsets, - SmallVector &newSizes, - SmallVector &newStrides, - function_ref exceedsNbDims); +LogicalResult combineAccessPatterns( + MLIRContext &rewriter, ArrayRef offsetsA, + ArrayRef sizesA, ArrayRef stridesA, + ArrayRef offsetsB, ArrayRef sizesB, + ArrayRef stridesB, SmallVector &newOffsets, + SmallVector &newSizes, SmallVector &newStrides, + function_ref exceedsNbDims); /// Fold subsequent dimensions within a strided access pattern that describe a /// single linear access. Returns `success` if folding took place. @@ -186,10 +166,9 @@ LogicalResult foldSingleDim(SmallVector &offsets, /// Note that the dimensions are merged into the outermost one. Heuristically, /// this works out best with other strided access pattern transformations, but /// could be made an option in the future. -LogicalResult foldUnitDims(MLIRContext *ctx, - const SmallVector &offsets, - const SmallVector &strides, - const SmallVector &sizes, +LogicalResult foldUnitDims(MLIRContext &ctx, ArrayRef offsets, + ArrayRef strides, + ArrayRef sizes, SmallVector &newOffsets, SmallVector &newStrides, SmallVector &newSizes); diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/AMDAIEDmaUtilsTest.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/AMDAIEDmaUtilsTest.cpp index ccb08bae5..b1fe6a0ce 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/AMDAIEDmaUtilsTest.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/AMDAIEDmaUtilsTest.cpp @@ -6,9 +6,9 @@ #include "gtest/gtest.h" #include "iree-amd-aie/Transforms/Utils/AMDAIEDmaUtils.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" namespace { @@ -28,7 +28,7 @@ class AccessPatternCombinationTest : public ::testing::Test { SmallVector toOpFoldResult(const SmallVector &values) { return llvm::map_to_vector(values, [&](int64_t v) -> OpFoldResult { - return rewriter.getI64IntegerAttr(v); + return getAsIndexOpFoldResult(&context, v); }); } @@ -44,9 +44,13 @@ class AccessPatternCombinationTest : public ::testing::Test { SmallVector offsetsValuesB = toOpFoldResult(offsetsB); SmallVector sizesValuesB = toOpFoldResult(sizesB); SmallVector stridesValuesB = toOpFoldResult(stridesB); - return areAccessPatternsCombinable( - offsetsValuesA, sizesValuesA, stridesValuesA, offsetsValuesB, - sizesValuesB, stridesValuesB, exceedsNbDims); + SmallVector newOffsets; + SmallVector newSizes; + SmallVector newStrides; + return succeeded(combineAccessPatterns( + context, offsetsValuesA, sizesValuesA, stridesValuesA, offsetsValuesB, + sizesValuesB, stridesValuesB, newOffsets, newSizes, newStrides, + exceedsNbDims)); } bool checkAreAccessPatternsCombinable(const SmallVector &offsetsA, @@ -81,35 +85,33 @@ class AccessPatternCombinationTest : public ::testing::Test { toOpFoldResult(expectedSizes); SmallVector expectedStridesValues = toOpFoldResult(expectedStrides); + SmallVector newOffsets; SmallVector newSizes; SmallVector newStrides; if (shouldSucceed) { EXPECT_TRUE(succeeded(combineAccessPatterns( - rewriter, offsetsValuesA, sizesValuesA, stridesValuesA, - offsetsValuesB, sizesValuesB, stridesValuesB, newOffsets, newSizes, - newStrides, exceedsNbDims))); + context, offsetsValuesA, sizesValuesA, stridesValuesA, offsetsValuesB, + sizesValuesB, stridesValuesB, newOffsets, newSizes, newStrides, + exceedsNbDims))); EXPECT_EQ(newOffsets, expectedOffsetsValues); EXPECT_EQ(newSizes, expectedSizesValues); EXPECT_EQ(newStrides, expectedStridesValues); } else { EXPECT_TRUE(failed(combineAccessPatterns( - rewriter, offsetsValuesA, sizesValuesA, stridesValuesA, - offsetsValuesB, sizesValuesB, stridesValuesB, newOffsets, newSizes, - newStrides, exceedsNbDims))); + context, offsetsValuesA, sizesValuesA, stridesValuesA, offsetsValuesB, + sizesValuesB, stridesValuesB, newOffsets, newSizes, newStrides, + exceedsNbDims))); } } - void checkCombineAccessPatterns(const SmallVector offsetsA, - const SmallVector sizesA, - const SmallVector stridesA, - const SmallVector offsetsB, - const SmallVector sizesB, - const SmallVector stridesB, - const SmallVector expectedOffsets, - const SmallVector expectedSizes, - const SmallVector expectedStrides, - size_t maxNbDims, bool shouldSucceed = true) { + void checkCombineAccessPatterns( + SmallVector offsetsA, SmallVector sizesA, + SmallVector stridesA, SmallVector offsetsB, + SmallVector sizesB, SmallVector stridesB, + SmallVector expectedOffsets, SmallVector expectedSizes, + SmallVector expectedStrides, size_t maxNbDims, + bool shouldSucceed = true) { checkCombineAccessPatterns( offsetsA, sizesA, stridesA, offsetsB, sizesB, stridesB, expectedOffsets, expectedSizes, expectedStrides, @@ -120,14 +122,15 @@ class AccessPatternCombinationTest : public ::testing::Test { IRRewriter rewriter; Location loc; }; - TEST_F(AccessPatternCombinationTest, CombinableAccessPatterns) { EXPECT_TRUE(checkAreAccessPatternsCombinable({}, {}, {}, {}, {}, {}, 1)); // size(A) == size(B) EXPECT_TRUE( checkAreAccessPatternsCombinable({0}, {16}, {1}, {32}, {16}, {1}, 2)); + EXPECT_TRUE(checkAreAccessPatternsCombinable({0, 0}, {16, 32}, {64, 1}, {0, 32}, {16, 32}, {64, 1}, 4)); + EXPECT_TRUE(checkAreAccessPatternsCombinable({1, 0}, {16, 32}, {64, 1}, {1, 32}, {16, 32}, {64, 1}, 4)); EXPECT_TRUE(checkAreAccessPatternsCombinable({0, 0, 0}, {16, 16, 32}, @@ -136,6 +139,7 @@ TEST_F(AccessPatternCombinationTest, CombinableAccessPatterns) { EXPECT_TRUE(checkAreAccessPatternsCombinable({0, 2, 0}, {16, 16, 32}, {32, 64, 1}, {0, 2, 32}, {16, 16, 32}, {32, 64, 1}, 4)); + EXPECT_TRUE(checkAreAccessPatternsCombinable({32, 0}, {64, 64}, {128, 1}, {96, 0}, {32, 64}, {128, 1}, 4)); // Same access patterns @@ -183,6 +187,7 @@ TEST_F(AccessPatternCombinationTest, NonCombinableAccessPatterns) { // Too few dimensions EXPECT_FALSE( checkAreAccessPatternsCombinable({0}, {16}, {1}, {32}, {16}, {1}, 1)); + EXPECT_FALSE( checkAreAccessPatternsCombinable({0}, {32}, {1}, {0}, {32}, {1}, 1)); EXPECT_FALSE(checkAreAccessPatternsCombinable({0, 0}, {16, 32}, {64, 1}, @@ -244,12 +249,12 @@ TEST_F(AccessPatternCombinationTest, NoDims) { } TEST_F(AccessPatternCombinationTest, CombineAccessPatterns) { - checkCombineAccessPatterns({}, {}, {}, {}, {}, {}, {}, {}, {}, 1); // size(A) == size(B) - checkCombineAccessPatterns({0}, {16}, {1}, {32}, {16}, {1}, {0, 0}, {2, 16}, - {32, 1}, 2); checkCombineAccessPatterns({0, 0}, {8, 16}, {8, 1}, {0, 32}, {8, 16}, {8, 1}, {0, 0, 0}, {2, 8, 16}, {32, 8, 1}, 3); + checkCombineAccessPatterns({}, {}, {}, {}, {}, {}, {}, {}, {}, 1); + checkCombineAccessPatterns({0}, {16}, {1}, {32}, {16}, {1}, {0, 0}, {2, 16}, + {32, 1}, 2); checkCombineAccessPatterns({0, 32}, {8, 16}, {8, 1}, {0, 64}, {8, 16}, {8, 1}, {0, 0, 32}, {2, 8, 16}, {32, 8, 1}, 3); checkCombineAccessPatterns({1, 32}, {8, 16}, {8, 1}, {1, 64}, {8, 16}, {8, 1}, @@ -273,6 +278,7 @@ TEST_F(AccessPatternCombinationTest, CombineAccessPatterns) { checkCombineAccessPatterns({32, 0}, {64, 64}, {128, 1}, {96, 0}, {32, 64}, {128, 1}, {32, 0}, {96, 64}, {128, 1}, 4); // size(A) == size(B) Same access pattern + checkCombineAccessPatterns({0}, {32}, {1}, {0}, {32}, {1}, {0, 0}, {2, 32}, {0, 1}, 2); checkCombineAccessPatterns({0, 0}, {16, 32}, {16, 1}, {0, 0}, {16, 32}, @@ -299,8 +305,8 @@ TEST_F(AccessPatternCombinationTest, CombineAccessPatterns) { {8, 32}, {0, 1}, 3); checkCombineAccessPatterns({1, 0}, {7, 32}, {0, 1}, {0}, {32}, {1}, {1, 0}, {8, 32}, {0, 1}, 3); - checkCombineAccessPatterns({1, 0}, {0, 32}, {0, 1}, {0}, {32}, {1}, {1, 0}, - {1, 32}, {0, 1}, 3); + checkCombineAccessPatterns({1, 0}, {0, 32}, {0, 1}, {0}, {32}, {1}, {0}, {32}, + {1}, 3); // size(B) > size(A) checkCombineAccessPatterns({0}, {32}, {1}, {0, 64}, {2, 32}, {64, 1}, {0, 0}, {3, 32}, {64, 1}, 3); @@ -448,14 +454,14 @@ class FoldTest : public ::testing::Test { SmallVector newSizes; SmallVector newStrides; if (shouldSucceed) { - EXPECT_TRUE(succeeded(foldUnitDims(&context, offsetsValues, sizesValues, + EXPECT_TRUE(succeeded(foldUnitDims(context, offsetsValues, sizesValues, stridesValues, newOffsets, newSizes, newStrides))); EXPECT_EQ(newOffsets, expectedOffsetsValues); EXPECT_EQ(newSizes, expectedSizesValues); EXPECT_EQ(newStrides, expectedStridesValues); } else { - EXPECT_TRUE(failed(foldUnitDims(&context, offsetsValues, sizesValues, + EXPECT_TRUE(failed(foldUnitDims(context, offsetsValues, sizesValues, stridesValues, newOffsets, newSizes, newStrides))); } @@ -541,24 +547,24 @@ TEST_F(FoldTest, UnitDimsFullFold) { } TEST_F(FoldTest, UnitDimsMerge) { - checkFoldUnitDims({1, 1}, {1, 1}, {32, 32}, {2}, {1}, {32}, true); - checkFoldUnitDims({1, 2}, {1, 1}, {32, 32}, {3}, {1}, {32}, true); - checkFoldUnitDims({2, 1}, {1, 1}, {32, 32}, {3}, {1}, {32}, true); - checkFoldUnitDims({1, 0, 1, 0}, {1, 32, 1, 8}, {1024, 32, 1024, 1}, {2, 0, 0}, - {1, 32, 8}, {1024, 32, 1}, true); - checkFoldUnitDims({1, 0, 2, 0}, {1, 32, 1, 8}, {1024, 32, 1024, 1}, {3, 0, 0}, - {1, 32, 8}, {1024, 32, 1}, true); - checkFoldUnitDims({2, 0, 1, 0}, {1, 32, 1, 8}, {1024, 32, 1024, 1}, {3, 0, 0}, - {1, 32, 8}, {1024, 32, 1}, true); + checkFoldUnitDims({1, 1}, {1, 1}, {32, 32}, {1}, {1}, {64}, true); + checkFoldUnitDims({1, 2}, {1, 1}, {32, 32}, {1}, {1}, {96}, true); + checkFoldUnitDims({2, 1}, {1, 1}, {32, 32}, {1}, {1}, {96}, true); + checkFoldUnitDims({1, 0, 1, 0}, {1, 32, 1, 8}, {1024, 32, 1024, 1}, {64, 0}, + {32, 8}, {32, 1}, true); + checkFoldUnitDims({1, 0, 2, 0}, {1, 32, 1, 8}, {1024, 32, 1024, 1}, {96, 0}, + {32, 8}, {32, 1}, true); + checkFoldUnitDims({2, 0, 1, 0}, {1, 32, 1, 8}, {1024, 32, 1024, 1}, {96, 0}, + {32, 8}, {32, 1}, true); } TEST_F(FoldTest, UnitDimsFoldAndMerge) { - checkFoldUnitDims({1, 0, 1}, {1, 1, 1}, {32, 1024, 32}, {2}, {1}, {32}, true); - checkFoldUnitDims({1, 0, 1}, {1, 1, 1}, {32, 32, 32}, {2}, {1}, {32}, true); - checkFoldUnitDims({1, 0, 2, 0}, {1, 1, 1, 1}, {32, 32, 32, 32}, {3}, {1}, - {32}, true); - checkFoldUnitDims({1, 0, 1, 0}, {1, 1, 1, 8}, {1024, 32, 1024, 1}, {2, 0}, - {1, 8}, {1024, 1}, true); + checkFoldUnitDims({1, 0, 1}, {1, 1, 1}, {32, 1024, 32}, {1}, {1}, {64}, true); + checkFoldUnitDims({1, 0, 1}, {1, 1, 1}, {32, 32, 32}, {1}, {1}, {64}, true); + checkFoldUnitDims({1, 0, 2, 0}, {1, 1, 1, 1}, {32, 32, 32, 32}, {1}, {1}, + {96}, true); + checkFoldUnitDims({1, 0, 1, 0}, {1, 1, 1, 8}, {1024, 32, 1024, 1}, {2048}, + {8}, {1}, true); } TEST_F(FoldTest, FoldRepetitionCount) { diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/canonicalize_doubly_strided_op.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/canonicalize_doubly_strided_op.mlir index 9d39c03d0..47922a760 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/canonicalize_doubly_strided_op.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/canonicalize_doubly_strided_op.mlir @@ -77,8 +77,8 @@ func.func @circular_dma_cpy_nd_unit_between_linear(%arg0: !amdaie.logicalobjectf // ----- // CHECK-LABEL: func.func @circular_dma_cpy_nd_non_zero_offset -// CHECK: amdaie.circular_dma_cpy_nd(%{{.+}}[3, 1, 1] [1, 8, 16] [128, 16, 1], %{{.+}}[1, 1, 1, 1] [1, 4, 2, 8] [64, 16, 8, 1]) -// FOLD-SINGLE-DIMS: amdaie.circular_dma_cpy_nd(%{{.+}}[3, 1, 1] [1, 8, 16] [128, 16, 1], %{{.+}}[1, 1, 1, 1] [1, 4, 2, 8] [64, 16, 8, 1]) +// CHECK: amdaie.circular_dma_cpy_nd(%{{.+}}[25, 1] [8, 16] [16, 1], %{{.+}}[5, 1, 1] [4, 2, 8] [16, 8, 1]) +// FOLD-SINGLE-DIMS: amdaie.circular_dma_cpy_nd(%{{.+}}[25, 1] [8, 16] [16, 1], %{{.+}}[5, 1, 1] [4, 2, 8] [16, 8, 1]) func.func @circular_dma_cpy_nd_non_zero_offset(%arg0: !amdaie.logicalobjectfifo>, %arg1: !amdaie.logicalobjectfifo>) { %0 = amdaie.circular_dma_cpy_nd(%arg0[2, 1, 1, 1] [1, 1, 8, 16] [128, 128, 16, 1], %arg1[1, 1, 1, 1] [1, 4, 2, 8] [64, 16, 8, 1]) : (!amdaie.logicalobjectfifo>, !amdaie.logicalobjectfifo>) "iree.keep"(%0) : (index) -> () @@ -174,8 +174,8 @@ func.func @dma_cpy_nd_unit_between_linear(%arg0: !amdaie.logicalobjectfifo>, %arg1: !amdaie.logicalobjectfifo>) { %0 = amdaie.dma_cpy_nd(%arg0[1, 2, 1, 1] [1, 1, 8, 16] [128, 128, 16, 1], %arg1[1, 1, 1, 1] [1, 4, 2, 8] [64, 16, 8, 1]) : (!amdaie.logicalobjectfifo>, !amdaie.logicalobjectfifo>) "iree.keep"(%0) : (index) -> () @@ -273,8 +273,8 @@ func.func @npu_dma_cpy_nd_unit_between_linear(%arg0: !amdaie.logicalobjectfifo>, %arg1: !amdaie.logicalobjectfifo>) { %0 = amdaie.circular_dma_cpy_nd(%arg0[] [] [], %arg1[] [] []) : (!amdaie.logicalobjectfifo>, !amdaie.logicalobjectfifo>) amdaie.npu.dma_cpy_nd %0([1, 2, 1, 1] [1, 1, 8, 16] [128, 128, 16, 1], [1, 1, 1, 1] [1, 4, 2, 8] [64, 16, 8, 1]) diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/combine_strided_ops.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/combine_strided_ops.mlir index 2bb6f0d24..6c032c265 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/combine_strided_ops.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/combine_strided_ops.mlir @@ -654,3 +654,30 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} return } } + +// ----- + +// CHECK: @with_index +// CHECK-SAME: !amdaie.logicalobjectfifo +// CHECK-SAME: !amdaie.logicalobjectfifo +// CHECK-SAME: %[[ARG2:.+]]: index +// CHECK: %[[CONNECTION:.+]] = amdaie.connection +// CHECK: amdaie.npu.circular_dma_cpy_nd +// CHECK-SAME: %[[CONNECTION]]([0, 0, 0, 0] [2, 32, 8, 8] [0, 8, 256, 1], +// CHECK-SAME: [%[[ARG2]], 0, 0, 0] [1, 2, 32, 64] [4096, 2048, 64, 1]) +#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}> +module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} { + func.func @with_index(%arg0: !amdaie.logicalobjectfifo, 2>, %arg1: !amdaie.logicalobjectfifo, 2>, %arg2: index) { + amdaie.workgroup { + %0 = amdaie.connection(%arg1, %arg0) : (!amdaie.logicalobjectfifo, 2>, !amdaie.logicalobjectfifo, 2>) + amdaie.controlcode { + %1 = amdaie.npu.circular_dma_cpy_nd %0([0, 0, 0] [32, 8, 8] [8, 256, 1], [%arg2, 0, 0] [1, 32, 64] [4096, 64, 1]) + %2 = amdaie.npu.circular_dma_cpy_nd %0([0, 0, 0] [32, 8, 8] [8, 256, 1], [%arg2, 1, 0, 0] [1, 1, 32, 64] [4096, 2048, 64, 1]) + // we check that the above 2 copies are combined to become + // amdaie.npu.circular_dma_cpy_nd %0([0, 0, 0, 0] [2, 32, 8, 8] [0, 8, 256, 1], [%arg2, 0, 0, 0] [1, 2, 32, 64] [4096, 2048 64, 1]) + amdaie.end + } + } + return + } +}