Skip to content

Commit

Permalink
[triton] Add clustering support and test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 605007757
  • Loading branch information
jax authors committed Feb 7, 2024
1 parent ddd668e commit 205a209
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions jaxlib/gpu/triton_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,14 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3],
/*blockDimY=*/1, /*blockDimZ=*/1, shared_mem_bytes_, stream, params,
/*extra=*/nullptr));
}
CUlaunchAttribute launch_attrs[2];
launch_attrs[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
launch_attrs[0].value.clusterDim.x = cluster_dims_[0];
launch_attrs[0].value.clusterDim.y = cluster_dims_[1];
launch_attrs[0].value.clusterDim.z = cluster_dims_[2];
launch_attrs[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
launch_attrs[1].value.clusterSchedulingPolicyPreference =
CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
CUlaunchConfig launch_config = {
/*gridDimX=*/grid[0] * cluster_dims_[0],
/*gridDimY=*/grid[1] * cluster_dims_[1],
Expand All @@ -305,8 +313,8 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3],
/*blockDimZ=*/1,
/*sharedMemBytes=*/shared_mem_bytes_,
/*hStream=*/stream,
/**attrs=*/nullptr, // TODO(giorgioa): Add attrs for block clusters.
/*numAttrs=*/0,
/**attrs=*/launch_attrs,
/*numAttrs=*/2,
};
return JAX_AS_STATUS(
cuLaunchKernelEx(&launch_config, kernel, params, /*extra=*/nullptr));
Expand Down

0 comments on commit 205a209

Please sign in to comment.