diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index 3c514c57ad4b..995f86ef51ad 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -135,7 +135,6 @@ static std::optional getMmaScheduleFromProblemAndTarget( GPUMMAHeuristicSeeds seeds; assert(problem.aType == problem.bType && "expected the same aType and bType."); - int64_t inBitWidth = problem.aType.getIntOrFloatBitWidth(); // Note that the following heuristic seeds are just placeholder values. // We need to clean it up and make it adjusting to different targets. @@ -148,14 +147,11 @@ static std::optional getMmaScheduleFromProblemAndTarget( // and a larger bestKTileCountPerSubgroup. seeds = {/*bestSubgroupCountPerWorkgroup=*/4, /*bestMNTileCountPerSubgroup=*/4, - /*bestKTileCountPerSubgroup=*/8, - /*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / inBitWidth}; + /*bestKTileCountPerSubgroup=*/8}; } else { seeds = {/*bestSubgroupCountPerWorkgroup=*/4, - /*bestMNTileCountPerSubgroup=*/16, - /*bestKTileCountPerSubgroup=*/4, - /*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / 2 / - inBitWidth}; + /*bestMNTileCountPerSubgroup=*/8, + /*bestKTileCountPerSubgroup=*/4}; } int64_t maxSharedMemoryBytes = target.getWgp().getMaxWorkgroupMemoryBytes(); @@ -163,7 +159,7 @@ static std::optional getMmaScheduleFromProblemAndTarget( // First try to find a schedule with an exactly matching intrinsic. std::optional schedule = deduceMMASchedule( problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize, - transposedLhs, transposedRhs, /*canUpcastAcc=*/false, + /*transposedLhs=*/false, /*transposedRhs=*/false, /*canUpcastAcc=*/false, /*mustBeAligned*/ mustBeAligned, doCPromotion); if (!schedule) { // Then try again by allowing upcasting accumulator. diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir index 47ccb627ac42..503d49d0095c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir @@ -37,7 +37,7 @@ func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor // CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: mma_kind = #iree_gpu.mma_layout // CHECK-SAME: promote_operands = [0, 1] -// CHECK-SAME: reduction = [0, 0, 0, 0, 4] +// CHECK-SAME: reduction = [0, 0, 0, 0, 8] // CHECK-SAME: subgroup = [1, 1, 4, 1, 0] // CHECK-SAME: workgroup = [1, 1, 64, 64, 0] @@ -70,7 +70,7 @@ func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4 // CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: mma_kind = #iree_gpu.mma_layout // CHECK-SAME: promote_operands = [0, 1] -// CHECK-SAME: reduction = [0, 0, 0, 0, 4, 1] +// CHECK-SAME: reduction = [0, 0, 0, 0, 8, 1] // CHECK-SAME: subgroup = [2, 2, 1, 1, 0, 0] // CHECK-SAME: workgroup = [2, 2, 32, 32, 0, 0] @@ -130,9 +130,9 @@ func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor< // CHECK: linalg.matmul {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: mma_kind = #iree_gpu.mma_layout // CHECK-SAME: promote_operands = [0, 1] -// CHECK-SAME: reduction = [0, 0, 2] -// CHECK-SAME: subgroup = [4, 4, 0] -// CHECK-SAME: workgroup = [128, 128, 0] +// CHECK-SAME: reduction = [0, 0, 4] +// CHECK-SAME: subgroup = [2, 4, 0] +// CHECK-SAME: workgroup = [64, 128, 0] // -----