Skip to content

Set __launch_bounds__ in kernel whenever we are able #3794

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Jan 29, 2025

Currently we set the number of threads per block via __launch_bounds__ when register sharing is enabled. This PR just enables this whenever it is possible, i.e. whenever we know the CTA size at compile time.

Adds the method ParallelDimensionMap::getNumThreadsEachBlock() which is similar to ParallelDimensionMap::getNumComputeThreadsEachBlock() but includes all threads and doesn't skip dma threads.

See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#launch-bounds for more background.

Currently we set the number of threads per block via `__launch_bounds__`
when register sharing is enabled. This PR just enables this whenever it
is possible, i.e. whenever we know the CTA size at compile time.

See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#launch-bounds
for more background.
@jacobhinkle
Copy link
Collaborator Author

!test

Copy link

github-actions bot commented Jan 29, 2025

PR Reviewer Guide 🔍

(Review updated until commit b78ae42)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 2 🔵🔵⚪⚪⚪
🧪 No relevant tests
⚡ Recommended focus areas for review

Const Cast

The code uses a const_cast to avoid a const correctness issue. Consider re-designing the code to avoid the need for const_cast.

// Avoid a const_cast that would be required to use kernel_ by picking the
// fusion of the first kernel output
FusionGuard fg(kernel_->outputs().front()->fusion());
Thread Count Calculation

The getNumThreadsEachBlock method calculates the total number of threads per block by multiplying the number of threads for each parallel type. Verify that this calculation is correct and accounts for all possible parallel types.

Val* ParallelDimensionMap::getNumThreadsEachBlock() const {
  Val* num_threads = FusionGuard::getCurFusion()->oneVal();
  for (auto pt : kParallelTypeTIDs) {
    num_threads = SimplifyingIrBuilder::mulExpr(num_threads, getRaw(pt));
  }
  return num_threads;
}

Copy link
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me, but will let Ryan to decide.

@rdspring1
Copy link
Collaborator

!test --pybench

Copy link
Collaborator

@rdspring1 rdspring1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_threads_per_cta should always have a value because it is the inferred launch bounds. https://github.com/NVIDIA/Fuser/blob/main/csrc/runtime/executor.cpp#L268-L286

This seems different from what you intended. I thought you wanted to set the launch bounds when TIDx, TIDy, and TIDz extents are constant.

@jacobhinkle
Copy link
Collaborator Author

This seems different from what you intended. I thought you wanted to set the launch bounds when TIDx, TIDy, and TIDz extents are constant.

Ah, you're right. I did intend it to not be derived from inputs since that would interfere with dynamic shapes. I'll give it another try.

@jacobhinkle
Copy link
Collaborator Author

!test --diff

@jacobhinkle
Copy link
Collaborator Author

!test --diff

@jacobhinkle
Copy link
Collaborator Author

!test --diff

Copy link

github-actions bot commented Jan 31, 2025

Review updated until commit 0088161

Description

  • Set __launch_bounds__ whenever CTA size is known at compile time.

  • Add ParallelDimensionMap::getNumThreadsEachBlock() method.

  • Update tests to reflect new kernel launch bounds.


Changes walkthrough 📝

Relevant files
Enhancement
codegen.cpp
Add logic for setting launch bounds                                           

