From fe2b43e83348c5b4e83d42e5b8d7fe9bc67f756a Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram Date: Fri, 10 Jan 2025 16:35:18 -0600 Subject: [PATCH] Go back to old Heuristic but ask for more bestKElementCountPerSubgroup Signed-off-by: Nirvedh Meshram --- .../Dialect/GPU/TargetUtils/ConfigUtils.cpp | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index 995f86ef51ad..e42ba2ddaa7d 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -135,6 +135,7 @@ static std::optional getMmaScheduleFromProblemAndTarget( GPUMMAHeuristicSeeds seeds; assert(problem.aType == problem.bType && "expected the same aType and bType."); + int64_t inBitWidth = problem.aType.getIntOrFloatBitWidth(); // Note that the following heuristic seeds are just placeholder values. // We need to clean it up and make it adjusting to different targets. @@ -147,14 +148,23 @@ static std::optional getMmaScheduleFromProblemAndTarget( // and a larger bestKTileCountPerSubgroup. seeds = {/*bestSubgroupCountPerWorkgroup=*/4, /*bestMNTileCountPerSubgroup=*/4, - /*bestKTileCountPerSubgroup=*/8}; + /*bestKTileCountPerSubgroup=*/8, + /*bestKElementCountPerSubgroup*/ kCacheLineSizeBits * 4 / + inBitWidth}; } else { seeds = {/*bestSubgroupCountPerWorkgroup=*/4, - /*bestMNTileCountPerSubgroup=*/8, - /*bestKTileCountPerSubgroup=*/4}; - } - - int64_t maxSharedMemoryBytes = target.getWgp().getMaxWorkgroupMemoryBytes(); + /*bestMNTileCountPerSubgroup=*/16, + /*bestKTileCountPerSubgroup=*/4, + /*bestKElementCountPerSubgroup*/ kCacheLineSizeBits * 2 / + inBitWidth}; + } + + // We target slightly below the full available shared Memory to leave room for + // `GPUReduceBankConflictsPass` that will pad shared memory without keeping + // track of usage. We can drop this after fixing + // https://github.com/iree-org/iree/issues/19675 + int64_t maxSharedMemoryBytes = + target.getWgp().getMaxWorkgroupMemoryBytes() - 64 * inBitWidth; // First try to find a schedule with an exactly matching intrinsic. std::optional schedule = deduceMMASchedule(