@@ -135,6 +135,7 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
135
135
GPUMMAHeuristicSeeds seeds;
136
136
assert (problem.aType == problem.bType &&
137
137
" expected the same aType and bType." );
138
+ int64_t inBitWidth = problem.aType .getIntOrFloatBitWidth ();
138
139
139
140
// Note that the following heuristic seeds are just placeholder values.
140
141
// We need to clean it up and make it adjusting to different targets.
@@ -147,14 +148,23 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
147
148
// and a larger bestKTileCountPerSubgroup.
148
149
seeds = {/* bestSubgroupCountPerWorkgroup=*/ 4 ,
149
150
/* bestMNTileCountPerSubgroup=*/ 4 ,
150
- /* bestKTileCountPerSubgroup=*/ 8 };
151
+ /* bestKTileCountPerSubgroup=*/ 8 ,
152
+ /* bestKElementCountPerSubgroup*/ kCacheLineSizeBits * 4 /
153
+ inBitWidth};
151
154
} else {
152
155
seeds = {/* bestSubgroupCountPerWorkgroup=*/ 4 ,
153
- /* bestMNTileCountPerSubgroup=*/ 8 ,
154
- /* bestKTileCountPerSubgroup=*/ 4 };
155
- }
156
-
157
- int64_t maxSharedMemoryBytes = target.getWgp ().getMaxWorkgroupMemoryBytes ();
156
+ /* bestMNTileCountPerSubgroup=*/ 16 ,
157
+ /* bestKTileCountPerSubgroup=*/ 4 ,
158
+ /* bestKElementCountPerSubgroup*/ kCacheLineSizeBits * 2 /
159
+ inBitWidth};
160
+ }
161
+
162
+ // We target slightly below the full available shared Memory to leave room for
163
+ // `GPUReduceBankConflictsPass` that will pad shared memory without keeping
164
+ // track of usage. We can drop this after fixing
165
+ // https://github.com/iree-org/iree/issues/19675
166
+ int64_t maxSharedMemoryBytes =
167
+ target.getWgp ().getMaxWorkgroupMemoryBytes () - 64 * inBitWidth;
158
168
159
169
// First try to find a schedule with an exactly matching intrinsic.
160
170
std::optional<GPUMMASchedule> schedule = deduceMMASchedule (
0 commit comments