Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set __launch_bounds__ in kernel whenever we are able #3794

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
6 changes: 4 additions & 2 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,15 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
const std::string& kernel_name,
std::optional<int64_t> num_threads_per_cta) {
code_ << "__global__ void ";
if (num_threads_per_cta.has_value()) {
code_ << "__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/"
<< num_threads_per_cta.value() << ") ";
}
if (kernel_->hasManaged("enable_register_sharing") &&
kernel_->getManaged<bool>("enable_register_sharing")) {
NVF_ERROR(
num_threads_per_cta.has_value(),
"__launch_bounds__ must be set for register sharing warp specialization");
code_ << "__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/"
<< num_threads_per_cta.value() << ") ";
}
if (kernel_->hasManaged("cluster_dims")) {
auto cluster_dims =
Expand Down