Skip to content

Commit e2bbc9b

Browse files
Match TileAndFuse Matmul Heuristics to VectorDistibute and raise limit of TileLargeTensorPass
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
1 parent b5ed37c commit e2bbc9b

File tree

6 files changed

+35
-30
lines changed

6 files changed

+35
-30
lines changed

compiler/src/iree/compiler/Codegen/Common/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ def TileLargeTensorsPass :
654654
];
655655
let options = [
656656
Option<"maxVectorSize", "max-vector-size", "int64_t",
657-
/*default=*/"64",
657+
/*default=*/"256",
658658
"Maximum static size to tile to (i.e. all remaining ops will be smaller)">,
659659
];
660660
}

compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,22 @@
33
// RUN: FileCheck %s
44

55
#map = affine_map<(d0, d1) -> (d0, d1)>
6-
func.func @simple_generic(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>, %5: tensor<64x256xf32>) -> tensor<64x256xf32> {
6+
func.func @simple_generic(%3: tensor<64x512xf32>, %4: tensor<64x512xf32>, %5: tensor<64x512xf32>) -> tensor<64x512xf32> {
77
%6 = linalg.generic {
88
indexing_maps = [#map, #map, #map],
99
iterator_types = ["parallel", "parallel"]
10-
} ins(%3, %4 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%5 : tensor<64x256xf32>) {
10+
} ins(%3, %4 : tensor<64x512xf32>, tensor<64x512xf32>) outs(%5 : tensor<64x512xf32>) {
1111
^bb0(%in: f32, %in_0: f32, %out: f32):
1212
%7 = arith.addf %in, %in_0 : f32
1313
linalg.yield %7 : f32
14-
} -> tensor<64x256xf32>
15-
return %6 : tensor<64x256xf32>
14+
} -> tensor<64x512xf32>
15+
return %6 : tensor<64x512xf32>
1616
}
1717

1818
// CHECK-LABEL: func.func @simple_generic
1919
// CHECK: scf.for %{{.*}} = %c0 to %c64 step %c1
20-
// CHECK: scf.for %{{.*}} = %c0 to %c256 step %c64
21-
// CHECK: linalg.generic {{.*}} outs({{.*}}: tensor<1x64xf32>)
20+
// CHECK: scf.for %{{.*}} = %c0 to %c512 step %c256
21+
// CHECK: linalg.generic {{.*}} outs({{.*}}: tensor<1x256xf32>)
2222

2323
// -----
2424

@@ -65,21 +65,21 @@ func.func @in_nested_region(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>, %5: te
6565

6666
// -----
6767

68-
func.func @multiple_use_tilable_op(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>) -> (tensor<64x256xf32>, tensor<256x64xf32>) {
69-
%add_empty = tensor.empty() : tensor<64x256xf32>
68+
func.func @multiple_use_tilable_op(%3: tensor<64x512xf32>, %4: tensor<64x512xf32>) -> (tensor<64x512xf32>, tensor<512x64xf32>) {
69+
%add_empty = tensor.empty() : tensor<64x512xf32>
7070
%6 = linalg.add
71-
ins(%3, %4 : tensor<64x256xf32>, tensor<64x256xf32>)
72-
outs(%add_empty : tensor<64x256xf32>) -> tensor<64x256xf32>
73-
%transpose_empty = tensor.empty() : tensor<256x64xf32>
71+
ins(%3, %4 : tensor<64x512xf32>, tensor<64x512xf32>)
72+
outs(%add_empty : tensor<64x512xf32>) -> tensor<64x512xf32>
73+
%transpose_empty = tensor.empty() : tensor<512x64xf32>
7474
%7 = linalg.transpose
75-
ins(%6 : tensor<64x256xf32>)
76-
outs(%transpose_empty : tensor<256x64xf32>) permutation = [1, 0]
77-
return %6, %7 : tensor<64x256xf32>, tensor<256x64xf32>
75+
ins(%6 : tensor<64x512xf32>)
76+
outs(%transpose_empty : tensor<512x64xf32>) permutation = [1, 0]
77+
return %6, %7 : tensor<64x512xf32>, tensor<512x64xf32>
7878
}
7979

8080
// CHECK-LABEL: func.func @multiple_use_tilable_op
8181
// CHECK: %[[ADD_TILING:.+]] = scf.for
82-
// CHECK: linalg.add {{.*}} -> tensor<1x64xf32>
82+
// CHECK: linalg.add {{.*}} -> tensor<1x256xf32>
8383
// CHECK: %[[T_TILING:.+]] = scf.for
8484
// CHECK: %[[FUSED_ADD:.+]] = linalg.add {{.*}} -> tensor<64x1xf32>
8585
// CHECK: linalg.transpose ins(%[[FUSED_ADD]]

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,25 +149,28 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
149149
seeds = {/*bestSubgroupCountPerWorkgroup=*/4,
150150
/*bestMNTileCountPerSubgroup=*/4,
151151
/*bestKTileCountPerSubgroup=*/8,
152-
/*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / inBitWidth};
152+
/*bestKElementCountPerSubgroup*/ kCacheLineSizeBits * 4 /
153+
inBitWidth};
153154
} else {
154155
seeds = {/*bestSubgroupCountPerWorkgroup=*/4,
155156
/*bestMNTileCountPerSubgroup=*/16,
156157
/*bestKTileCountPerSubgroup=*/4,
157-
/*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / 2 /
158+
/*bestKElementCountPerSubgroup*/ kCacheLineSizeBits * 2 /
158159
inBitWidth};
159160
}
160161

161-
int64_t maxSharedMemoryBytes = target.getWgp().getMaxWorkgroupMemoryBytes();
162+
// We target slightly below the full available shared Memory to leave room for
163+
// `GPUReduceBankConflictsPass` that will pad shared memory without keeping
164+
// track of usage. We can drop this after solving
165+
// https://github.com/iree-org/iree/issues/19675
166+
int64_t maxSharedMemoryBytes =
167+
target.getWgp().getMaxWorkgroupMemoryBytes() - 64 * inBitWidth;
162168

163169
// First try to find a schedule with an exactly matching intrinsic.
164170
std::optional<GPUMMASchedule> schedule = deduceMMASchedule(
165171
problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize,
166172
transposedLhs, transposedRhs, /*canUpcastAcc=*/false,
167173
/*mustBeAligned*/ mustBeAligned, doCPromotion);
168-
// TODO (nirvedhmeshram) : Add support for upcasting accumulator schedule.
169-
// Currently we dont have this for TileAndFuse path, see
170-
// https://github.com/iree-org/iree/issues/19532
171174
return schedule;
172175
}
173176

compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,9 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target,
620620
/*canUpcastAcc=*/true);
621621
}
622622

