From b98f522f5defb6962111c252b5df3aa262ebfc77 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 29 Jan 2025 13:16:33 -0500 Subject: [PATCH 1/7] Set __launch_bounds__ in kernel whenever we are able Currently we set the number of threads per block via `__launch_bounds__` when register sharing is enabled. This PR just enables this whenever it is possible, i.e. whenever we know the CTA size at compile time. See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#launch-bounds for more background. --- csrc/codegen.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index d9a5ba6bdcd..390196dd732 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -279,13 +279,15 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { const std::string& kernel_name, std::optional num_threads_per_cta) { code_ << "__global__ void "; + if (num_threads_per_cta.has_value()) { + code_ << "__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/" + << num_threads_per_cta.value() << ") "; + } if (kernel_->hasManaged("enable_register_sharing") && kernel_->getManaged("enable_register_sharing")) { NVF_ERROR( num_threads_per_cta.has_value(), "__launch_bounds__ must be set for register sharing warp specialization"); - code_ << "__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/" - << num_threads_per_cta.value() << ") "; } if (kernel_->hasManaged("cluster_dims")) { auto cluster_dims = From eb80538313da9d148c8e34e63e96f2150f1d4baa Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 30 Jan 2025 15:20:55 -0500 Subject: [PATCH 2/7] Compute num threads per block properly --- csrc/codegen.cpp | 8 ++++++-- csrc/parallel_dimension_map.cpp | 8 ++++++++ csrc/parallel_dimension_map.h | 3 +++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 390196dd732..0930648d5c3 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -279,9 +279,13 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { const std::string& kernel_name, std::optional num_threads_per_cta) { code_ << "__global__ void "; - if (num_threads_per_cta.has_value()) { + PolymorphicValue num_threads = + kernel_->summary() + .parallel_dimension_map.getNumThreadsEachBlock() + ->evaluate(); + if (num_threads.hasValue()) { code_ << "__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/" - << num_threads_per_cta.value() << ") "; + << num_threads.as() << ") "; } if (kernel_->hasManaged("enable_register_sharing") && kernel_->getManaged("enable_register_sharing")) { diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index c3595c3d671..27748b616c9 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -192,6 +192,14 @@ bool ParallelDimensionMap::isExact(ParallelType pt) const { return exact_types_.find(pt) != exact_types_.end(); } +Val* ParallelDimensionMap::getNumThreadsEachBlock() const { + Val* num_threads = FusionGuard::getCurFusion()->oneVal(); + for (auto pt : kParallelTypeTIDs) { + num_threads = SimplifyingIrBuilder::mulExpr(num_threads, getRaw(pt)); + } + return num_threads; +} + Val* ParallelDimensionMap::getRawCompute(ParallelType pt) const { Val* raw = getRaw(pt); if (warp_specialized_types_.count(pt)) { diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h index e2e9de423c1..fa407e00132 100644 --- a/csrc/parallel_dimension_map.h +++ b/csrc/parallel_dimension_map.h @@ -41,6 +41,9 @@ class ParallelDimensionMap { return dim_map_; } + //! Get the number of threads per each CTA total. + Val* getNumThreadsEachBlock() const; + //! Get the "compute" parallel dimension on the given ParallelType. In case //! of no warp specialization, this is the same as getRaw(pt). If we are doing //! warp specialization on pt, the result is getRaw(pt) - 1, because the last From b78ae421412d82e3acde856910b01dec91ce8c70 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 30 Jan 2025 20:22:52 -0500 Subject: [PATCH 3/7] Fix error caused by using evaluate() without checking if const --- csrc/codegen.cpp | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 0930648d5c3..4e8a47f077d 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -279,13 +279,16 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { const std::string& kernel_name, std::optional num_threads_per_cta) { code_ << "__global__ void "; - PolymorphicValue num_threads = - kernel_->summary() - .parallel_dimension_map.getNumThreadsEachBlock() - ->evaluate(); - if (num_threads.hasValue()) { - code_ << "__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/" - << num_threads.as() << ") "; + { + // Avoid a const_cast that would be required to use kernel_ by picking the + // fusion of the first kernel output + FusionGuard fg(kernel_->outputs().front()->fusion()); + Val* num_threads = + kernel_->summary().parallel_dimension_map.getNumThreadsEachBlock(); + if (num_threads->isConstInt()) { + code_ << "__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/" + << num_threads->evaluate().as() << ") "; + } } if (kernel_->hasManaged("enable_register_sharing") && kernel_->getManaged("enable_register_sharing")) { From 254a8a91d4731a46b3733de3d31b9594e13636f1 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 31 Jan 2025 08:11:42 -0500 Subject: [PATCH 4/7] Fix failing tests with expected cuda code --- tests/cpp/test_loop_rotation.cpp | 12 ++++++------ tests/cpp/test_scalar_hoisting.cpp | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/cpp/test_loop_rotation.cpp b/tests/cpp/test_loop_rotation.cpp index d8ae9d49bda..154b922c1c4 100644 --- a/tests/cpp/test_loop_rotation.cpp +++ b/tests/cpp/test_loop_rotation.cpp @@ -33,7 +33,7 @@ TEST_F(LoopRotationTest, RotateInner) { scheduler_utils::rotateLoop(tv4, -1, {tv1, tv2}); const std::string expected_kernel = R"( -__global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { +__global__ void __launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/1) CUDAGeneratedKernel(Tensor T0, Tensor T4) { NVFUSER_DEFINE_MAGIC_ZERO; #pragma unroll 1 for(nvfuser_index_t i0 = 0LL; i0 < T0.logical_size[0LL]; ++i0) { @@ -99,7 +99,7 @@ TEST_F(LoopRotationTest, RotateOuter) { scheduler_utils::rotateLoop(tv4, 0, {tv1, tv2}); const std::string expected_kernel = R"( -__global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { +__global__ void __launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/1) CUDAGeneratedKernel(Tensor T0, Tensor T4) { NVFUSER_DEFINE_MAGIC_ZERO; Array T1; Array T2; @@ -196,7 +196,7 @@ TEST_F(LoopRotationTest, NonDivisibleSplit) { scheduler_utils::rotateLoop(tv4, 0, {tv1, tv2}); const std::string expected_kernel = R"( -__global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { +__global__ void __launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/1) CUDAGeneratedKernel(Tensor T0, Tensor T4) { NVFUSER_DEFINE_MAGIC_ZERO; nvfuser_index_t i0; i0 = T0.logical_size[0LL] * T0.logical_size[1LL]; @@ -302,7 +302,7 @@ TEST_F(LoopRotationTest, CircularBuffered) { scheduler_utils::rotateLoop(tv4, 0, {tv2}); const std::string expected_kernel = R"( -__global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { +__global__ void __launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/1) CUDAGeneratedKernel(Tensor T0, Tensor T4) { NVFUSER_DEFINE_MAGIC_ZERO; nvfuser_index_t i0; i0 = 4LL * T0.alloc_stride[0LL]; @@ -413,7 +413,7 @@ TEST_F(LoopRotationTest, SelectCircularBufferLoad) { scheduler_utils::rotateLoop(tv4, 0, {tv1, tv2}); const std::string expected_kernel = R"( -__global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { +__global__ void __launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/1) CUDAGeneratedKernel(Tensor T0, Tensor T4) { NVFUSER_DEFINE_MAGIC_ZERO; nvfuser_index_t i0; i0 = 4LL * T0.alloc_stride[0LL]; @@ -563,7 +563,7 @@ TEST_F(LoopRotationTest, MultipleCircularBuffer) { scheduler_utils::rotateLoop(tv3, 0, {tv1}); const std::string expected_kernel = R"( -__global__ void CUDAGeneratedKernel(Tensor T0, Tensor T3) { +__global__ void __launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/1) CUDAGeneratedKernel(Tensor T0, Tensor T3) { alignas(16) extern __shared__ char array[]; const unsigned smem_offset = 0; NVFUSER_DEFINE_MAGIC_ZERO; diff --git a/tests/cpp/test_scalar_hoisting.cpp b/tests/cpp/test_scalar_hoisting.cpp index ae23b3e5593..177a8e72cda 100644 --- a/tests/cpp/test_scalar_hoisting.cpp +++ b/tests/cpp/test_scalar_hoisting.cpp @@ -295,7 +295,7 @@ TEST_F(ScalarHoistTest, IndexHoist3) { auto cg_outputs = ke.run({t0}); const std::string expected_kernel = R"( -__global__ void CUDAGeneratedKernel(Tensor T0, Tensor T2) { +__global__ void __launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/256) CUDAGeneratedKernel(Tensor T0, Tensor T2) { nvfuser_index_t i0; i0 = ((nvfuser_index_t)threadIdx.x) + (256LL * ((nvfuser_index_t)blockIdx.x)); Tensor s1; @@ -374,7 +374,7 @@ TEST_F(ScalarHoistTest, ARange) { auto cg_outputs = ke.run({start, end, step}); const std::string expected_kernel = R"( -__global__ void CUDAGeneratedKernel(int64_t i0, int64_t i1, int64_t i2, Tensor T0, Tensor T1) { +__global__ void __launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/1) CUDAGeneratedKernel(int64_t i0, int64_t i1, int64_t i2, Tensor T0, Tensor T1) { int64_t i3; i3 = i1 - i0; int64_t i4; From 4495654da134b004e8362b5b174f0b4043233bad Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 31 Jan 2025 08:21:40 -0500 Subject: [PATCH 5/7] Just use a const_cast --- csrc/codegen.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 4e8a47f077d..7e30a166060 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -280,9 +280,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { std::optional num_threads_per_cta) { code_ << "__global__ void "; { - // Avoid a const_cast that would be required to use kernel_ by picking the - // fusion of the first kernel output - FusionGuard fg(kernel_->outputs().front()->fusion()); + FusionGuard fg(const_cast(kernel_)); Val* num_threads = kernel_->summary().parallel_dimension_map.getNumThreadsEachBlock(); if (num_threads->isConstInt()) { From 508d7ae3ee4a0a5d555f34c51b97572924853780 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 31 Jan 2025 08:44:31 -0500 Subject: [PATCH 6/7] Disable clang-tidy warning on const_cast --- csrc/codegen.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 7e30a166060..c6a68f42519 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -280,6 +280,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { std::optional num_threads_per_cta) { code_ << "__global__ void "; { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) FusionGuard fg(const_cast(kernel_)); Val* num_threads = kernel_->summary().parallel_dimension_map.getNumThreadsEachBlock(); From 008816176990148d1e5bdf2381fd1c24f83fb86d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 11 Feb 2025 09:03:39 -0500 Subject: [PATCH 7/7] Update logic --- csrc/codegen.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index de55a2fe2c8..cfa2b456b7d 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -279,14 +279,14 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { const std::string& kernel_name, std::optional num_threads_per_cta) { code_ << "__global__ void "; - { + if (!num_threads_per_cta.has_value()) { + // Try to evaluate the block size so that we can set launch bounds. // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) FusionGuard fg(const_cast(kernel_)); Val* num_threads = kernel_->summary().parallel_dimension_map.getNumThreadsEachBlock(); if (num_threads->isConstInt()) { - code_ << "__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/" - << num_threads->evaluate().as() << ") "; + num_threads_per_cta = num_threads->evaluate().as(); } } if (kernel_->hasManaged("enable_register_sharing") && @@ -295,6 +295,10 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { num_threads_per_cta.has_value(), "__launch_bounds__ must be set for register sharing warp specialization"); } + if (num_threads_per_cta.has_value()) { + code_ << "__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/" + << num_threads_per_cta.value() << ") "; + } if (kernel_->hasManaged("cluster_dims")) { auto cluster_dims = kernel_->getManaged>(