From b8851e6c35706d90efcc1aad2b6b0ac18c1d4aca Mon Sep 17 00:00:00 2001 From: Hui Liu <96135754+hliuca@users.noreply.github.com> Date: Fri, 15 Nov 2024 18:46:28 +0000 Subject: [PATCH 1/2] fix cuda compilation --- csrc/activation_kernels.cu | 6 ++++++ vllm/model_executor/layers/tuned_gemm.py | 22 ++++++++++++++++------ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 967f64a0f6b..94a1c4a678c 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -28,6 +28,7 @@ __global__ void act_and_mul_kernel( } // Scaled activation and gating kernel template. +#ifdef USE_ROCM template __global__ void scaled_act_and_mul_kernel( c10::Float8_e4m3fnuz* __restrict__ out, // [..., d] @@ -42,6 +43,7 @@ __global__ void scaled_act_and_mul_kernel( hip_fp8(r).data, c10::Float8_e4m3fnuz::from_bits()); } } +#endif template __device__ __forceinline__ T silu_kernel(const T& x) { @@ -90,6 +92,7 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { }); // Launch activation and gating kernel. +#ifdef USE_ROCM #define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \ int d = input.size(-1) / 2; \ int64_t num_tokens = input.numel() / input.size(-1); \ @@ -104,6 +107,7 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { input.data_ptr(), d, \ 1.0 / (*scale.data_ptr())); \ }); +#endif void silu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] @@ -114,7 +118,9 @@ void silu_and_mul(torch::Tensor& out, // [..., d] void scaled_silu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., 2 * d] torch::Tensor& scale) { +#ifdef USE_ROCM LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); +#endif } void gelu_and_mul(torch::Tensor& out, // [..., d] diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 0595ff83be2..69dfdccb314 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -4,14 +4,15 @@ import pandas as pd import torch import torch.nn.functional as F -from hipbsolidxgemm import hipb_create_extension, hipb_mm -from rocsolidxgemm import rocb_create_extension, rocb_mm from vllm import _custom_ops as ops from vllm.envs import VLLM_USE_ROCM_SKINNY_GEMM from vllm.platforms import current_platform from vllm.utils import is_navi +if current_platform.is_rocm(): + from hipbsolidxgemm import hipb_create_extension, hipb_mm + from rocsolidxgemm import rocb_create_extension, rocb_mm class TunedGemm: @@ -26,8 +27,12 @@ def __init__(self): self.bestsols = {} self.load_best_sols() self.create_ds() - self.cu_count = torch.cuda.get_device_properties( - device='cuda').multi_processor_count + + if current_platform.is_rocm(): + self.cu_count = torch.cuda.get_device_properties( + device='cuda').multi_processor_count + else: + self.cu_count = -1 self.use_skinny = (current_platform.is_rocm() and VLLM_USE_ROCM_SKINNY_GEMM and not is_navi()) @@ -81,6 +86,9 @@ def apply_skinny(self, m, n, k, inp_view, weights): return None def mm(self, inp, weights, bias=None): + if not current_platform.is_rocm(): + return F.linear(inp, weights, bias) + # F.Linear can take a 3 dimensional input. vllm # uses this for linear units. However, sampler # will use torch.matmul with 2 dimensions only @@ -94,9 +102,11 @@ def mm(self, inp, weights, bias=None): inp_view = inp batched = False if self.extensions_created is False: - rocb_create_extension() - hipb_create_extension() + if current_platform.is_rocm(): + rocb_create_extension() + hipb_create_extension() self.extensions_created = True + m = weights.shape[0] n = inp_view.shape[0] k = inp_view.shape[1] From 63efef32e2b087fea45a196ce6e3e0f74532e826 Mon Sep 17 00:00:00 2001 From: Hui Liu <96135754+hliuca@users.noreply.github.com> Date: Sat, 16 Nov 2024 18:01:36 -0800 Subject: [PATCH 2/2] checkout tuned gemm from develop --- vllm/model_executor/layers/tuned_gemm.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 877c0bea80d..a441ca5def0 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -24,6 +24,7 @@ def hipb_mm(inp, weights, solidx, bias=None): def rocb_mm(inp, weights, solidx): return torch.ops._gradlib_C.rocb_mm(inp, weights, solidx) + class TunedGemm: def __init__(self): @@ -35,12 +36,8 @@ def __init__(self): self.bestsols = {} self.load_best_sols() self.create_ds() - - if current_platform.is_rocm(): - self.cu_count = torch.cuda.get_device_properties( - device='cuda').multi_processor_count - else: - self.cu_count = -1 + self.cu_count = torch.cuda.get_device_properties( + device='cuda').multi_processor_count self.use_skinny = (current_platform.is_rocm() and VLLM_USE_ROCM_SKINNY_GEMM and not is_navi()) @@ -112,7 +109,6 @@ def mm(self, inp, weights, bias=None): torch.ops._gradlib_C.rocb_create_extension() torch.ops._gradlib_C.hipb_create_extension() self.extensions_created = True - m = weights.shape[0] n = inp_view.shape[0] k = inp_view.shape[1]