From 68d652792db8589cdd7f379c03501c2908a63f8b Mon Sep 17 00:00:00 2001 From: Jorn Tuyls Date: Tue, 14 Jan 2025 19:27:46 +0100 Subject: [PATCH] [FuseConsumerIntoLoop] Extend to multiple results and generalize depth fusion (#1027) This PR extends the consumer fusion to work for matmul + elemwise in case of an additional level of tiling (see added lit test). --- build_tools/ci/cpu_comparison/run.py | 49 +++++---- .../Transforms/AMDAIEFuseConsumerIntoLoop.cpp | 22 +++- .../test/fuse_consumer_into_loop.mlir | 103 ++++++++++++++++++ 3 files changed, 148 insertions(+), 26 deletions(-) diff --git a/build_tools/ci/cpu_comparison/run.py b/build_tools/ci/cpu_comparison/run.py index 2117554de..ef5546745 100755 --- a/build_tools/ci/cpu_comparison/run.py +++ b/build_tools/ci/cpu_comparison/run.py @@ -756,6 +756,7 @@ def __init__( rhs, expected_out, run_on_target=["npu1_4col"], + tile_pipeline="pack-peel", ): super().__init__( run_on_target=run_on_target, @@ -765,7 +766,7 @@ def __init__( K=K, input_type=input_type, acc_type=acc_type, - tile_pipeline="pack-peel", + tile_pipeline=tile_pipeline, n_repeats=1, ) self.labels.append("MatmulTruncf") @@ -776,6 +777,8 @@ def __init__( assert expected_out.shape == (M, M) self.name = f"matmul_truncf_{M}_{K}_{input_type}_{acc_type}" + if tile_pipeline == "pack-peel-4-level-tiling": + self.name += "_4_level_tiling" self.lhs = lhs self.rhs = rhs self.expected_out = expected_out @@ -1594,29 +1597,31 @@ def __init__(self): self.tests = [] # Matmul with truncf test(s): - self.register( - MatmulTruncf( - 16, - 16, - "bf16", - "f32", - 101 * np.ones([16, 16]), - 3 * np.eye(16), - 302 * np.ones([16, 16]), + for tile_pipeline in ["pack-peel", "pack-peel-4-level-tiling"]: + self.register( + MatmulTruncf( + 16, + 16, + "bf16", + "f32", + 101 * np.ones([16, 16]), + 3 * np.eye(16), + 302 * np.ones([16, 16]), + tile_pipeline=tile_pipeline, + ) ) - ) - - self.register( - MatmulTruncf( - 128, - 256, - "bf16", - "f32", - 2 * np.ones([128, 256]), - 3 * np.ones([256, 128]), - 1536 * np.ones([128, 128]), + self.register( + MatmulTruncf( + 128, + 256, + "bf16", + "f32", + 2 * np.ones([128, 256]), + 3 * np.ones([256, 128]), + 1536 * np.ones([128, 128]), + tile_pipeline=tile_pipeline, + ) ) - ) # BatchMatmul test(s): for input_type, acc_type in zip(["i32", "bf16"], ["i32", "f32"]): diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFuseConsumerIntoLoop.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFuseConsumerIntoLoop.cpp index 2d7a28980..84b61d9c9 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFuseConsumerIntoLoop.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFuseConsumerIntoLoop.cpp @@ -98,14 +98,20 @@ void AMDAIEFuseConsumerIntoLoopPass::runOnOperation() { // with any depth instead. for (unsigned depth = 1; depth <= fuseDepth; depth++) { do { - Value::user_range users = producerOp->getResult(0).getUsers(); - if (!llvm::hasSingleElement(users)) { + ResultRange results = producerOp->getResults(); + SmallVector allUsers = std::accumulate( + results.begin(), results.end(), SmallVector{}, + [](SmallVector init, OpResult res) { + for (Operation *op : res.getUsers()) init.push_back(op); + return init; + }); + if (allUsers.size() != 1) { LLVM_DEBUG(llvm::dbgs() << "Expected only one user of the compute op\n"); break; } - Operation *candidateSliceOp = *(users.begin()); + Operation *candidateSliceOp = allUsers[0]; if (!(isa( candidateSliceOp))) { producerOp = producerOp->getParentOfType(); @@ -127,7 +133,15 @@ void AMDAIEFuseConsumerIntoLoopPass::runOnOperation() { } changed = true; fusedConsumer->origConsumerOperand->getOwner()->erase(); - producerOp = fusedConsumer->tiledAndFusedConsumerOperand->getOwner(); + Operation *fusedOp = + fusedConsumer->tiledAndFusedConsumerOperand->getOwner(); + if (getAncestorInBlock(fusedOp, computeOp->getBlock()) != nullptr) { + // The consumer is fused all the way into the producer's block, so + // operate on this op from now on, but with reduced depth. + computeOp = fusedOp; + fuseDepth -= 1; + } + producerOp = fusedOp; break; } while (producerOp && producerOp->getParentOp() != funcOp); } diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/fuse_consumer_into_loop.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/fuse_consumer_into_loop.mlir index 33470fac3..b8d91fc27 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/fuse_consumer_into_loop.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/fuse_consumer_into_loop.mlir @@ -547,3 +547,106 @@ module { return %6 : tensor<512x4096xf32> } } + +// ----- + +// CHECK-LABEL: @matmul_elemwise_multiple_fusion_iterations +// CHECK: %[[EXTRACT_SLICE_0:.+]] = tensor.extract_slice %{{.+}}[%{{.+}}, %{{.+}}] [256, 256] [1, 1] +// CHECK: %[[FORALL_0:.+]]:3 = scf.forall (%[[ARG0:.+]], %[[ARG1:.+]]) = (0, 0) to (8, 8) step (4, 4) shared_outs(%[[MATMUL_OUT:.+]] = %{{.*}}, %[[ELEMWISE_OUT:.+]] = %{{.*}}, %[[UNPACK_OUT:.+]] = %{{.*}}) +// CHECK: %[[FORALL_1:.+]]:3 = scf.forall (%[[ARG2:.+]], %[[ARG3:.+]]) in (4, 4) shared_outs(%[[MATMUL_LOCAL_OUT:.+]] = %{{.*}}, %[[ELEMWISE_LOCAL_OUT:.+]] = %{{.*}}, %[[UNPACK_LOCAL_OUT:.+]] = %{{.*}}) +// CHECK-SAME: { +// CHECK: %[[MATMUL:.+]] = linalg.generic +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: %[[ELEMWISE:.+]] = linalg.generic +// CHECK-SAME: ins(%[[MATMUL]] : tensor<1x1x8x8x4x4xf32>) +// CHECK: arith.truncf +// CHECK: %[[EXTRACT_SLICE_1:.+]] = tensor.extract_slice %[[UNPACK_LOCAL_OUT]][%[[ARG2]], %[[ARG3]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ELEMWISE]] outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %[[EXTRACT_SLICE_1]] +// CHECK: scf.forall.in_parallel { +// CHECK-DAG: tensor.parallel_insert_slice %[[MATMUL]] into %[[MATMUL_LOCAL_OUT]][%[[ARG2]], %[[ARG3]], 0, 0, 0, 0] [1, 1, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1] +// CHECK-DAG: tensor.parallel_insert_slice %[[ELEMWISE]] into %[[ELEMWISE_LOCAL_OUT]][%[[ARG2]], %[[ARG3]], 0, 0, 0, 0] [1, 1, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1] +// CHECK-DAG: tensor.parallel_insert_slice %[[UNPACK]] into %[[UNPACK_LOCAL_OUT]][%[[ARG2]], %[[ARG3]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] +// CHECK: } +// CHECK: } +// CHECK: scf.forall.in_parallel { +// CHECK-DAG: tensor.parallel_insert_slice %[[FORALL_1]]#0 into %[[MATMUL_OUT]][%[[ARG0]], %[[ARG1]], 0, 0, 0, 0] [4, 4, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1] +// CHECK-DAG: tensor.parallel_insert_slice %[[FORALL_1]]#1 into %[[ELEMWISE_OUT]][%[[ARG0]], %[[ARG1]], 0, 0, 0, 0] [4, 4, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1] +// CHECK-DAG: tensor.parallel_insert_slice %[[FORALL_1]]#2 into %[[UNPACK_OUT]][%[[ARG0]], %[[ARG1]], 0, 0] [4, 4, 32, 32] [1, 1, 1, 1] +// CHECK: } +// CHECK: tensor.unpack %[[FORALL_0]]#2 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[EXTRACT_SLICE_0]] +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d2, d4, d5, d8, d7)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +module { + // expected-error @+1 {{Maximum number of iterations reached, consumer fusion is likely stuck in an infinite loop.}} + func.func @matmul_elemwise_multiple_fusion_iterations() -> tensor<512x4096xbf16> { + %alloc = memref.alloc() : memref<1x1x8x8x8x4xbf16, 2 : i32> + %alloc_0 = memref.alloc() : memref<1x1x8x8x4x8xbf16, 2 : i32> + %alloc_1 = memref.alloc() : memref<8x8x32x32xbf16, 1 : i32> + %alloc_2 = memref.alloc() : memref<8x8x64x32xbf16, 1 : i32> + %alloc_3 = memref.alloc() : memref<8x8x32x64xbf16, 1 : i32> + %0 = tensor.empty() : tensor<512x512xbf16> + %1 = tensor.empty() : tensor<512x4096xbf16> + %2 = tensor.empty() : tensor<512x4096xbf16> + %3 = scf.forall (%arg0, %arg1) = (0, 0) to (512, 4096) step (256, 256) shared_outs(%arg2 = %2) -> (tensor<512x4096xbf16>) { + %extracted_slice = tensor.extract_slice %0[%arg0, 0] [256, 512] [1, 1] : tensor<512x512xbf16> to tensor<256x512xbf16> + %extracted_slice_4 = tensor.extract_slice %1[0, %arg1] [512, 256] [1, 1] : tensor<512x4096xbf16> to tensor<512x256xbf16> + %extracted_slice_5 = tensor.extract_slice %arg2[%arg0, %arg1] [256, 256] [1, 1] : tensor<512x4096xbf16> to tensor<256x256xbf16> + %4 = bufferization.to_tensor %alloc_3 restrict writable : memref<8x8x32x64xbf16, 1 : i32> to tensor<8x8x32x64xbf16> + %pack = tensor.pack %extracted_slice outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64] into %4 : tensor<256x512xbf16> -> tensor<8x8x32x64xbf16> + %5 = bufferization.to_tensor %alloc_2 restrict writable : memref<8x8x64x32xbf16, 1 : i32> to tensor<8x8x64x32xbf16> + %pack_6 = tensor.pack %extracted_slice_4 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %5 : tensor<512x256xbf16> -> tensor<8x8x64x32xbf16> + %6 = bufferization.to_tensor %alloc_1 restrict writable : memref<8x8x32x32xbf16, 1 : i32> to tensor<8x8x32x32xbf16> + %7 = tensor.empty() : tensor<8x8x8x8x4x4xbf16> + %8 = tensor.empty() : tensor<8x8x8x8x4x4xf32> + %9 = scf.forall (%arg3, %arg4) = (0, 0) to (8, 8) step (4, 4) shared_outs(%arg5 = %8) -> (tensor<8x8x8x8x4x4xf32>) { + %extracted_slice_8 = tensor.extract_slice %pack[%arg3, 0, 0, 0] [4, 8, 32, 64] [1, 1, 1, 1] : tensor<8x8x32x64xbf16> to tensor<4x8x32x64xbf16> + %extracted_slice_9 = tensor.extract_slice %pack_6[%arg4, 0, 0, 0] [4, 8, 64, 32] [1, 1, 1, 1] : tensor<8x8x64x32xbf16> to tensor<4x8x64x32xbf16> + %extracted_slice_10 = tensor.extract_slice %extracted_slice_8[0, 7, 0, 0] [4, 1, 32, 64] [1, 1, 1, 1] : tensor<4x8x32x64xbf16> to tensor<4x1x32x64xbf16> + %extracted_slice_11 = tensor.extract_slice %extracted_slice_9[0, 7, 0, 0] [4, 1, 64, 32] [1, 1, 1, 1] : tensor<4x8x64x32xbf16> to tensor<4x1x64x32xbf16> + %11 = tensor.empty() : tensor<4x4x8x8x4x4xf32> + %12 = scf.forall (%arg6, %arg7) in (4, 4) shared_outs(%arg8 = %11) -> (tensor<4x4x8x8x4x4xf32>) { + %extracted_slice_12 = tensor.extract_slice %extracted_slice_10[%arg6, 0, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : tensor<4x1x32x64xbf16> to tensor<1x1x32x64xbf16> + %13 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x8x4x8xbf16, 2 : i32> to tensor<1x1x8x8x4x8xbf16> + %pack_13 = tensor.pack %extracted_slice_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %13 : tensor<1x1x32x64xbf16> -> tensor<1x1x8x8x4x8xbf16> + %extracted_slice_14 = tensor.extract_slice %extracted_slice_11[%arg7, 0, 0, 0] [1, 1, 64, 32] [1, 1, 1, 1] : tensor<4x1x64x32xbf16> to tensor<1x1x64x32xbf16> + %14 = bufferization.to_tensor %alloc restrict writable : memref<1x1x8x8x8x4xbf16, 2 : i32> to tensor<1x1x8x8x8x4xbf16> + %pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %14 : tensor<1x1x64x32xbf16> -> tensor<1x1x8x8x8x4xbf16> + %extracted_slice_16 = tensor.extract_slice %arg8[%arg6, %arg7, 0, 0, 0, 0] [1, 1, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<4x4x8x8x4x4xf32> to tensor<1x1x8x8x4x4xf32> + %15 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_13, %pack_15 : tensor<1x1x8x8x4x8xbf16>, tensor<1x1x8x8x8x4xbf16>) outs(%extracted_slice_16 : tensor<1x1x8x8x4x4xf32>) { + ^bb0(%in: bf16, %in_17: bf16, %out: f32): + %16 = arith.extf %in : bf16 to f32 + %17 = arith.extf %in_17 : bf16 to f32 + %18 = arith.mulf %16, %17 : f32 + %19 = arith.addf %out, %18 : f32 + linalg.yield %19 : f32 + } -> tensor<1x1x8x8x4x4xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %15 into %arg8[%arg6, %arg7, 0, 0, 0, 0] [1, 1, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x8x4x4xf32> into tensor<4x4x8x8x4x4xf32> + } + } {mapping = [#gpu.thread, #gpu.thread]} + scf.forall.in_parallel { + tensor.parallel_insert_slice %12 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [4, 4, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<4x4x8x8x4x4xf32> into tensor<8x8x8x8x4x4xf32> + } + } {mapping = [#gpu.block, #gpu.block]} + %10 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%9 : tensor<8x8x8x8x4x4xf32>) outs(%7 : tensor<8x8x8x8x4x4xbf16>) { + ^bb0(%in: f32, %out: bf16): + %11 = arith.truncf %in : f32 to bf16 + linalg.yield %11 : bf16 + } -> tensor<8x8x8x8x4x4xbf16> + %unpack = tensor.unpack %10 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %6 : tensor<8x8x8x8x4x4xbf16> -> tensor<8x8x32x32xbf16> + %unpack_7 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %extracted_slice_5 : tensor<8x8x32x32xbf16> -> tensor<256x256xbf16> + scf.forall.in_parallel { + tensor.parallel_insert_slice %unpack_7 into %arg2[%arg0, %arg1] [256, 256] [1, 1] : tensor<256x256xbf16> into tensor<512x4096xbf16> + } + } {mapping = [#gpu.block, #gpu.block]} + memref.dealloc %alloc_3 : memref<8x8x32x64xbf16, 1 : i32> + memref.dealloc %alloc_2 : memref<8x8x64x32xbf16, 1 : i32> + memref.dealloc %alloc_1 : memref<8x8x32x32xbf16, 1 : i32> + memref.dealloc %alloc_0 : memref<1x1x8x8x4x8xbf16, 2 : i32> + memref.dealloc %alloc : memref<1x1x8x8x8x4xbf16, 2 : i32> + return %3 : tensor<512x4096xbf16> + } +}