Skip to content

Commit

Permalink
[GPU] Match TileAndFuse Matmul heuristics to Vector Distribute
Browse files Browse the repository at this point in the history
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
  • Loading branch information
nirvedhmeshram committed Jan 10, 2025
1 parent 106371d commit b3f2c11
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
GPUMMAHeuristicSeeds seeds;
assert(problem.aType == problem.bType &&
"expected the same aType and bType.");
int64_t inBitWidth = problem.aType.getIntOrFloatBitWidth();

// Note that the following heuristic seeds are just placeholder values.
// We need to clean it up and make it adjusting to different targets.
Expand All @@ -148,22 +147,19 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
// and a larger bestKTileCountPerSubgroup.
seeds = {/*bestSubgroupCountPerWorkgroup=*/4,
/*bestMNTileCountPerSubgroup=*/4,
/*bestKTileCountPerSubgroup=*/8,
/*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / inBitWidth};
/*bestKTileCountPerSubgroup=*/8};
} else {
seeds = {/*bestSubgroupCountPerWorkgroup=*/4,
/*bestMNTileCountPerSubgroup=*/16,
/*bestKTileCountPerSubgroup=*/4,
/*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / 2 /
inBitWidth};
/*bestMNTileCountPerSubgroup=*/8,
/*bestKTileCountPerSubgroup=*/4};
}

int64_t maxSharedMemoryBytes = target.getWgp().getMaxWorkgroupMemoryBytes();

// First try to find a schedule with an exactly matching intrinsic.
std::optional<GPUMMASchedule> schedule = deduceMMASchedule(
problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize,
transposedLhs, transposedRhs, /*canUpcastAcc=*/false,
/*transposedLhs=*/false, /*transposedRhs=*/false, /*canUpcastAcc=*/false,
/*mustBeAligned*/ mustBeAligned, doCPromotion);
if (!schedule) {
// Then try again by allowing upcasting accumulator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
// CHECK-SAME: promote_operands = [0, 1]
// CHECK-SAME: reduction = [0, 0, 0, 0, 4]
// CHECK-SAME: reduction = [0, 0, 0, 0, 8]
// CHECK-SAME: subgroup = [1, 1, 4, 1, 0]
// CHECK-SAME: workgroup = [1, 1, 64, 64, 0]

Expand Down Expand Up @@ -70,7 +70,7 @@ func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
// CHECK-SAME: promote_operands = [0, 1]
// CHECK-SAME: reduction = [0, 0, 0, 0, 4, 1]
// CHECK-SAME: reduction = [0, 0, 0, 0, 8, 1]
// CHECK-SAME: subgroup = [2, 2, 1, 1, 0, 0]
// CHECK-SAME: workgroup = [2, 2, 32, 32, 0, 0]

Expand Down Expand Up @@ -130,9 +130,9 @@ func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor<
// CHECK: linalg.matmul {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
// CHECK-SAME: promote_operands = [0, 1]
// CHECK-SAME: reduction = [0, 0, 2]
// CHECK-SAME: subgroup = [4, 4, 0]
// CHECK-SAME: workgroup = [128, 128, 0]
// CHECK-SAME: reduction = [0, 0, 4]
// CHECK-SAME: subgroup = [2, 4, 0]
// CHECK-SAME: workgroup = [64, 128, 0]

// -----

Expand Down

0 comments on commit b3f2c11

Please sign in to comment.