-
Notifications
You must be signed in to change notification settings - Fork 9
WCC Grouped Gemm Implementation #24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements a Work-Centric Grouped GEMM (General Matrix Multiply) operation for the tritonblas library using Triton kernels. The implementation provides an efficient way to execute multiple matrix multiplications in a single kernel launch by distributing work across processing units.
- Adds a new work-centric grouped GEMM implementation with partial result accumulation
- Includes comprehensive test coverage for various matrix sizes and block configurations
- Exposes the grouped GEMM functionality through the tritonblas public API
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| tests/test_grouped_gemm.py | Test file with parametrized tests for different matrix sizes, block sizes, and group counts |
| include/tritonblas/internal/wcc_grouped_gemm.py | Core implementation containing WCC grouped GEMM kernel and helper functions |
| include/tritonblas/grouped_gemm.py | Public API wrapper that prepares data structures and launches the kernel |
| include/tritonblas/init.py | Exports the grouped_gemm function for public use |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| current_device_index = torch.cuda.current_device() | ||
| current_device = torch.cuda.get_device_properties(current_device_index) | ||
| MAX_SMS = current_device.multi_processor_count | ||
| #TODO: 256x256 for fp16/bf16, need adjust for fp8/fp4 |
Copilot
AI
Sep 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO comment should be formatted with proper spacing: '# TODO:' and provide more specific guidance about the adjustment needed for fp8/fp4 data types.
| #TODO: 256x256 for fp16/bf16, need adjust for fp8/fp4 | |
| # TODO: 256x256 block size is suitable for fp16/bf16; for fp8/fp4, consider reducing block size (e.g., 128x128) due to hardware and data type constraints. Investigate optimal values. |
| A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_SIZE_K * stride_ak * remainder | ||
| B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_SIZE_K * stride_bk * remainder | ||
| """ | ||
| A_BASE = A + rm[:, None] * stride_am + rk[None, :] + (BLOCK_SIZE_K * tile_offset) |
Copilot
AI
Sep 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Matrix A pointer calculation is missing the stride_ak multiplication. It should be rk[None, :] * stride_ak to properly handle different stride patterns for matrix A.
| A_BASE = A + rm[:, None] * stride_am + rk[None, :] + (BLOCK_SIZE_K * tile_offset) | |
| A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + (BLOCK_SIZE_K * tile_offset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ryanswann-amd can you take a look if its still correct after adding this stride?
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
| A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_SIZE_K * stride_ak * remainder | ||
| B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_SIZE_K * stride_bk * remainder | ||
| """ | ||
| A_BASE = A + rm[:, None] * stride_am + rk[None, :] + (BLOCK_SIZE_K * tile_offset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ryanswann-amd can you take a look if its still correct after adding this stride?
Motivation
Work Centric Grouped GEMM Implementation
Test Plan
Testing script is added to test the correctness
Test Result
All tests pass
Submission Checklist