623+
LDBG("transposedLhs: " << transposedLhs);
624+
LDBG("transposedRhs: " << transposedRhs);
625+
623626
// Only batch_matmul is supported in the LLVMGPUPadAndVectorDistribute
624627
// pipeline.
625628
// TODO(hanchung): Support cases that there are fused producers.

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor
3939
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
4040
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
4141
// CHECK-SAME: promote_operands = [0, 1]
42-
// CHECK-SAME: reduction = [0, 0, 0, 0, 4]
42+
// CHECK-SAME: reduction = [0, 0, 0, 0, 8]
4343
// CHECK-SAME: subgroup = [1, 1, 4, 1, 0]
4444
// CHECK-SAME: workgroup = [1, 1, 64, 64, 0]
4545

@@ -74,7 +74,7 @@ func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4
7474
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
7575
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
7676
// CHECK-SAME: promote_operands = [0, 1]
77-
// CHECK-SAME: reduction = [0, 0, 0, 0, 4, 1]
77+
// CHECK-SAME: reduction = [0, 0, 0, 0, 8, 1]
7878
// CHECK-SAME: subgroup = [2, 2, 1, 1, 0, 0]
7979
// CHECK-SAME: workgroup = [2, 2, 32, 32, 0, 0]
8080

@@ -136,9 +136,9 @@ func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor<
136136
// CHECK: linalg.matmul {{.*}}lowering_config = #iree_gpu.lowering_config
137137
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
138138
// CHECK-SAME: promote_operands = [0, 1]
139-
// CHECK-SAME: reduction = [0, 0, 2]
140-
// CHECK-SAME: subgroup = [4, 4, 0]
141-
// CHECK-SAME: workgroup = [128, 128, 0]
139+
// CHECK-SAME: reduction = [0, 0, 4]
140+
// CHECK-SAME: subgroup = [2, 4, 0]
141+
// CHECK-SAME: workgroup = [64, 128, 0]
142142

143143
// -----
144144

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,9 +1013,8 @@ hal.executable public @main {
10131013
// CHECK: scf.yield %[[REDUCE]]
10141014

10151015
// CHECK: scf.for %{{.*}} = %{{.*}} to %c16 step %c1
1016-
// CHECK: scf.for
1017-
// CHECK-COUNT-4: arith.addf {{.*}} : vector<9xf32>
1018-
// CHECK: vector.transfer_write {{.*}} vector<9xi8>, memref<32x16x9x9xi8, #hal.descriptor_type<storage_buffer>>
1016+
// CHECK-COUNT-4: arith.addf {{.*}} : vector<9x9xf32>
1017+
// CHECK: vector.transfer_write {{.*}} vector<9x9xi8>, memref<32x16x9x9xi8, #hal.descriptor_type<storage_buffer>>
10191018

10201019
// -----
10211020

0 commit comments

Comments
 (0)