Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set __launch_bounds__ in kernel whenever we are able #3794

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
12 changes: 10 additions & 2 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,21 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
const std::string& kernel_name,
std::optional<int64_t> num_threads_per_cta) {
code_ << "__global__ void ";
{
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
FusionGuard fg(const_cast<kir::Kernel*>(kernel_));
Val* num_threads =
kernel_->summary().parallel_dimension_map.getNumThreadsEachBlock();
if (num_threads->isConstInt()) {
code_ << "__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/"
<< num_threads->evaluate().as<int64_t>() << ") ";
}
}
if (kernel_->hasManaged("enable_register_sharing") &&
kernel_->getManaged<bool>("enable_register_sharing")) {
NVF_ERROR(
num_threads_per_cta.has_value(),
"__launch_bounds__ must be set for register sharing warp specialization");
code_ << "__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/"
<< num_threads_per_cta.value() << ") ";
}
if (kernel_->hasManaged("cluster_dims")) {
auto cluster_dims =
Expand Down
8 changes: 8 additions & 0 deletions csrc/parallel_dimension_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,14 @@ bool ParallelDimensionMap::isExact(ParallelType pt) const {
return exact_types_.find(pt) != exact_types_.end();
}

Val* ParallelDimensionMap::getNumThreadsEachBlock() const {
Val* num_threads = FusionGuard::getCurFusion()->oneVal();
for (auto pt : kParallelTypeTIDs) {
num_threads = SimplifyingIrBuilder::mulExpr(num_threads, getRaw(pt));
}
return num_threads;
}

Val* ParallelDimensionMap::getRawCompute(ParallelType pt) const {
Val* raw = getRaw(pt);
if (warp_specialized_types_.count(pt)) {
Expand Down
3 changes: 3 additions & 0 deletions csrc/parallel_dimension_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class ParallelDimensionMap {
return dim_map_;
}

//! Get the number of threads per each CTA total.
Val* getNumThreadsEachBlock() const;

//! Get the "compute" parallel dimension on the given ParallelType. In case
//! of no warp specialization, this is the same as getRaw(pt). If we are doing
//! warp specialization on pt, the result is getRaw(pt) - 1, because the last
Expand Down
12 changes: 6 additions & 6 deletions tests/cpp/test_loop_rotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ TEST_F(LoopRotationTest, RotateInner) {
scheduler_utils::rotateLoop(tv4, -1, {tv1, tv2});

const std::string expected_kernel = R"(
__global__ void CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2> T4) {
__global__ void __launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/1) CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2> T4) {
NVFUSER_DEFINE_MAGIC_ZERO;
#pragma unroll 1
for(nvfuser_index_t i0 = 0LL; i0 < T0.logical_size[0LL]; ++i0) {
Expand Down Expand Up @@ -99,7 +99,7 @@ TEST_F(LoopRotationTest, RotateOuter) {
scheduler_utils::rotateLoop(tv4, 0, {tv1, tv2});

const std::string expected_kernel = R"(
__global__ void CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2> T4) {
__global__ void __launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/1) CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2> T4) {
NVFUSER_DEFINE_MAGIC_ZERO;
Array<float, 3LL, 1> T1;
Array<float, 3LL, 1> T2;
Expand Down Expand Up @@ -196,7 +196,7 @@ TEST_F(LoopRotationTest, NonDivisibleSplit) {
scheduler_utils::rotateLoop(tv4, 0, {tv1, tv2});

const std::string expected_kernel = R"(
__global__ void CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2> T4) {
__global__ void __launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/1) CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2> T4) {
NVFUSER_DEFINE_MAGIC_ZERO;
nvfuser_index_t i0;
i0 = T0.logical_size[0LL] * T0.logical_size[1LL];
Expand Down Expand Up @@ -302,7 +302,7 @@ TEST_F(LoopRotationTest, CircularBuffered) {
scheduler_utils::rotateLoop(tv4, 0, {tv2});

const std::string expected_kernel = R"(
__global__ void CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2> T4) {
__global__ void __launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/1) CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2> T4) {
NVFUSER_DEFINE_MAGIC_ZERO;
nvfuser_index_t i0;
i0 = 4LL * T0.alloc_stride[0LL];
Expand Down Expand Up @@ -413,7 +413,7 @@ TEST_F(LoopRotationTest, SelectCircularBufferLoad) {
scheduler_utils::rotateLoop(tv4, 0, {tv1, tv2});

const std::string expected_kernel = R"(
__global__ void CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2> T4) {
__global__ void __launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/1) CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2> T4) {
NVFUSER_DEFINE_MAGIC_ZERO;
nvfuser_index_t i0;
i0 = 4LL * T0.alloc_stride[0LL];
Expand Down Expand Up @@ -563,7 +563,7 @@ TEST_F(LoopRotationTest, MultipleCircularBuffer) {
scheduler_utils::rotateLoop(tv3, 0, {tv1});

const std::string expected_kernel = R"(
__global__ void CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2> T3) {
__global__ void __launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/1) CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2> T3) {
alignas(16) extern __shared__ char array[];
const unsigned smem_offset = 0;
NVFUSER_DEFINE_MAGIC_ZERO;
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/test_scalar_hoisting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ TEST_F(ScalarHoistTest, IndexHoist3) {
auto cg_outputs = ke.run({t0});

const std::string expected_kernel = R"(
__global__ void CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2> T2) {
__global__ void __launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/256) CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2> T2) {
nvfuser_index_t i0;
i0 = ((nvfuser_index_t)threadIdx.x) + (256LL * ((nvfuser_index_t)blockIdx.x));
Tensor<float, 2, 2> s1;
Expand Down Expand Up @@ -374,7 +374,7 @@ TEST_F(ScalarHoistTest, ARange) {
auto cg_outputs = ke.run({start, end, step});

const std::string expected_kernel = R"(
__global__ void CUDAGeneratedKernel(int64_t i0, int64_t i1, int64_t i2, Tensor<int64_t, 1, 1> T0, Tensor<int64_t, 1, 1> T1) {
__global__ void __launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/1) CUDAGeneratedKernel(int64_t i0, int64_t i1, int64_t i2, Tensor<int64_t, 1, 1> T0, Tensor<int64_t, 1, 1> T1) {
int64_t i3;
i3 = i1 - i0;
int64_t i4;
Expand Down