diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index f57e69b168f9..cbd16b688e9b 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -296,6 +296,14 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3], /*blockDimY=*/1, /*blockDimZ=*/1, shared_mem_bytes_, stream, params, /*extra=*/nullptr)); } + CUlaunchAttribute launch_attrs[2]; + launch_attrs[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + launch_attrs[0].value.clusterDim.x = cluster_dims_[0]; + launch_attrs[0].value.clusterDim.y = cluster_dims_[1]; + launch_attrs[0].value.clusterDim.z = cluster_dims_[2]; + launch_attrs[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; + launch_attrs[1].value.clusterSchedulingPolicyPreference = + CU_CLUSTER_SCHEDULING_POLICY_SPREAD; CUlaunchConfig launch_config = { /*gridDimX=*/grid[0] * cluster_dims_[0], /*gridDimY=*/grid[1] * cluster_dims_[1], @@ -305,8 +313,8 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3], /*blockDimZ=*/1, /*sharedMemBytes=*/shared_mem_bytes_, /*hStream=*/stream, - /**attrs=*/nullptr, // TODO(giorgioa): Add attrs for block clusters. - /*numAttrs=*/0, + /**attrs=*/launch_attrs, + /*numAttrs=*/2, }; return JAX_AS_STATUS( cuLaunchKernelEx(&launch_config, kernel, params, /*extra=*/nullptr));