Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cuda compile fix2 #284

Merged
merged 6 commits into from
Nov 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ __global__ void act_and_mul_kernel(
}

// Scaled activation and gating kernel template.
#ifdef USE_ROCM
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void scaled_act_and_mul_kernel(
c10::Float8_e4m3fnuz* __restrict__ out, // [..., d]
Expand All @@ -42,6 +43,7 @@ __global__ void scaled_act_and_mul_kernel(
hip_fp8(r).data, c10::Float8_e4m3fnuz::from_bits());
}
}
#endif

template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
Expand Down Expand Up @@ -90,6 +92,7 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
});

// Launch activation and gating kernel.
#ifdef USE_ROCM
#define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
Expand All @@ -104,6 +107,7 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
input.data_ptr<scalar_t>(), d, \
1.0 / (*scale.data_ptr<float>())); \
});
#endif

void silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
Expand All @@ -114,7 +118,9 @@ void silu_and_mul(torch::Tensor& out, // [..., d]
void scaled_silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
torch::Tensor& scale) {
#ifdef USE_ROCM
LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
#endif
}

void gelu_and_mul(torch::Tensor& out, // [..., d]
Expand Down