Skip to content

Commit 1554e02

Browse files
committed
cherry-pick: fix stmatrix indexing
1 parent a470821 commit 1554e02

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

csrc/device_lower/pass/index.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2115,13 +2115,13 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
21152115
switch (swizzle) {
21162116
case MmaInputSmemSwizzle::None:
21172117
out = hardCodedIndexGenerationForStMatrix(
2118-
ldst, for_loops_[0], m_tile, n_tile, m, n);
2118+
ldst, for_loops_[for_loops_.size() - 3], m_tile, n_tile, m, n);
21192119
break;
21202120
case MmaInputSmemSwizzle::B128:
21212121
case MmaInputSmemSwizzle::B64:
21222122
case MmaInputSmemSwizzle::B32:
21232123
out = hardCodedIndexGenerationForStMatrixSwizzle(
2124-
ldst, for_loops_[0], m_tile, n_tile, m, n);
2124+
ldst, for_loops_[for_loops_.size() - 3], m_tile, n_tile, m, n);
21252125
break;
21262126
default:
21272127
NVF_ERROR("Unsupported Swizzle Type for StMatrix");

csrc/scheduler/mma_utils.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,21 +1315,23 @@ void scheduleStMatrixForMmaOutput(
13151315
dataTypeSize(tv->dtype()) == 2,
13161316
"we only support 16-bit types in stmatrix");
13171317

1318+
// NOTE: There can be iterDomains to left of the mma output if there is cta
1319+
// or warp tiling.
13181320
if (tile_m == 16 && tile_n == 16) {
13191321
// Let [M, N] be [64, 32]
13201322
// After scheduleMmaOutputAllocation: [128(TIDx), 4, 2, 2]
13211323
// [128(TIDx), 4(n), 2, 2] -> [128(TIDx), 2(no), 2(ni), 2, 2]
13221324
tv->split(-3, 2);
13231325
// [128(TIDx), 2(no), 2(ni), 2, 2] -> [2(no), 128(TIDx), 2(ni), 2, 2]
1324-
tv->reorder({{-4, 0}});
1326+
tv->reorder({{-4, -5}});
13251327
// [128(TIDx), 2(no), 2(ni), 2, 2] -> [2(no), 128(TIDx), 8 (vectorize)]
13261328
tv->merge(-3);
13271329
tv->merge(-2);
13281330
} else if (tile_m == 16 && tile_n == 8) {
13291331
// Let [M, N] be [64, 16]
13301332
// After scheduleMmaOutputAllocation: [128(TIDx), 2, 2, 2]
13311333
// [128(TIDx), 2, 2, 2] -> [2, 128(TIDx), 2, 2]
1332-
tv->reorder({{-3, 0}});
1334+
tv->reorder({{-3, -4}});
13331335
// [2, 128(TIDx), 2, 2] -> [2, 128(TIDx), 4(vectorize)]
13341336
tv->merge(-2);
13351337
}

0 commit comments

Comments
 (0)