Skip to content

Commit 282e5a7

Browse files
Go back to old Heuristic but ask for more bestKElementCountPerSubgroup
1 parent b3f2c11 commit 282e5a7

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
135135
GPUMMAHeuristicSeeds seeds;
136136
assert(problem.aType == problem.bType &&
137137
"expected the same aType and bType.");
138+
int64_t inBitWidth = problem.aType.getIntOrFloatBitWidth();
138139

139140
// Note that the following heuristic seeds are just placeholder values.
140141
// We need to clean it up and make it adjusting to different targets.
@@ -147,14 +148,23 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
147148
// and a larger bestKTileCountPerSubgroup.
148149
seeds = {/*bestSubgroupCountPerWorkgroup=*/4,
149150
/*bestMNTileCountPerSubgroup=*/4,
150-
/*bestKTileCountPerSubgroup=*/8};
151+
/*bestKTileCountPerSubgroup=*/8,
152+
/*bestKElementCountPerSubgroup*/ kCacheLineSizeBits * 4 /
153+
inBitWidth};
151154
} else {
152155
seeds = {/*bestSubgroupCountPerWorkgroup=*/4,
153-
/*bestMNTileCountPerSubgroup=*/8,
154-
/*bestKTileCountPerSubgroup=*/4};
155-
}
156-
157-
int64_t maxSharedMemoryBytes = target.getWgp().getMaxWorkgroupMemoryBytes();
156+
/*bestMNTileCountPerSubgroup=*/16,
157+
/*bestKTileCountPerSubgroup=*/4,
158+
/*bestKElementCountPerSubgroup*/ kCacheLineSizeBits * 2 /
159+
inBitWidth};
160+
}
161+
162+
// We target slightly below the full available shared Memory to leave room for
163+
// `GPUReduceBankConflictsPass` that will pad shared memory without keeping
164+
// track of usage. We can drop this after fixing
165+
// https://github.com/iree-org/iree/issues/19675
166+
int64_t maxSharedMemoryBytes =
167+
target.getWgp().getMaxWorkgroupMemoryBytes() - 64 * inBitWidth;
158168

159169
// First try to find a schedule with an exactly matching intrinsic.
160170
std::optional<GPUMMASchedule> schedule = deduceMMASchedule(

0 commit comments

Comments
 (0)