Skip to content

Commit

Permalink
u[date
Browse files Browse the repository at this point in the history
  • Loading branch information
newling committed Jan 20, 2025
1 parent 4e21f8d commit 7b21be2
Show file tree
Hide file tree
Showing 9 changed files with 432 additions and 432 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,12 @@ struct FoldDmaOpUnitDims
SmallVector<OpFoldResult> targetStrides = op.getTargetMixedStrides();
SmallVector<OpFoldResult> newSourceOffsets, newSourceSizes,
newSourceStrides, newTargetOffsets, newTargetSizes, newTargetStrides;
LogicalResult sourceRes =
foldUnitDims(op.getContext(), sourceOffsets, sourceSizes, sourceStrides,
newSourceOffsets, newSourceSizes, newSourceStrides);
LogicalResult targetRes =
foldUnitDims(op.getContext(), targetOffsets, targetSizes, targetStrides,
newTargetOffsets, newTargetSizes, newTargetStrides);
LogicalResult sourceRes = foldUnitDims(
*op.getContext(), sourceOffsets, sourceSizes, sourceStrides,
newSourceOffsets, newSourceSizes, newSourceStrides);
LogicalResult targetRes = foldUnitDims(
*op.getContext(), targetOffsets, targetSizes, targetStrides,
newTargetOffsets, newTargetSizes, newTargetStrides);
if (failed(sourceRes) && failed(targetRes)) {
return failure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "iree-amd-aie/Transforms/Transforms.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEDmaUtils.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-amdaie-combine-strided-ops"
Expand Down Expand Up @@ -47,8 +46,10 @@ struct CombineStridedOps

std::unique_ptr<DmaDimConfig> sourceDmaDimConfig;
std::unique_ptr<DmaDimConfig> targetDmaDimConfig;

SmallVector<Operation *> userOpsToBeErased;
AMDAIE::DoublyStridedOpInterface nextStridedOp;

if (auto npuDmaOp = dyn_cast<AMDAIE::NpuDmaCpyNdOp>(op.getOperation())) {
LLVM_DEBUG(llvm::dbgs() << "npuDmaOp: " << npuDmaOp << "\n");
// Fail if any non-wait user operations.
Expand Down Expand Up @@ -105,6 +106,10 @@ struct CombineStridedOps
return failure();
}

MLIRContext &context = *rewriter.getContext();
auto dimCountCheck = std::bind(&DmaDimConfig::exceedsNbDims,
std::ref(sourceDmaDimConfig), _1);

SmallVector<OpFoldResult> sourceOffsetsA = op.getSourceMixedOffsets();
SmallVector<OpFoldResult> sourceSizesA = op.getSourceMixedSizes();
SmallVector<OpFoldResult> sourceStridesA = op.getSourceMixedStrides();
Expand All @@ -114,11 +119,15 @@ struct CombineStridedOps
nextStridedOp.getSourceMixedSizes();
SmallVector<OpFoldResult> sourceStridesB =
nextStridedOp.getSourceMixedStrides();
bool areSourcesCombinable = areAccessPatternsCombinable(
sourceOffsetsA, sourceSizesA, sourceStridesA, sourceOffsetsB,
sourceSizesB, sourceStridesB,
std::bind(&DmaDimConfig::exceedsNbDims, std::ref(sourceDmaDimConfig),
_1));
SmallVector<OpFoldResult> newSourceOffsets;
SmallVector<OpFoldResult> newSourceSizes;
SmallVector<OpFoldResult> newSourceStrides;
if (failed(combineAccessPatterns(
context, sourceOffsetsA, sourceSizesA, sourceStridesA,
sourceOffsetsB, sourceSizesB, sourceStridesB, newSourceOffsets,
newSourceSizes, newSourceStrides, dimCountCheck))) {
return failure();
}

SmallVector<OpFoldResult> targetOffsetsA = op.getTargetMixedOffsets();
SmallVector<OpFoldResult> targetSizesA = op.getTargetMixedSizes();
Expand All @@ -129,53 +138,25 @@ struct CombineStridedOps
nextStridedOp.getTargetMixedSizes();
SmallVector<OpFoldResult> targetStridesB =
nextStridedOp.getTargetMixedStrides();
bool areTargetsCombinable = areAccessPatternsCombinable(
targetOffsetsA, targetSizesA, targetStridesA, targetOffsetsB,
targetSizesB, targetStridesB,
std::bind(&DmaDimConfig::exceedsNbDims, std::ref(targetDmaDimConfig),
_1));

LLVM_DEBUG(llvm::dbgs()
<< "areSourcesCombinable: " << areSourcesCombinable << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "areTargetsCombinable: " << areTargetsCombinable << "\n");

if (areSourcesCombinable && areTargetsCombinable) {
SmallVector<OpFoldResult> newSourceOffsets;
SmallVector<OpFoldResult> newSourceSizes;
SmallVector<OpFoldResult> newSourceStrides;
if (failed(combineAccessPatterns(
rewriter, sourceOffsetsA, sourceSizesA, sourceStridesA,
sourceOffsetsB, sourceSizesB, sourceStridesB, newSourceOffsets,
newSourceSizes, newSourceStrides,
std::bind(&DmaDimConfig::exceedsNbDims,
std::ref(sourceDmaDimConfig), _1)))) {
return failure();
}

SmallVector<OpFoldResult> newTargetOffsets;
SmallVector<OpFoldResult> newTargetSizes;
SmallVector<OpFoldResult> newTargetStrides;
if (failed(combineAccessPatterns(
rewriter, targetOffsetsA, targetSizesA, targetStridesA,
targetOffsetsB, targetSizesB, targetStridesB, newTargetOffsets,
newTargetSizes, newTargetStrides,
std::bind(&DmaDimConfig::exceedsNbDims,
std::ref(targetDmaDimConfig), _1)))) {
return failure();
}
SmallVector<OpFoldResult> newTargetOffsets;
SmallVector<OpFoldResult> newTargetSizes;
SmallVector<OpFoldResult> newTargetStrides;
if (failed(combineAccessPatterns(
context, targetOffsetsA, targetSizesA, targetStridesA,
targetOffsetsB, targetSizesB, targetStridesB, newTargetOffsets,
newTargetSizes, newTargetStrides, dimCountCheck))) {
return failure();
}

rewriter.setInsertionPoint(op);
auto newDoublyStridedOp = nextStridedOp.createDoublyStridedOp(
rewriter, newTargetOffsets, newTargetSizes, newTargetStrides,
newSourceOffsets, newSourceSizes, newSourceStrides);
rewriter.replaceOp(nextStridedOp, newDoublyStridedOp.getOperation());
rewriter.setInsertionPoint(op);
auto newDoublyStridedOp = nextStridedOp.createDoublyStridedOp(
rewriter, newTargetOffsets, newTargetSizes, newTargetStrides,
newSourceOffsets, newSourceSizes, newSourceStrides);
rewriter.replaceOp(nextStridedOp, newDoublyStridedOp.getOperation());

for (Operation *userOp : userOpsToBeErased) rewriter.eraseOp(userOp);
rewriter.eraseOp(op);
return success();
}
return failure();
for (Operation *userOp : userOpsToBeErased) rewriter.eraseOp(userOp);
rewriter.eraseOp(op);
return success();
}

template <typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ void AMDAIEDmaCompositionPass::runOnOperation() {
onlyZeroStrideOnOuterDim);
populateStridedOpCombinationPattern(patterns);
populateCanonicalizeDoublyStridedOpPatterns(patterns, false, deviceModel);

if (failed(applyPatternsGreedily(parentOp, std::move(patterns)))) {
parentOp->emitOpError("failed to compose strided operations");
return signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ LogicalResult AIEDeviceBuilder::foldDimsAndReturnAsStatic(
SmallVector<OpFoldResult> offsets(
strides.size(), getAsIndexOpFoldResult(rewriter.getContext(), 0));
SmallVector<OpFoldResult> unitOffsets, unitSizes, unitStrides, newOffsets;
(void)foldUnitDims(rewriter.getContext(), offsets, sizes, strides,
(void)foldUnitDims(*rewriter.getContext(), offsets, sizes, strides,
unitOffsets, unitSizes, unitStrides);
DmaDimConfig dmaDimConfig(deviceModel, memSpace);
SmallVector<int64_t> maxSizes = dmaDimConfig.getMaxSizes(unitOffsets.size());
Expand Down
Loading

0 comments on commit 7b21be2

Please sign in to comment.