diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 967f64a0f6b4c..94a1c4a678cf7 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -28,6 +28,7 @@ __global__ void act_and_mul_kernel( } // Scaled activation and gating kernel template. +#ifdef USE_ROCM template __global__ void scaled_act_and_mul_kernel( c10::Float8_e4m3fnuz* __restrict__ out, // [..., d] @@ -42,6 +43,7 @@ __global__ void scaled_act_and_mul_kernel( hip_fp8(r).data, c10::Float8_e4m3fnuz::from_bits()); } } +#endif template __device__ __forceinline__ T silu_kernel(const T& x) { @@ -90,6 +92,7 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { }); // Launch activation and gating kernel. +#ifdef USE_ROCM #define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \ int d = input.size(-1) / 2; \ int64_t num_tokens = input.numel() / input.size(-1); \ @@ -104,6 +107,7 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { input.data_ptr(), d, \ 1.0 / (*scale.data_ptr())); \ }); +#endif void silu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] @@ -114,7 +118,9 @@ void silu_and_mul(torch::Tensor& out, // [..., d] void scaled_silu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., 2 * d] torch::Tensor& scale) { +#ifdef USE_ROCM LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); +#endif } void gelu_and_mul(torch::Tensor& out, // [..., d]