Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FuseConsumerIntoLoop] Extend to multiple results and generalize depth fusion #1027

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 27 additions & 22 deletions build_tools/ci/cpu_comparison/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,7 @@ def __init__(
rhs,
expected_out,
run_on_target=["npu1_4col"],
tile_pipeline="pack-peel",
):
super().__init__(
run_on_target=run_on_target,
Expand All @@ -765,7 +766,7 @@ def __init__(
K=K,
input_type=input_type,
acc_type=acc_type,
tile_pipeline="pack-peel",
tile_pipeline=tile_pipeline,
n_repeats=1,
)
self.labels.append("MatmulTruncf")
Expand All @@ -776,6 +777,8 @@ def __init__(
assert expected_out.shape == (M, M)

self.name = f"matmul_truncf_{M}_{K}_{input_type}_{acc_type}"
if tile_pipeline == "pack-peel-4-level-tiling":
self.name += "_4_level_tiling"
self.lhs = lhs
self.rhs = rhs
self.expected_out = expected_out
Expand Down Expand Up @@ -1594,29 +1597,31 @@ def __init__(self):
self.tests = []

# Matmul with truncf test(s):
self.register(
MatmulTruncf(
16,
16,
"bf16",
"f32",
101 * np.ones([16, 16]),
3 * np.eye(16),
302 * np.ones([16, 16]),
for tile_pipeline in ["pack-peel", "pack-peel-4-level-tiling"]:
self.register(
MatmulTruncf(
16,
16,
"bf16",
"f32",
101 * np.ones([16, 16]),
3 * np.eye(16),
302 * np.ones([16, 16]),
tile_pipeline=tile_pipeline,
)
)
)

self.register(
MatmulTruncf(
128,
256,
"bf16",
"f32",
2 * np.ones([128, 256]),
3 * np.ones([256, 128]),
1536 * np.ones([128, 128]),
self.register(
MatmulTruncf(
128,
256,
"bf16",
"f32",
2 * np.ones([128, 256]),
3 * np.ones([256, 128]),
1536 * np.ones([128, 128]),
tile_pipeline=tile_pipeline,
)
)
)

# BatchMatmul test(s):
for input_type, acc_type in zip(["i32", "bf16"], ["i32", "f32"]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,20 @@ void AMDAIEFuseConsumerIntoLoopPass::runOnOperation() {
// with any depth instead.
for (unsigned depth = 1; depth <= fuseDepth; depth++) {
do {
Value::user_range users = producerOp->getResult(0).getUsers();
if (!llvm::hasSingleElement(users)) {
ResultRange results = producerOp->getResults();
SmallVector<Operation *> allUsers = std::accumulate(
results.begin(), results.end(), SmallVector<Operation *>{},
[](SmallVector<Operation *> init, OpResult res) {
for (Operation *op : res.getUsers()) init.push_back(op);
return init;
});
if (allUsers.size() != 1) {
LLVM_DEBUG(llvm::dbgs()
<< "Expected only one user of the compute op\n");
break;
}

Operation *candidateSliceOp = *(users.begin());
Operation *candidateSliceOp = allUsers[0];
if (!(isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
candidateSliceOp))) {
producerOp = producerOp->getParentOfType<LoopLikeOpInterface>();
Expand All @@ -127,7 +133,15 @@ void AMDAIEFuseConsumerIntoLoopPass::runOnOperation() {
}
changed = true;
fusedConsumer->origConsumerOperand->getOwner()->erase();
producerOp = fusedConsumer->tiledAndFusedConsumerOperand->getOwner();
Operation *fusedOp =
fusedConsumer->tiledAndFusedConsumerOperand->getOwner();
if (getAncestorInBlock(fusedOp, computeOp->getBlock()) != nullptr) {
// The consumer is fused all the way into the producer's block, so
// operate on this op from now on, but with reduced depth.
computeOp = fusedOp;
fuseDepth -= 1;
}
producerOp = fusedOp;
break;
} while (producerOp && producerOp->getParentOp() != funcOp);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,3 +547,106 @@ module {
return %6 : tensor<512x4096xf32>
}
}

// -----

// CHECK-LABEL: @matmul_elemwise_multiple_fusion_iterations
// CHECK: %[[EXTRACT_SLICE_0:.+]] = tensor.extract_slice %{{.+}}[%{{.+}}, %{{.+}}] [256, 256] [1, 1]
// CHECK: %[[FORALL_0:.+]]:3 = scf.forall (%[[ARG0:.+]], %[[ARG1:.+]]) = (0, 0) to (8, 8) step (4, 4) shared_outs(%[[MATMUL_OUT:.+]] = %{{.*}}, %[[ELEMWISE_OUT:.+]] = %{{.*}}, %[[UNPACK_OUT:.+]] = %{{.*}})
// CHECK: %[[FORALL_1:.+]]:3 = scf.forall (%[[ARG2:.+]], %[[ARG3:.+]]) in (4, 4) shared_outs(%[[MATMUL_LOCAL_OUT:.+]] = %{{.*}}, %[[ELEMWISE_LOCAL_OUT:.+]] = %{{.*}}, %[[UNPACK_LOCAL_OUT:.+]] = %{{.*}})
// CHECK-SAME: {
// CHECK: %[[MATMUL:.+]] = linalg.generic
// CHECK: arith.mulf
// CHECK: arith.addf
// CHECK: %[[ELEMWISE:.+]] = linalg.generic
// CHECK-SAME: ins(%[[MATMUL]] : tensor<1x1x8x8x4x4xf32>)
// CHECK: arith.truncf
// CHECK: %[[EXTRACT_SLICE_1:.+]] = tensor.extract_slice %[[UNPACK_LOCAL_OUT]][%[[ARG2]], %[[ARG3]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1]
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ELEMWISE]] outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %[[EXTRACT_SLICE_1]]
// CHECK: scf.forall.in_parallel {
// CHECK-DAG: tensor.parallel_insert_slice %[[MATMUL]] into %[[MATMUL_LOCAL_OUT]][%[[ARG2]], %[[ARG3]], 0, 0, 0, 0] [1, 1, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1]
// CHECK-DAG: tensor.parallel_insert_slice %[[ELEMWISE]] into %[[ELEMWISE_LOCAL_OUT]][%[[ARG2]], %[[ARG3]], 0, 0, 0, 0] [1, 1, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1]
// CHECK-DAG: tensor.parallel_insert_slice %[[UNPACK]] into %[[UNPACK_LOCAL_OUT]][%[[ARG2]], %[[ARG3]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1]
// CHECK: }
// CHECK: }
// CHECK: scf.forall.in_parallel {
// CHECK-DAG: tensor.parallel_insert_slice %[[FORALL_1]]#0 into %[[MATMUL_OUT]][%[[ARG0]], %[[ARG1]], 0, 0, 0, 0] [4, 4, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1]
// CHECK-DAG: tensor.parallel_insert_slice %[[FORALL_1]]#1 into %[[ELEMWISE_OUT]][%[[ARG0]], %[[ARG1]], 0, 0, 0, 0] [4, 4, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1]
// CHECK-DAG: tensor.parallel_insert_slice %[[FORALL_1]]#2 into %[[UNPACK_OUT]][%[[ARG0]], %[[ARG1]], 0, 0] [4, 4, 32, 32] [1, 1, 1, 1]
// CHECK: }
// CHECK: tensor.unpack %[[FORALL_0]]#2 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[EXTRACT_SLICE_0]]
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d2, d4, d5, d8, d7)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
module {
// expected-error @+1 {{Maximum number of iterations reached, consumer fusion is likely stuck in an infinite loop.}}
func.func @matmul_elemwise_multiple_fusion_iterations() -> tensor<512x4096xbf16> {
%alloc = memref.alloc() : memref<1x1x8x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x8x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<8x8x32x32xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<8x8x64x32xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<8x8x32x64xbf16, 1 : i32>
%0 = tensor.empty() : tensor<512x512xbf16>
%1 = tensor.empty() : tensor<512x4096xbf16>
%2 = tensor.empty() : tensor<512x4096xbf16>
%3 = scf.forall (%arg0, %arg1) = (0, 0) to (512, 4096) step (256, 256) shared_outs(%arg2 = %2) -> (tensor<512x4096xbf16>) {
%extracted_slice = tensor.extract_slice %0[%arg0, 0] [256, 512] [1, 1] : tensor<512x512xbf16> to tensor<256x512xbf16>
%extracted_slice_4 = tensor.extract_slice %1[0, %arg1] [512, 256] [1, 1] : tensor<512x4096xbf16> to tensor<512x256xbf16>
%extracted_slice_5 = tensor.extract_slice %arg2[%arg0, %arg1] [256, 256] [1, 1] : tensor<512x4096xbf16> to tensor<256x256xbf16>
%4 = bufferization.to_tensor %alloc_3 restrict writable : memref<8x8x32x64xbf16, 1 : i32> to tensor<8x8x32x64xbf16>
%pack = tensor.pack %extracted_slice outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %4 : tensor<256x512xbf16> -> tensor<8x8x32x64xbf16>
%5 = bufferization.to_tensor %alloc_2 restrict writable : memref<8x8x64x32xbf16, 1 : i32> to tensor<8x8x64x32xbf16>
%pack_6 = tensor.pack %extracted_slice_4 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %5 : tensor<512x256xbf16> -> tensor<8x8x64x32xbf16>
%6 = bufferization.to_tensor %alloc_1 restrict writable : memref<8x8x32x32xbf16, 1 : i32> to tensor<8x8x32x32xbf16>
%7 = tensor.empty() : tensor<8x8x8x8x4x4xbf16>
%8 = tensor.empty() : tensor<8x8x8x8x4x4xf32>
%9 = scf.forall (%arg3, %arg4) = (0, 0) to (8, 8) step (4, 4) shared_outs(%arg5 = %8) -> (tensor<8x8x8x8x4x4xf32>) {
%extracted_slice_8 = tensor.extract_slice %pack[%arg3, 0, 0, 0] [4, 8, 32, 64] [1, 1, 1, 1] : tensor<8x8x32x64xbf16> to tensor<4x8x32x64xbf16>
%extracted_slice_9 = tensor.extract_slice %pack_6[%arg4, 0, 0, 0] [4, 8, 64, 32] [1, 1, 1, 1] : tensor<8x8x64x32xbf16> to tensor<4x8x64x32xbf16>
%extracted_slice_10 = tensor.extract_slice %extracted_slice_8[0, 7, 0, 0] [4, 1, 32, 64] [1, 1, 1, 1] : tensor<4x8x32x64xbf16> to tensor<4x1x32x64xbf16>
%extracted_slice_11 = tensor.extract_slice %extracted_slice_9[0, 7, 0, 0] [4, 1, 64, 32] [1, 1, 1, 1] : tensor<4x8x64x32xbf16> to tensor<4x1x64x32xbf16>
%11 = tensor.empty() : tensor<4x4x8x8x4x4xf32>
%12 = scf.forall (%arg6, %arg7) in (4, 4) shared_outs(%arg8 = %11) -> (tensor<4x4x8x8x4x4xf32>) {
%extracted_slice_12 = tensor.extract_slice %extracted_slice_10[%arg6, 0, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : tensor<4x1x32x64xbf16> to tensor<1x1x32x64xbf16>
%13 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x8x4x8xbf16, 2 : i32> to tensor<1x1x8x8x4x8xbf16>
%pack_13 = tensor.pack %extracted_slice_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %13 : tensor<1x1x32x64xbf16> -> tensor<1x1x8x8x4x8xbf16>
%extracted_slice_14 = tensor.extract_slice %extracted_slice_11[%arg7, 0, 0, 0] [1, 1, 64, 32] [1, 1, 1, 1] : tensor<4x1x64x32xbf16> to tensor<1x1x64x32xbf16>
%14 = bufferization.to_tensor %alloc restrict writable : memref<1x1x8x8x8x4xbf16, 2 : i32> to tensor<1x1x8x8x8x4xbf16>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %14 : tensor<1x1x64x32xbf16> -> tensor<1x1x8x8x8x4xbf16>
%extracted_slice_16 = tensor.extract_slice %arg8[%arg6, %arg7, 0, 0, 0, 0] [1, 1, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<4x4x8x8x4x4xf32> to tensor<1x1x8x8x4x4xf32>
%15 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_13, %pack_15 : tensor<1x1x8x8x4x8xbf16>, tensor<1x1x8x8x8x4xbf16>) outs(%extracted_slice_16 : tensor<1x1x8x8x4x4xf32>) {
^bb0(%in: bf16, %in_17: bf16, %out: f32):
%16 = arith.extf %in : bf16 to f32
%17 = arith.extf %in_17 : bf16 to f32
%18 = arith.mulf %16, %17 : f32
%19 = arith.addf %out, %18 : f32
linalg.yield %19 : f32
} -> tensor<1x1x8x8x4x4xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %15 into %arg8[%arg6, %arg7, 0, 0, 0, 0] [1, 1, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x8x4x4xf32> into tensor<4x4x8x8x4x4xf32>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.forall.in_parallel {
tensor.parallel_insert_slice %12 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [4, 4, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<4x4x8x8x4x4xf32> into tensor<8x8x8x8x4x4xf32>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
%10 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%9 : tensor<8x8x8x8x4x4xf32>) outs(%7 : tensor<8x8x8x8x4x4xbf16>) {
^bb0(%in: f32, %out: bf16):
%11 = arith.truncf %in : f32 to bf16
linalg.yield %11 : bf16
} -> tensor<8x8x8x8x4x4xbf16>
%unpack = tensor.unpack %10 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %6 : tensor<8x8x8x8x4x4xbf16> -> tensor<8x8x32x32xbf16>
%unpack_7 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %extracted_slice_5 : tensor<8x8x32x32xbf16> -> tensor<256x256xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_7 into %arg2[%arg0, %arg1] [256, 256] [1, 1] : tensor<256x256xbf16> into tensor<512x4096xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
memref.dealloc %alloc_3 : memref<8x8x32x64xbf16, 1 : i32>
memref.dealloc %alloc_2 : memref<8x8x64x32xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<8x8x32x32xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x8x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x8x8x8x4xbf16, 2 : i32>
return %3 : tensor<512x4096xbf16>
}
}
Loading