Skip to content

Commit

Permalink
[COMMON] Push up the extract slice op
Browse files Browse the repository at this point in the history
Push the extract_slice ops to the beginning of the block if all its
operands are block arguments. This lets the bufferization framework know
the presense of subset buffer that can be reused.
  • Loading branch information
pashu123 committed Jan 11, 2025
1 parent 9f93691 commit 4ae2cb5
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,23 @@ void OptimizeTensorInsertExtractSlicesPass::runOnOperation() {
auto funcOp = getOperation();
IRRewriter rewriter(funcOp->getContext());

// Push the extract_slice ops to the beginning of the block if all its
// operands are block arguments. This lets the bufferization framework
// know the presence of subset buffer that can be reused.
funcOp.walk([&](tensor::ExtractSliceOp extractSliceOp) {
// Check that all operands of tensor extractSliceOp are block arguments
// and ensure they belong to the same block as the extractSliceOp.
auto currBlock = extractSliceOp.getOperation()->getBlock();
if (llvm::all_of(extractSliceOp.getOperands(), [&](Value operand) {
auto blockArg = dyn_cast<BlockArgument>(operand);
return blockArg && blockArg.getParentBlock() == currBlock;
})) {
// Move the extractSliceOp to the beginning of the block.
auto &firstOp = currBlock->getOperations().front();
extractSliceOp->moveBefore(&firstOp);
}
});

funcOp.walk([&](scf::ForOp forOp) { moveLoopInvariantCode(forOp); });
LDBG("after hoisting loop invariant code\n" << funcOp);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,18 @@ func.func @fold_identity_extract_slice(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK-LABEL: @fold_identity_extract_slice
// CHECK: %[[ARG0:.+]]: tensor<?xf32>
// CHECK: return %[[ARG0]]

// -----

func.func @push_up_extract_slice(%arg0: index, %arg1: vector<64x64xf32>, %arg2: tensor<2x4096x10x64xf16>) -> tensor<1x64x1x64xf16> {
%c0 = arith.constant 0 : index
%0 = tensor.empty() : tensor<64x64xf16>
%1 = arith.truncf %arg1 : vector<64x64xf32> to vector<64x64xf16>
%2 = vector.transfer_write %1, %0[%c0, %c0] {in_bounds = [true, true]} : vector<64x64xf16>, tensor<64x64xf16>
%extracted_slice = tensor.extract_slice %arg2[%arg0, 0, 0, 0] [1, 64, 1, 64] [1, 1, 1, 1] : tensor<2x4096x10x64xf16> to tensor<1x64x1x64xf16>
%inserted_slice = tensor.insert_slice %2 into %extracted_slice[0, 0, 0, 0] [1, 64, 1, 64] [1, 1, 1, 1] : tensor<64x64xf16> into tensor<1x64x1x64xf16>
return %inserted_slice : tensor<1x64x1x64xf16>
}
// CHECK-LABEL: @push_up_extract_slice
// CHECK: tensor.extract_slice
// CHECK: vector.transfer_write
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,6 @@ hal.executable private @main {
// CHECK: scf.forall ({{.*}}) in (17, 81) {
// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C721]] step %[[C1]] {{.*}} -> (vector<1x1x1x1x4x1xf32>)
// CHECK: gpu.barrier
// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<1xf16>
// CHECK-DAG: vector.transfer_write %[[LHS_RD]]
// Note that to simplify the test we are not showing the mapping of the RHS_RD
// to its buffer as it goes through an scf.if/else control structure
// involving allocas.
// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read {{.*}} vector<1xf16>
// CHECK-DAG: vector.transfer_write %[[RHS_RD]]
// CHECK: gpu.barrier
// CHECK-DAG: %[[LHS_MM0:.+]] = vector.transfer_read {{.*}} vector<4xf16>
// CHECK-DAG: %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<4x1x1xf16>
// CHECK-COUNT-1: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1151,11 +1151,6 @@ hal.executable public @main {
// CHECK: scf.forall ({{.*}}) in (12, 37, 10) {
// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c145 step %c1 {{.*}} -> (vector<1x1x1x4x1xf32>)
// CHECK: gpu.barrier
// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read {{.*}} vector<4xf32>
// CHECK-DAG: vector.transfer_write %[[LHS_RD]]
// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read {{.*}} vector<1xf32>
// CHECK-DAG: vector.transfer_write %[[RHS_RD]]
// CHECK: gpu.barrier
// CHECK-DAG: vector.transfer_read {{.*}} #gpu.address_space<workgroup>>, vector<1xf32>
// CHECK-DAG: vector.transfer_read {{.*}} #gpu.address_space<workgroup>>, vector<1xf32>
// CHECK-COUNT-1: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32
Expand Down

0 comments on commit 4ae2cb5

Please sign in to comment.