Skip to content

Commit 629f74b

Browse files
authored
Merge pull request #7 from ROCm/integration_no_fp8
Features integration: Custom ops and kernels from private v0.2.7_dllehr Triton attention kernel from jpvillam/v0.3.3_triton Option to run multi GPU using torchrun instead or ray
2 parents 54be8a0 + 5e3ec52 commit 629f74b

20 files changed

+1515
-88
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ _build/
181181
# hip files generated by PyTorch
182182
*.hip
183183
*_hip*
184+
hip_compat.h
184185

185186
# Benchmark dataset
186187
*.json

Dockerfile.rocm

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
33

44
FROM $BASE_IMAGE
55

6-
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
7-
86
RUN echo "Base image is $BASE_IMAGE"
97

108
# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
@@ -26,22 +24,12 @@ ARG BUILD_FA="1"
2624
# whether to build cupy on rocm
2725
ARG BUILD_CUPY="1"
2826

29-
# Install some basic utilities
30-
RUN apt-get update && apt-get install python3 python3-pip -y
27+
# whether to build triton on rocm
28+
ARG BUILD_TRITON="1"
3129

3230
# Install some basic utilities
3331
RUN apt-get update && apt-get install -y \
34-
curl \
35-
ca-certificates \
36-
sudo \
37-
git \
38-
bzip2 \
39-
libx11-6 \
40-
build-essential \
41-
wget \
42-
unzip \
43-
nvidia-cuda-toolkit \
44-
tmux \
32+
sqlite3 libsqlite3-dev libfmt-dev \
4533
&& rm -rf /var/lib/apt/lists/*
4634

4735
### Mount Point ###
@@ -95,6 +83,17 @@ RUN if [ "$BUILD_CUPY" = "1" ]; then \
9583
&& cd ..; \
9684
fi
9785

86+
# build triton
87+
RUN if [ "$BUILD_TRITON" = "1" ]; then \
88+
mkdir -p libs \
89+
&& cd libs \
90+
&& pip uninstall -y triton \
91+
&& git clone https://github.com/ROCm/triton.git \
92+
&& cd triton/python \
93+
&& pip3 install . \
94+
&& cd ../..; \
95+
fi
96+
9897
COPY ./ /app/vllm
9998

10099
RUN python3 -m pip install --upgrade pip
@@ -104,12 +103,13 @@ RUN cd /app \
104103
&& cd vllm \
105104
&& pip install -U -r requirements-rocm.txt \
106105
&& if [ "$BUILD_FA" = "1" ]; then \
107-
bash patch_xformers.rocm.sh; fi \
108-
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
106+
bash patch_xformers.rocm.sh; fi \
107+
&& if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
108+
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch; fi \
109109
&& python3 setup.py install \
110110
&& cd ..
111111

112112
RUN python3 -m pip install --upgrade pip
113-
RUN python3 -m pip install --no-cache-dir ray[all]
113+
RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3
114114

115115
CMD ["/bin/bash"]

benchmarks/benchmark_latency.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,17 @@ def main(args: argparse.Namespace):
1616

1717
# NOTE(woosuk): If the request cannot be processed in a single batch,
1818
# the engine will automatically process the request in multiple batches.
19-
llm = LLM(
20-
model=args.model,
21-
tokenizer=args.tokenizer,
22-
quantization=args.quantization,
23-
tensor_parallel_size=args.tensor_parallel_size,
24-
trust_remote_code=args.trust_remote_code,
25-
dtype=args.dtype,
26-
enforce_eager=args.enforce_eager,
27-
kv_cache_dtype=args.kv_cache_dtype,
28-
device=args.device,
29-
ray_workers_use_nsight=args.ray_workers_use_nsight,
30-
)
19+
llm = LLM(model=args.model,
20+
tokenizer=args.tokenizer,
21+
quantization=args.quantization,
22+
tensor_parallel_size=args.tensor_parallel_size,
23+
trust_remote_code=args.trust_remote_code,
24+
dtype=args.dtype,
25+
enforce_eager=args.enforce_eager,
26+
kv_cache_dtype=args.kv_cache_dtype,
27+
device=args.device,
28+
ray_workers_use_nsight=args.ray_workers_use_nsight,
29+
worker_use_ray=args.worker_use_ray)
3130

3231
sampling_params = SamplingParams(
3332
n=args.n,
@@ -151,5 +150,10 @@ def run_to_completion(profile_dir: Optional[str] = None):
151150
action='store_true',
152151
help="If specified, use nsight to profile ray workers",
153152
)
153+
parser.add_argument('--worker-use-ray',
154+
action='store_true',
155+
help='use Ray for distributed serving, will be '
156+
'automatically set when using more than 1 GPU '
157+
'unless on ROCm where the default is torchrun')
154158
args = parser.parse_args()
155159
main(args)

benchmarks/benchmark_throughput.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,21 +75,25 @@ def run_vllm(
7575
device: str,
7676
enable_prefix_caching: bool,
7777
gpu_memory_utilization: float = 0.9,
78+
worker_use_ray: bool = False,
7879
) -> float:
7980
from vllm import LLM, SamplingParams
80-
llm = LLM(model=model,
81-
tokenizer=tokenizer,
82-
quantization=quantization,
83-
tensor_parallel_size=tensor_parallel_size,
84-
seed=seed,
85-
trust_remote_code=trust_remote_code,
86-
dtype=dtype,
87-
max_model_len=max_model_len,
88-
gpu_memory_utilization=gpu_memory_utilization,
89-
enforce_eager=enforce_eager,
90-
kv_cache_dtype=kv_cache_dtype,
91-
device=device,
92-
enable_prefix_caching=enable_prefix_caching)
81+
llm = LLM(
82+
model=model,
83+
tokenizer=tokenizer,
84+
quantization=quantization,
85+
tensor_parallel_size=tensor_parallel_size,
86+
seed=seed,
87+
trust_remote_code=trust_remote_code,
88+
dtype=dtype,
89+
max_model_len=max_model_len,
90+
gpu_memory_utilization=gpu_memory_utilization,
91+
enforce_eager=enforce_eager,
92+
kv_cache_dtype=kv_cache_dtype,
93+
device=device,
94+
enable_prefix_caching=enable_prefix_caching,
95+
worker_use_ray=worker_use_ray,
96+
)
9397

9498
# Add the requests to the engine.
9599
for prompt, _, output_len in requests:
@@ -213,7 +217,8 @@ def main(args: argparse.Namespace):
213217
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
214218
args.trust_remote_code, args.dtype, args.max_model_len,
215219
args.enforce_eager, args.kv_cache_dtype, args.device,
216-
args.enable_prefix_caching, args.gpu_memory_utilization)
220+
args.enable_prefix_caching, args.gpu_memory_utilization,
221+
args.worker_use_ray)
217222
elif args.backend == "hf":
218223
assert args.tensor_parallel_size == 1
219224
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@@ -314,6 +319,11 @@ def main(args: argparse.Namespace):
314319
"--enable-prefix-caching",
315320
action='store_true',
316321
help="enable automatic prefix caching for vLLM backend.")
322+
parser.add_argument('--worker-use-ray',
323+
action='store_true',
324+
help='use Ray for distributed serving, will be '
325+
'automatically set when using more than 1 GPU '
326+
'unless on ROCm where the default is torchrun')
317327
args = parser.parse_args()
318328
if args.tokenizer is None:
319329
args.tokenizer = args.model

csrc/attention/attention_kernels.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,11 @@ template<
602602
typename CACHE_T,
603603
int BLOCK_SIZE,
604604
bool IS_FP8_E5M2_KV_CACHE,
605+
#ifdef USE_ROCM
606+
int NUM_THREADS = 1024>
607+
#else
605608
int NUM_THREADS = 128>
609+
#endif
606610
void paged_attention_v1_launcher(
607611
torch::Tensor& out,
608612
torch::Tensor& query,
@@ -779,8 +783,13 @@ template<
779783
typename CACHE_T,
780784
int BLOCK_SIZE,
781785
bool IS_FP8_E5M2_KV_CACHE,
786+
#ifdef USE_ROCM
787+
int NUM_THREADS = 1024,
788+
int PARTITION_SIZE = 1024>
789+
#else
782790
int NUM_THREADS = 128,
783791
int PARTITION_SIZE = 512>
792+
#endif
784793
void paged_attention_v2_launcher(
785794
torch::Tensor& out,
786795
torch::Tensor& exp_sums,

csrc/custom/custom.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#include <ATen/cuda/CUDAContext.h>
2+
#include <torch/extension.h>
3+
#include <pybind11/pybind11.h>
4+
#include <cuda_runtime.h>
5+
6+
namespace py = pybind11;
7+
8+
// declare templates for front (cpp) and back (cuda) sides of function:
9+
//template <typename T>
10+
11+
void LLGemm_Silu(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int rows_per_block);
12+
void LLMM_Silu(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int rows_per_block) {
13+
int M = in_a.size(0);
14+
int K = in_a.size(1);
15+
LLGemm_Silu(in_a.data_ptr(), in_b.data_ptr(),
16+
out_c.data_ptr(), M, K, at::cuda::getCurrentCUDAStream(),rows_per_block);
17+
}
18+
19+
void LLGemm1(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream,const int rows_per_block);
20+
21+
//template <typename T>
22+
void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int rows_per_block=4) {
23+
int M = in_a.size(0);
24+
int K = in_a.size(1);
25+
//if (N != in_b.numel())
26+
// throw std::invalid_argument("Size mismatch A.numel(): " + std::to_string(in_a.numel())
27+
// + ", B.numel(): " + std::to_string(in_b.numel()));
28+
29+
//out_c.resize_({N});
30+
31+
// call the kernel function...
32+
LLGemm1(in_a.data_ptr(), in_b.data_ptr(),
33+
out_c.data_ptr(), M, K, at::cuda::getCurrentCUDAStream(),rows_per_block);
34+
}
35+
36+
void LLGemmZZ(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int solidx);
37+
38+
void LLZZ(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int solidx=0) {
39+
int M = in_a.size(0);
40+
int K = in_a.size(1);
41+
42+
LLGemmZZ(in_a.data_ptr(), in_b.data_ptr(),
43+
out_c.data_ptr(), M, K, at::cuda::getCurrentCUDAStream(),solidx);
44+
}
45+
// instantiate the CPP template for T=float:
46+
//template void AddGPU<float>(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c);
47+
48+
49+
void MMGPUKernel(float *in_a, float *in_b, float *out_c,
50+
int numARows, int numAColumns,
51+
int numBRows, int numBColumns,
52+
int numCRows, int numCColumns,
53+
cudaStream_t stream);
54+
55+
56+
void MMCustomGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c) {
57+
auto matA_sizes { in_a.sizes() };
58+
auto matB_sizes { in_b.sizes() };
59+
auto matO_sizes { out_c.sizes() };
60+
MMGPUKernel(in_a.data_ptr<float>(), in_b.data_ptr<float>(), out_c.data_ptr<float>(),
61+
matA_sizes[0], matA_sizes[1],
62+
matB_sizes[0], matB_sizes[1],
63+
matO_sizes[0], matO_sizes[1],
64+
at::cuda::getCurrentCUDAStream());
65+
}
66+
67+
// declare the extension module with the AddGPU function:
68+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
69+
m.doc() = "pybind11 example plugin";
70+
m.def("LLMM1", &LLMM1);
71+
m.def("LLMM_Silu", &LLMM_Silu);
72+
m.def("LLZZ", &LLZZ);
73+
//m.def("MMCustomGPU", &MMCustomGPU);
74+
}

0 commit comments

Comments
 (0)