Skip to content

Commit

Permalink
[FuseConsumerIntoLoop] Extend to multiple results and generalize dept…
Browse files Browse the repository at this point in the history
…h fusion (#1027)

This PR extends the consumer fusion to work for matmul + elemwise in
case of an additional level of tiling (see added lit test).
  • Loading branch information
jtuyls authored Jan 14, 2025
1 parent 2f86199 commit 68d6527
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 26 deletions.
49 changes: 27 additions & 22 deletions build_tools/ci/cpu_comparison/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,7 @@ def __init__(
rhs,
expected_out,
run_on_target=["npu1_4col"],
tile_pipeline="pack-peel",
):
super().__init__(
run_on_target=run_on_target,
Expand All @@ -765,7 +766,7 @@ def __init__(
K=K,
input_type=input_type,
acc_type=acc_type,
tile_pipeline="pack-peel",
tile_pipeline=tile_pipeline,
n_repeats=1,
)
self.labels.append("MatmulTruncf")
Expand All @@ -776,6 +777,8 @@ def __init__(
assert expected_out.shape == (M, M)

self.name = f"matmul_truncf_{M}_{K}_{input_type}_{acc_type}"
if tile_pipeline == "pack-peel-4-level-tiling":
self.name += "_4_level_tiling"
self.lhs = lhs
self.rhs = rhs
self.expected_out = expected_out
Expand Down Expand Up @@ -1594,29 +1597,31 @@ def __init__(self):
self.tests = []

# Matmul with truncf test(s):
self.register(
MatmulTruncf(
16,
16,
"bf16",
"f32",
101 * np.ones([16, 16]),
3 * np.eye(16),
302 * np.ones([16, 16]),
for tile_pipeline in ["pack-peel", "pack-peel-4-level-tiling"]:
self.register(
MatmulTruncf(
16,
16,
"bf16",
"f32",
101 * np.ones([16, 16]),
3 * np.eye(16),
302 * np.ones([16, 16]),
tile_pipeline=tile_pipeline,
)
)
)

self.register(
MatmulTruncf(
128,
256,
"bf16",
"f32",
2 * np.ones([128, 256]),
3 * np.ones([256, 128]),
1536 * np.ones([128, 128]),
self.register(
MatmulTruncf(
128,
256,
"bf16",
"f32",
2 * np.ones([128, 256]),
3 * np.ones([256, 128]),
1536 * np.ones([128, 128]),
tile_pipeline=tile_pipeline,
)
)
)

# BatchMatmul test(s):
for input_type, acc_type in zip(["i32", "bf16"], ["i32", "f32"]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,20 @@ void AMDAIEFuseConsumerIntoLoopPass::runOnOperation() {
// with any depth instead.
for (unsigned depth = 1; depth <= fuseDepth; depth++) {
do {
Value::user_range users = producerOp->getResult(0).getUsers();
if (!llvm::hasSingleElement(users)) {
ResultRange results = producerOp->getResults();
SmallVector<Operation *> allUsers = std::accumulate(
results.begin(), results.end(), SmallVector<Operation *>{},
[](SmallVector<Operation *> init, OpResult res) {
for (Operation *op : res.getUsers()) init.push_back(op);
return init;
});
if (allUsers.size() != 1) {
LLVM_DEBUG(llvm::dbgs()
<< "Expected only one user of the compute op\n");
break;
}

Operation *candidateSliceOp = *(users.begin());
Operation *candidateSliceOp = allUsers[0];
if (!(isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
candidateSliceOp))) {
producerOp = producerOp->getParentOfType<LoopLikeOpInterface>();
Expand All @@ -127,7 +133,15 @@ void AMDAIEFuseConsumerIntoLoopPass::runOnOperation() {
}
changed = true;
fusedConsumer->origConsumerOperand->getOwner()->erase();
producerOp = fusedConsumer->tiledAndFusedConsumerOperand->getOwner();
Operation *fusedOp =
fusedConsumer->tiledAndFusedConsumerOperand->getOwner();
if (getAncestorInBlock(fusedOp, computeOp->getBlock()) != nullptr) {
// The consumer is fused all the way into the producer's block, so
// operate on this op from now on, but with reduced depth.
computeOp = fusedOp;
fuseDepth -= 1;
}
producerOp = fusedOp;
break;
} while (producerOp && producerOp->getParentOp() != funcOp);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,3 +547,106 @@ module {
return %6 : tensor<512x4096xf32>
}
}

// -----

// CHECK-LABEL: @matmul_elemwise_multiple_fusion_iterations
// CHECK: %[[EXTRACT_SLICE_0:.+]] = tensor.extract_slice %{{.+}}[%{{.+}}, %{{.+}}] [256, 256] [1, 1]
// CHECK: %[[FORALL_0:.+]]:3 = scf.forall (%[[ARG0:.+]], %[[ARG1:.+]]) = (0, 0) to (8, 8) step (4, 4) shared_outs(%[[MATMUL_OUT:.+]] = %{{.*}}, %[[ELEMWISE_OUT:.+]] = %{{.*}}, %[[UNPACK_OUT:.+]] = %{{.*}})
// CHECK: %[[FORALL_1:.+]]:3 = scf.forall (%[[ARG2:.+]], %[[ARG3:.+]]) in (4, 4) shared_outs(%[[MATMUL_LOCAL_OUT:.+]] = %{{.*}}, %[[ELEMWISE_LOCAL_OUT:.+]] = %{{.*}}, %[[UNPACK_LOCAL_OUT:.+]] = %{{.*}})
// CHECK-SAME: {
// CHECK: %[[MATMUL:.+]] = linalg.generic
// CHECK: arith.mulf
// CHECK: arith.addf
// CHECK: %[[ELEMWISE:.+]] = linalg.generic
// CHECK-SAME: ins(%[[MATMUL]] : tensor<1x1x8x8x4x4xf32>)
// CHECK: arith.truncf
// CHECK: %[[EXTRACT_SLICE_1:.+]] = tensor.extract_slice %[[UNPACK_LOCAL_OUT]][%[[ARG2]], %[[ARG3]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1]
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ELEMWISE]] outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %[[EXTRACT_SLICE_1]]
// CHECK: scf.forall.in_parallel {
// CHECK-DAG: tensor.parallel_insert_slice %[[MATMUL]] into %[[MATMUL_LOCAL_OUT]][%[[ARG2]], %[[ARG3]], 0, 0, 0, 0] [1, 1, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1]
// CHECK-DAG: tensor.parallel_insert_slice %[[ELEMWISE]] into %[[ELEMWISE_LOCAL_OUT]][%[[ARG2]], %[[ARG3]], 0, 0, 0, 0] [1, 1, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1]
// CHECK-DAG: tensor.parallel_insert_slice %[[UNPACK]] into %[[UNPACK_LOCAL_OUT]][%[[ARG2]], %[[ARG3]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1]
// CHECK: }
// CHECK: }
// CHECK: scf.forall.in_parallel {
// CHECK-DAG: tensor.parallel_insert_slice %[[FORALL_1]]#0 into %[[MATMUL_OUT]][%[[ARG0]], %[[ARG1]], 0, 0, 0, 0] [4, 4, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1]
// CHECK-DAG: tensor.parallel_insert_slice %[[FORALL_1]]#1 into %[[ELEMWISE_OUT]][%[[ARG0]], %[[ARG1]], 0, 0, 0, 0] [4, 4, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1]
// CHECK-DAG: tensor.parallel_insert_slice %[[FORALL_1]]#2 into %[[UNPACK_OUT]][%[[ARG0]], %[[ARG1]], 0, 0] [4, 4, 32, 32] [1, 1, 1, 1]
// CHECK: }
// CHECK: tensor.unpack %[[FORALL_0]]#2 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[EXTRACT_SLICE_0]]
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d2, d4, d5, d8, d7)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
module {
// expected-error @+1 {{Maximum number of iterations reached, consumer fusion is likely stuck in an infinite loop.}}
func.func @matmul_elemwise_multiple_fusion_iterations() -> tensor<512x4096xbf16> {
%alloc = memref.alloc() : memref<1x1x8x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x8x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<8x8x32x32xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<8x8x64x32xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<8x8x32x64xbf16, 1 : i32>
%0 = tensor.empty() : tensor<512x512xbf16>
%1 = tensor.empty() : tensor<512x4096xbf16>
%2 = tensor.empty() : tensor<512x4096xbf16>
%3 = scf.forall (%arg0, %arg1) = (0, 0) to (512, 4096) step (256, 256) shared_outs(%arg2 = %2) -> (tensor<512x4096xbf16>) {
%extracted_slice = tensor.extract_slice %0[%arg0, 0] [256, 512] [1, 1] : tensor<512x512xbf16> to tensor<256x512xbf16>
%extracted_slice_4 = tensor.extract_slice %1[0, %arg1] [512, 256] [1, 1] : tensor<512x4096xbf16> to tensor<512x256xbf16>
%extracted_slice_5 = tensor.extract_slice %arg2[%arg0, %arg1] [256, 256] [1, 1] : tensor<512x4096xbf16> to tensor<256x256xbf16>
%4 = bufferization.to_tensor %alloc_3 restrict writable : memref<8x8x32x64xbf16, 1 : i32> to tensor<8x8x32x64xbf16>
%pack = tensor.pack %extracted_slice outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %4 : tensor<256x512xbf16> -> tensor<8x8x32x64xbf16>
%5 = bufferization.to_tensor %alloc_2 restrict writable : memref<8x8x64x32xbf16, 1 : i32> to tensor<8x8x64x32xbf16>
%pack_6 = tensor.pack %extracted_slice_4 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %5 : tensor<512x256xbf16> -> tensor<8x8x64x32xbf16>
%6 = bufferization.to_tensor %alloc_1 restrict writable : memref<8x8x32x32xbf16, 1 : i32> to tensor<8x8x32x32xbf16>
%7 = tensor.empty() : tensor<8x8x8x8x4x4xbf16>
%8 = tensor.empty() : tensor<8x8x8x8x4x4xf32>
%9 = scf.forall (%arg3, %arg4) = (0, 0) to (8, 8) step (4, 4) shared_outs(%arg5 = %8) -> (tensor<8x8x8x8x4x4xf32>) {
%extracted_slice_8 = tensor.extract_slice %pack[%arg3, 0, 0, 0] [4, 8, 32, 64] [1, 1, 1, 1] : tensor<8x8x32x64xbf16> to tensor<4x8x32x64xbf16>
%extracted_slice_9 = tensor.extract_slice %pack_6[%arg4, 0, 0, 0] [4, 8, 64, 32] [1, 1, 1, 1] : tensor<8x8x64x32xbf16> to tensor<4x8x64x32xbf16>
%extracted_slice_10 = tensor.extract_slice %extracted_slice_8[0, 7, 0, 0] [4, 1, 32, 64] [1, 1, 1, 1] : tensor<4x8x32x64xbf16> to tensor<4x1x32x64xbf16>
%extracted_slice_11 = tensor.extract_slice %extracted_slice_9[0, 7, 0, 0] [4, 1, 64, 32] [1, 1, 1, 1] : tensor<4x8x64x32xbf16> to tensor<4x1x64x32xbf16>
%11 = tensor.empty() : tensor<4x4x8x8x4x4xf32>
%12 = scf.forall (%arg6, %arg7) in (4, 4) shared_outs(%arg8 = %11) -> (tensor<4x4x8x8x4x4xf32>) {
%extracted_slice_12 = tensor.extract_slice %extracted_slice_10[%arg6, 0, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : tensor<4x1x32x64xbf16> to tensor<1x1x32x64xbf16>
%13 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x8x4x8xbf16, 2 : i32> to tensor<1x1x8x8x4x8xbf16>
%pack_13 = tensor.pack %extracted_slice_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %13 : tensor<1x1x32x64xbf16> -> tensor<1x1x8x8x4x8xbf16>
%extracted_slice_14 = tensor.extract_slice %extracted_slice_11[%arg7, 0, 0, 0] [1, 1, 64, 32] [1, 1, 1, 1] : tensor<4x1x64x32xbf16> to tensor<1x1x64x32xbf16>
%14 = bufferization.to_tensor %alloc restrict writable : memref<1x1x8x8x8x4xbf16, 2 : i32> to tensor<1x1x8x8x8x4xbf16>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %14 : tensor<1x1x64x32xbf16> -> tensor<1x1x8x8x8x4xbf16>
%extracted_slice_16 = tensor.extract_slice %arg8[%arg6, %arg7, 0, 0, 0, 0] [1, 1, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<4x4x8x8x4x4xf32> to tensor<1x1x8x8x4x4xf32>
%15 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_13, %pack_15 : tensor<1x1x8x8x4x8xbf16>, tensor<1x1x8x8x8x4xbf16>) outs(%extracted_slice_16 : tensor<1x1x8x8x4x4xf32>) {
^bb0(%in: bf16, %in_17: bf16, %out: f32):
%16 = arith.extf %in : bf16 to f32
%17 = arith.extf %in_17 : bf16 to f32
%18 = arith.mulf %16, %17 : f32
%19 = arith.addf %out, %18 : f32
linalg.yield %19 : f32
} -> tensor<1x1x8x8x4x4xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %15 into %arg8[%arg6, %arg7, 0, 0, 0, 0] [1, 1, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x8x4x4xf32> into tensor<4x4x8x8x4x4xf32>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.forall.in_parallel {
tensor.parallel_insert_slice %12 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [4, 4, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<4x4x8x8x4x4xf32> into tensor<8x8x8x8x4x4xf32>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
%10 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%9 : tensor<8x8x8x8x4x4xf32>) outs(%7 : tensor<8x8x8x8x4x4xbf16>) {
^bb0(%in: f32, %out: bf16):
%11 = arith.truncf %in : f32 to bf16
linalg.yield %11 : bf16
} -> tensor<8x8x8x8x4x4xbf16>
%unpack = tensor.unpack %10 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %6 : tensor<8x8x8x8x4x4xbf16> -> tensor<8x8x32x32xbf16>
%unpack_7 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %extracted_slice_5 : tensor<8x8x32x32xbf16> -> tensor<256x256xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_7 into %arg2[%arg0, %arg1] [256, 256] [1, 1] : tensor<256x256xbf16> into tensor<512x4096xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
memref.dealloc %alloc_3 : memref<8x8x32x64xbf16, 1 : i32>
memref.dealloc %alloc_2 : memref<8x8x64x32xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<8x8x32x32xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x8x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x8x8x8x4xbf16, 2 : i32>
return %3 : tensor<512x4096xbf16>
}
}

0 comments on commit 68d6527

Please sign in to comment.