@@ -547,3 +547,106 @@ module {
547
547
return %6 : tensor <512 x4096 xf32 >
548
548
}
549
549
}
550
+
551
+ // -----
552
+
553
+ // CHECK-LABEL: @matmul_elemwise_multiple_fusion_iterations
554
+ // CHECK: %[[EXTRACT_SLICE_0:.+]] = tensor.extract_slice %{{.+}}[%{{.+}}, %{{.+}}] [256, 256] [1, 1]
555
+ // CHECK: %[[FORALL_0:.+]]:3 = scf.forall (%[[ARG0:.+]], %[[ARG1:.+]]) = (0, 0) to (8, 8) step (4, 4) shared_outs(%[[MATMUL_OUT:.+]] = %{{.*}}, %[[ELEMWISE_OUT:.+]] = %{{.*}}, %[[UNPACK_OUT:.+]] = %{{.*}})
556
+ // CHECK: %[[FORALL_1:.+]]:3 = scf.forall (%[[ARG2:.+]], %[[ARG3:.+]]) in (4, 4) shared_outs(%[[MATMUL_LOCAL_OUT:.+]] = %{{.*}}, %[[ELEMWISE_LOCAL_OUT:.+]] = %{{.*}}, %[[UNPACK_LOCAL_OUT:.+]] = %{{.*}})
557
+ // CHECK-SAME: {
558
+ // CHECK: %[[MATMUL:.+]] = linalg.generic
559
+ // CHECK: arith.mulf
560
+ // CHECK: arith.addf
561
+ // CHECK: %[[ELEMWISE:.+]] = linalg.generic
562
+ // CHECK-SAME: ins(%[[MATMUL]] : tensor<1x1x8x8x4x4xf32>)
563
+ // CHECK: arith.truncf
564
+ // CHECK: %[[EXTRACT_SLICE_1:.+]] = tensor.extract_slice %[[UNPACK_LOCAL_OUT]][%[[ARG2]], %[[ARG3]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1]
565
+ // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ELEMWISE]] outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %[[EXTRACT_SLICE_1]]
566
+ // CHECK: scf.forall.in_parallel {
567
+ // CHECK-DAG: tensor.parallel_insert_slice %[[MATMUL]] into %[[MATMUL_LOCAL_OUT]][%[[ARG2]], %[[ARG3]], 0, 0, 0, 0] [1, 1, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1]
568
+ // CHECK-DAG: tensor.parallel_insert_slice %[[ELEMWISE]] into %[[ELEMWISE_LOCAL_OUT]][%[[ARG2]], %[[ARG3]], 0, 0, 0, 0] [1, 1, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1]
569
+ // CHECK-DAG: tensor.parallel_insert_slice %[[UNPACK]] into %[[UNPACK_LOCAL_OUT]][%[[ARG2]], %[[ARG3]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1]
570
+ // CHECK: }
571
+ // CHECK: }
572
+ // CHECK: scf.forall.in_parallel {
573
+ // CHECK-DAG: tensor.parallel_insert_slice %[[FORALL_1]]#0 into %[[MATMUL_OUT]][%[[ARG0]], %[[ARG1]], 0, 0, 0, 0] [4, 4, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1]
574
+ // CHECK-DAG: tensor.parallel_insert_slice %[[FORALL_1]]#1 into %[[ELEMWISE_OUT]][%[[ARG0]], %[[ARG1]], 0, 0, 0, 0] [4, 4, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1]
575
+ // CHECK-DAG: tensor.parallel_insert_slice %[[FORALL_1]]#2 into %[[UNPACK_OUT]][%[[ARG0]], %[[ARG1]], 0, 0] [4, 4, 32, 32] [1, 1, 1, 1]
576
+ // CHECK: }
577
+ // CHECK: tensor.unpack %[[FORALL_0]]#2 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[EXTRACT_SLICE_0]]
578
+ #map = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 , d8 ) -> (d0 , d2 , d5 , d3 , d6 , d8 )>
579
+ #map1 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 , d8 ) -> (d1 , d2 , d4 , d5 , d8 , d7 )>
580
+ #map2 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 , d8 ) -> (d0 , d1 , d4 , d3 , d6 , d7 )>
581
+ #map3 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d2 , d3 , d4 , d5 )>
582
+ module {
583
+ // expected-error @+1 {{Maximum number of iterations reached, consumer fusion is likely stuck in an infinite loop.}}
584
+ func.func @matmul_elemwise_multiple_fusion_iterations () -> tensor <512 x4096 xbf16 > {
585
+ %alloc = memref.alloc () : memref <1 x1 x8 x8 x8 x4 xbf16 , 2 : i32 >
586
+ %alloc_0 = memref.alloc () : memref <1 x1 x8 x8 x4 x8 xbf16 , 2 : i32 >
587
+ %alloc_1 = memref.alloc () : memref <8 x8 x32 x32 xbf16 , 1 : i32 >
588
+ %alloc_2 = memref.alloc () : memref <8 x8 x64 x32 xbf16 , 1 : i32 >
589
+ %alloc_3 = memref.alloc () : memref <8 x8 x32 x64 xbf16 , 1 : i32 >
590
+ %0 = tensor.empty () : tensor <512 x512 xbf16 >
591
+ %1 = tensor.empty () : tensor <512 x4096 xbf16 >
592
+ %2 = tensor.empty () : tensor <512 x4096 xbf16 >
593
+ %3 = scf.forall (%arg0 , %arg1 ) = (0 , 0 ) to (512 , 4096 ) step (256 , 256 ) shared_outs (%arg2 = %2 ) -> (tensor <512 x4096 xbf16 >) {
594
+ %extracted_slice = tensor.extract_slice %0 [%arg0 , 0 ] [256 , 512 ] [1 , 1 ] : tensor <512 x512 xbf16 > to tensor <256 x512 xbf16 >
595
+ %extracted_slice_4 = tensor.extract_slice %1 [0 , %arg1 ] [512 , 256 ] [1 , 1 ] : tensor <512 x4096 xbf16 > to tensor <512 x256 xbf16 >
596
+ %extracted_slice_5 = tensor.extract_slice %arg2 [%arg0 , %arg1 ] [256 , 256 ] [1 , 1 ] : tensor <512 x4096 xbf16 > to tensor <256 x256 xbf16 >
597
+ %4 = bufferization.to_tensor %alloc_3 restrict writable : memref <8 x8 x32 x64 xbf16 , 1 : i32 > to tensor <8 x8 x32 x64 xbf16 >
598
+ %pack = tensor.pack %extracted_slice outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [32 , 64 ] into %4 : tensor <256 x512 xbf16 > -> tensor <8 x8 x32 x64 xbf16 >
599
+ %5 = bufferization.to_tensor %alloc_2 restrict writable : memref <8 x8 x64 x32 xbf16 , 1 : i32 > to tensor <8 x8 x64 x32 xbf16 >
600
+ %pack_6 = tensor.pack %extracted_slice_4 outer_dims_perm = [1 , 0 ] inner_dims_pos = [0 , 1 ] inner_tiles = [64 , 32 ] into %5 : tensor <512 x256 xbf16 > -> tensor <8 x8 x64 x32 xbf16 >
601
+ %6 = bufferization.to_tensor %alloc_1 restrict writable : memref <8 x8 x32 x32 xbf16 , 1 : i32 > to tensor <8 x8 x32 x32 xbf16 >
602
+ %7 = tensor.empty () : tensor <8 x8 x8 x8 x4 x4 xbf16 >
603
+ %8 = tensor.empty () : tensor <8 x8 x8 x8 x4 x4 xf32 >
604
+ %9 = scf.forall (%arg3 , %arg4 ) = (0 , 0 ) to (8 , 8 ) step (4 , 4 ) shared_outs (%arg5 = %8 ) -> (tensor <8 x8 x8 x8 x4 x4 xf32 >) {
605
+ %extracted_slice_8 = tensor.extract_slice %pack [%arg3 , 0 , 0 , 0 ] [4 , 8 , 32 , 64 ] [1 , 1 , 1 , 1 ] : tensor <8 x8 x32 x64 xbf16 > to tensor <4 x8 x32 x64 xbf16 >
606
+ %extracted_slice_9 = tensor.extract_slice %pack_6 [%arg4 , 0 , 0 , 0 ] [4 , 8 , 64 , 32 ] [1 , 1 , 1 , 1 ] : tensor <8 x8 x64 x32 xbf16 > to tensor <4 x8 x64 x32 xbf16 >
607
+ %extracted_slice_10 = tensor.extract_slice %extracted_slice_8 [0 , 7 , 0 , 0 ] [4 , 1 , 32 , 64 ] [1 , 1 , 1 , 1 ] : tensor <4 x8 x32 x64 xbf16 > to tensor <4 x1 x32 x64 xbf16 >
608
+ %extracted_slice_11 = tensor.extract_slice %extracted_slice_9 [0 , 7 , 0 , 0 ] [4 , 1 , 64 , 32 ] [1 , 1 , 1 , 1 ] : tensor <4 x8 x64 x32 xbf16 > to tensor <4 x1 x64 x32 xbf16 >
609
+ %11 = tensor.empty () : tensor <4 x4 x8 x8 x4 x4 xf32 >
610
+ %12 = scf.forall (%arg6 , %arg7 ) in (4 , 4 ) shared_outs (%arg8 = %11 ) -> (tensor <4 x4 x8 x8 x4 x4 xf32 >) {
611
+ %extracted_slice_12 = tensor.extract_slice %extracted_slice_10 [%arg6 , 0 , 0 , 0 ] [1 , 1 , 32 , 64 ] [1 , 1 , 1 , 1 ] : tensor <4 x1 x32 x64 xbf16 > to tensor <1 x1 x32 x64 xbf16 >
612
+ %13 = bufferization.to_tensor %alloc_0 restrict writable : memref <1 x1 x8 x8 x4 x8 xbf16 , 2 : i32 > to tensor <1 x1 x8 x8 x4 x8 xbf16 >
613
+ %pack_13 = tensor.pack %extracted_slice_12 outer_dims_perm = [0 , 1 , 3 , 2 ] inner_dims_pos = [2 , 3 ] inner_tiles = [4 , 8 ] into %13 : tensor <1 x1 x32 x64 xbf16 > -> tensor <1 x1 x8 x8 x4 x8 xbf16 >
614
+ %extracted_slice_14 = tensor.extract_slice %extracted_slice_11 [%arg7 , 0 , 0 , 0 ] [1 , 1 , 64 , 32 ] [1 , 1 , 1 , 1 ] : tensor <4 x1 x64 x32 xbf16 > to tensor <1 x1 x64 x32 xbf16 >
615
+ %14 = bufferization.to_tensor %alloc restrict writable : memref <1 x1 x8 x8 x8 x4 xbf16 , 2 : i32 > to tensor <1 x1 x8 x8 x8 x4 xbf16 >
616
+ %pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0 , 1 , 3 , 2 ] inner_dims_pos = [2 , 3 ] inner_tiles = [8 , 4 ] into %14 : tensor <1 x1 x64 x32 xbf16 > -> tensor <1 x1 x8 x8 x8 x4 xbf16 >
617
+ %extracted_slice_16 = tensor.extract_slice %arg8 [%arg6 , %arg7 , 0 , 0 , 0 , 0 ] [1 , 1 , 8 , 8 , 4 , 4 ] [1 , 1 , 1 , 1 , 1 , 1 ] : tensor <4 x4 x8 x8 x4 x4 xf32 > to tensor <1 x1 x8 x8 x4 x4 xf32 >
618
+ %15 = linalg.generic {index ing_maps = [#map , #map1 , #map2 ], iterator_types = [" parallel" , " parallel" , " reduction" , " parallel" , " parallel" , " reduction" , " parallel" , " parallel" , " reduction" ]} ins (%pack_13 , %pack_15 : tensor <1 x1 x8 x8 x4 x8 xbf16 >, tensor <1 x1 x8 x8 x8 x4 xbf16 >) outs (%extracted_slice_16 : tensor <1 x1 x8 x8 x4 x4 xf32 >) {
619
+ ^bb0 (%in: bf16 , %in_17: bf16 , %out: f32 ):
620
+ %16 = arith.extf %in : bf16 to f32
621
+ %17 = arith.extf %in_17 : bf16 to f32
622
+ %18 = arith.mulf %16 , %17 : f32
623
+ %19 = arith.addf %out , %18 : f32
624
+ linalg.yield %19 : f32
625
+ } -> tensor <1 x1 x8 x8 x4 x4 xf32 >
626
+ scf.forall.in_parallel {
627
+ tensor.parallel_insert_slice %15 into %arg8 [%arg6 , %arg7 , 0 , 0 , 0 , 0 ] [1 , 1 , 8 , 8 , 4 , 4 ] [1 , 1 , 1 , 1 , 1 , 1 ] : tensor <1 x1 x8 x8 x4 x4 xf32 > into tensor <4 x4 x8 x8 x4 x4 xf32 >
628
+ }
629
+ } {mapping = [#gpu.thread <y >, #gpu.thread <x >]}
630
+ scf.forall.in_parallel {
631
+ tensor.parallel_insert_slice %12 into %arg5 [%arg3 , %arg4 , 0 , 0 , 0 , 0 ] [4 , 4 , 8 , 8 , 4 , 4 ] [1 , 1 , 1 , 1 , 1 , 1 ] : tensor <4 x4 x8 x8 x4 x4 xf32 > into tensor <8 x8 x8 x8 x4 x4 xf32 >
632
+ }
633
+ } {mapping = [#gpu.block <y >, #gpu.block <x >]}
634
+ %10 = linalg.generic {index ing_maps = [#map3 , #map3 ], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " parallel" ]} ins (%9 : tensor <8 x8 x8 x8 x4 x4 xf32 >) outs (%7 : tensor <8 x8 x8 x8 x4 x4 xbf16 >) {
635
+ ^bb0 (%in: f32 , %out: bf16 ):
636
+ %11 = arith.truncf %in : f32 to bf16
637
+ linalg.yield %11 : bf16
638
+ } -> tensor <8 x8 x8 x8 x4 x4 xbf16 >
639
+ %unpack = tensor.unpack %10 outer_dims_perm = [0 , 1 , 3 , 2 ] inner_dims_pos = [2 , 3 ] inner_tiles = [4 , 4 ] into %6 : tensor <8 x8 x8 x8 x4 x4 xbf16 > -> tensor <8 x8 x32 x32 xbf16 >
640
+ %unpack_7 = tensor.unpack %unpack inner_dims_pos = [0 , 1 ] inner_tiles = [32 , 32 ] into %extracted_slice_5 : tensor <8 x8 x32 x32 xbf16 > -> tensor <256 x256 xbf16 >
641
+ scf.forall.in_parallel {
642
+ tensor.parallel_insert_slice %unpack_7 into %arg2 [%arg0 , %arg1 ] [256 , 256 ] [1 , 1 ] : tensor <256 x256 xbf16 > into tensor <512 x4096 xbf16 >
643
+ }
644
+ } {mapping = [#gpu.block <y >, #gpu.block <x >]}
645
+ memref.dealloc %alloc_3 : memref <8 x8 x32 x64 xbf16 , 1 : i32 >
646
+ memref.dealloc %alloc_2 : memref <8 x8 x64 x32 xbf16 , 1 : i32 >
647
+ memref.dealloc %alloc_1 : memref <8 x8 x32 x32 xbf16 , 1 : i32 >
648
+ memref.dealloc %alloc_0 : memref <1 x1 x8 x8 x4 x8 xbf16 , 2 : i32 >
649
+ memref.dealloc %alloc : memref <1 x1 x8 x8 x8 x4 xbf16 , 2 : i32 >
650
+ return %3 : tensor <512 x4096 xbf16 >
651
+ }
652
+ }
0 commit comments