csrc/codegen.cpp

  • Added logic to evaluate block size and set __launch_bounds__ if CTA
    size is known.
  • Updated kernel declaration to include __launch_bounds__ when
    available.
  • +12/-0   
    parallel_dimension_map.cpp
    Implement getNumThreadsEachBlock method                                   

    csrc/parallel_dimension_map.cpp

  • Implemented getNumThreadsEachBlock() to calculate total threads per
    block.
  • +8/-0     
    parallel_dimension_map.h
    Add getNumThreadsEachBlock declaration                                     

    csrc/parallel_dimension_map.h

    • Added declaration for getNumThreadsEachBlock().
    +3/-0     
    Tests
    test_loop_rotation.cpp
    Update expected kernel strings                                                     

    tests/cpp/test_loop_rotation.cpp

    • Updated expected kernel strings to include __launch_bounds__.
    +6/-6     
    test_scalar_hoisting.cpp
    Update expected kernel strings                                                     

    tests/cpp/test_scalar_hoisting.cpp

    • Updated expected kernel strings to include __launch_bounds__.
    +2/-2     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Launch Bounds Calculation

    Ensure that the calculation of num_threads_per_cta is correct and handles all edge cases, especially when num_threads_per_cta is not initially provided.

    if (!num_threads_per_cta.has_value()) {
      // Try to evaluate the block size so that we can set launch bounds.
      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
      FusionGuard fg(const_cast<kir::Kernel*>(kernel_));
      Val* num_threads =
          kernel_->summary().parallel_dimension_map.getNumThreadsEachBlock();
      if (num_threads->isConstInt()) {
        num_threads_per_cta = num_threads->evaluate().as<int64_t>();
      }
    }
    Thread Count Calculation

    Verify that getNumThreadsEachBlock correctly calculates the total number of threads per block, including all types of threads.

    Val* ParallelDimensionMap::getNumThreadsEachBlock() const {
      Val* num_threads = FusionGuard::getCurFusion()->oneVal();
      for (auto pt : kParallelTypeTIDs) {
        num_threads = SimplifyingIrBuilder::mulExpr(num_threads, getRaw(pt));
      }
      return num_threads;
    }
    Expected Kernel Output

    Confirm that the expected kernel output strings in the tests accurately reflect the changes made to the kernel launch bounds.

      const std::string expected_kernel = R"(
    __global__ void __launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/256) CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2> T2) {

    @jacobhinkle
    Copy link
    Collaborator Author

    !test --diff

    @jacobhinkle
    Copy link
    Collaborator Author

    !test --diff

    @jacobhinkle
    Copy link
    Collaborator Author

    Codediff doesn't show anything concerning IMO. Register usage is usually reduced, though sometimes increased by a few registers.

    if (kernel_->hasManaged("enable_register_sharing") &&
    kernel_->getManaged<bool>("enable_register_sharing")) {
    NVF_ERROR(
    num_threads_per_cta.has_value(),
    "__launch_bounds__ must be set for register sharing warp specialization");
    code_ << "__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/"
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Isn't this still necessary because we can determine the number of threads at runtime?

    How about changing the predicate to

        // Always set __launch_bounds__  when register sharing is enabled.
        if (kernel_->hasManaged("enable_register_sharing") &&
            kernel_->getManaged<bool>("enable_register_sharing") {
            ...
        } else {
          // Enable __launch_bounds__ when number of threads is known at compile-time.
          // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
          FusionGuard fg(const_cast<kir::Kernel*>(kernel_));
          Val* num_threads =
              kernel_->summary().parallel_dimension_map.getNumThreadsEachBlock();
          if (num_threads->isConstInt()) {
            code_ << "__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/"
                  << num_threads->evaluate().as<int64_t>() << ") ";
          }
        }
    

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Good call. I now just try and set num_threads_per_cta if it is unset, then I use that to set the launch bounds argument. So in the warp specialization case we will skip the check and just use the provided number of threads.

    @jacobhinkle
    Copy link
    Collaborator Author

    !test

    @jacobhinkle
    Copy link
    Collaborator Author

    We hit an error because of this

    // Now that we have launch parameters we can compile the kernel. It's a bit
    // odd we need launch parameters for compilation, need to go back and check
    // why this is the case.
    compiled_kernel_->compile(launch_params.nThreads());

    When we compile the kernel, we are launching it one way, but then when we reuse the kernel we can use different launch params. I was thinking that for warp specialization in matmul this is probably not needed because we will have fixed block size. @rdspring1 what is the case where we might need to have runtime determination of block size?

    If needed, we could consider caching and recompiling after lowering, as part of CompiledKernel. This is a very similar challenge to index type: see #3850.

    if (kernel_->hasManaged("enable_register_sharing") &&
    kernel_->getManaged<bool>("enable_register_sharing")) {
    NVF_ERROR(
    num_threads_per_cta.has_value(),
    "__launch_bounds__ must be set for register sharing warp specialization");
    }
    if (num_threads_per_cta.has_value()) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    num_threads_per_cta is always known at compile-time, so the setting launch bounds with a runtime value should be guarded by enable_register_sharing.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Yeah, I get that now. But due to kernel re-use this is not safe. We might have different LaunchParams during compilation than we do for a later launch.

    @rdspring1
    Copy link
    Collaborator

    rdspring1 commented Feb 11, 2025

    The persistent schedulers can use dynamic block sizes. There was interest in exploring warp specialization with register sharing for those schedulers.

    https://github.com/NVIDIA/Fuser/blob/main/csrc/scheduler/normalization_inner_outer.cpp#L903-L908

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    3 participants