-
Notifications
You must be signed in to change notification settings - Fork 244
[CK_TILE] Conv bwd splitN support #3047
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp
Outdated
Show resolved
Hide resolved
include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp
Show resolved
Hide resolved
include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements Split-N support for the backward convolution (data) kernel, enabling the batch dimension to be split across multiple GPU blocks when tensor sizes exceed 2GB. This parallels the Split-N implementation previously added for forward convolution in PR #2776.
Key changes:
- Enables Split-N computation by distributing batch dimension across
blockIdx.z - Adds logic to calculate optimal N splits based on tensor memory footprint
- Adds safeguards to prevent simultaneous use of Split-K and Split-N (both use
blockIdx.z)
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp |
Implements GetSplitedNSize() to calculate optimal batch splits and adds tracking of original vs. split batch sizes |
include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp |
Updates kernel to use Split-N grid dimensions, applies batch offsets to input/output pointers, and adds Split-K/Split-N conflict detection |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp
Outdated
Show resolved
Hide resolved
include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| for(index_t i = num_dims - 2; i >= 1; i--) | ||
| { | ||
| a_g_n_c_wis_strides[i] = a_g_n_c_wis_strides[i + 1] * a_g_n_c_wis_lengths[i + 1]; | ||
| } |
Copilot
AI
Oct 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Loop condition 'i >= 1' with signed index type can cause undefined behavior when 'i' decrements from 0. Since 'index_t' is typically a signed type, when 'i' reaches 0 and decrements, it becomes -1, which still satisfies 'i >= 1' is false, but the comparison may cause issues. Consider using 'i > 0' or switching to an unsigned type with explicit bounds checking.
| for(index_t i = num_dims - 2; i >= 1; i--) | ||
| { | ||
| c_g_n_k_wos_strides[i] = c_g_n_k_wos_strides[i + 1] * c_g_n_k_wos_lengths[i + 1]; | ||
| } |
Copilot
AI
Oct 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same loop condition issue as with the input stride calculation. The condition 'i >= 1' with a signed index can lead to undefined behavior. Use 'i > 0' or ensure proper type handling.
| static_cast<long_index_t>(kargs.input_batch_stride); | ||
|
|
||
| // SplitK | ||
| // TODO: Implement SplitK support |
Copilot
AI
Oct 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TODO comment mentions SplitK support but doesn't explain its relationship to the current SplitN implementation or why it's commented out. Consider adding context about whether SplitK conflicts with SplitN, or if it's a future enhancement that works independently.
| // TODO: Implement SplitK support | |
| // TODO: Implement SplitK support. | |
| // Note: SplitK and SplitN are mutually exclusive in the current implementation. | |
| // Only SplitN is supported at this time; SplitK is commented out because | |
| // enabling both simultaneously would require additional logic to handle | |
| // their interaction. SplitK support is planned as a future enhancement. |
* Conv bwd splitN support * Adjust splitting calculations to lengths format * Prepare indexing for future splitK support
* Conv bwd splitN support * Adjust splitting calculations to lengths format * Prepare indexing for future splitK support
Proposed changes
Implements SplitN support for backwards convolution (data), similar to what #2776 did for forward convolution.
Test results (with element-size-limit reduced to 16MB make CPU reference comparison feasible):
The above results show that the implementation is correct for
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered