Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix stmatrix scheduling for persistent GEMM #3791

Merged
merged 1 commit into from
Jan 30, 2025
Merged

Conversation

rdspring1
Copy link
Collaborator

Problem

The current stmatrix scheduling assumes all iterDomains are parallelized to the left of the mma output allocation domain.
Therefore, it moves the stmatrix serial iterDomain to the 0th position. This is incompatible with persistent gemm kernels, which have a grid strided serial iterDomain.

Solution

This PR fixes this moving the stmatrix serial iterDomain back one position and using the 3rd from the end for-loop during index generation. The 3rd from the end for-loop is the 0th position from the mma output allocation domain.

@rdspring1
Copy link
Collaborator Author

!test

Copy link

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 2 🔵🔵⚪⚪⚪
🧪 No relevant tests
⚡ Recommended focus areas for review

Index Generation

The PR changes the index generation for stmatrix to use the 3rd from the end for-loop. This change may have implications for the correctness and performance of the stmatrix scheduling.

      ldst, for_loops_[for_loops_.size() - 3], m_tile, n_tile, m, n);
  break;
case MmaInputSmemSwizzle::B128:
case MmaInputSmemSwizzle::B64:
case MmaInputSmemSwizzle::B32:
  out = hardCodedIndexGenerationForStMatrixSwizzle(
      ldst, for_loops_[for_loops_.size() - 3], m_tile, n_tile, m, n);
IterDomain Reordering

The PR reorders the iterDomains to accommodate the stmatrix scheduling. This change may have implications for the correctness and performance of the mma output allocation.

  tv->reorder({{-4, -5}});
  // [128(TIDx), 2(no), 2(ni), 2, 2] -> [2(no), 128(TIDx), 8 (vectorize)]
  tv->merge(-3);
  tv->merge(-2);
} else if (tile_m == 16 && tile_n == 8) {
  // Let [M, N] be [64, 16]
  // After scheduleMmaOutputAllocation: [128(TIDx), 2, 2, 2]
  // [128(TIDx), 2, 2, 2] -> [2, 128(TIDx), 2, 2]
  tv->reorder({{-3, -4}});

Copy link
Collaborator

@jacobhinkle jacobhinkle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@rdspring1 rdspring1 merged commit a3f6ba9 into main Jan 30, 2025
51 checks passed
@rdspring1 rdspring1 deleted the fix_stmatrix_for_loop branch January 30, 2025 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants