Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features integration without fp8 #7

Merged
merged 39 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
9c2367c
[ROCm] Fixup arch checks for ROCM
dllehr-amd Jan 27, 2024
a9d752c
yapf cleanup
dllehr-amd Jan 27, 2024
20b5f10
Initial port of gradlib gemm tuner
dllehr-amd Feb 10, 2024
6f28107
Enable torchrun vs Ray
dllehr-amd Feb 11, 2024
184806e
Add custom matvec kernels and sampler matmul call tuned_gemm
dllehr-amd Feb 11, 2024
af9e9d1
Add silu gemm fusion when batch and seq_len = 1
dllehr-amd Feb 14, 2024
5f8eac3
Add tunable flags to VLLM
dllehr-amd Feb 14, 2024
22766b4
Allow benchmark_latency to take a list of input/output/batches for fa…
dllehr-amd Feb 14, 2024
87b4c1b
Add dynamic tuning feature to vllm
dllehr-amd Feb 15, 2024
694ae1d
Add rpd tracer controls to benchmark_latency.py
dllehr-amd Feb 15, 2024
0e73aed
Fix Dockerfile errors
dllehr-amd Feb 15, 2024
90df0c9
Add llama2 run script
dllehr-amd Feb 16, 2024
ab67280
Increase Partition and Num threads for attention blocks
dllehr-amd Feb 16, 2024
1d53722
Fix WORKDIR
dllehr-amd Feb 22, 2024
5148aa5
Add accuracy flag to benchmark_latency.py
dllehr-amd Feb 22, 2024
534dcff
Don't broadcast when using torchrun
dllehr-amd Feb 26, 2024
e569133
Adding new rocm triton flash attention kernel
jpvillam-amd Mar 15, 2024
8cb05bc
Merge remote-tracking branch 'origin/main' into v0.3.3_greg
gshtras Mar 18, 2024
be708d0
Removed gradlib and its tuned gemm in favor of tunable ops
gshtras Mar 18, 2024
0e63661
Small fix on dockerfile
jpvillam-amd Mar 19, 2024
c89c0e3
Merge remote-tracking branch 'origin/main' into jpvillam/v0.3.3_triton
jpvillam-amd Mar 19, 2024
d4cb905
Rebase updates and PR review changes
jpvillam-amd Mar 19, 2024
bc750fa
Introducing torchrun multi GPU support
gshtras Mar 20, 2024
a83b7ea
Update dockerfile
gshtras Mar 21, 2024
2ff59d7
Merge branch 'greg/torchrun' into integration_no_fp8
gshtras Mar 21, 2024
04fd3fb
Merge remote-tracking branch 'origin/jpvillam/v0.3.3_triton' into int…
gshtras Mar 21, 2024
e01b8cd
add use case for custom kernel for matvec operation
charlifu Mar 21, 2024
eb21ad7
limit the custom kernel under is_hip
charlifu Mar 21, 2024
1bf736f
fix custom kernel
charlifu Mar 21, 2024
7ab4a24
Rocm defaults and cleanup
gshtras Mar 21, 2024
fbea667
Remove ignored file
gshtras Mar 21, 2024
c8fce27
Refactor torchrun executor to reuse single gpu executor code
gshtras Mar 21, 2024
1fff99a
Added interleaving for MQA for triton kernel
Mar 22, 2024
0a2309a
linter
gshtras Mar 22, 2024
1ec6554
Merge remote-tracking branch 'origin/jpvillam/v0.3.3_triton' into int…
gshtras Mar 22, 2024
44c2cee
Making torchrun the default multi GPU executor on ROCm unless overrid…
gshtras Mar 22, 2024
1256bee
Make triton the default FA
jpvillam-amd Mar 22, 2024
b687795
Make workaround only applicable to triton path
jpvillam-amd Mar 22, 2024
5e3ec52
Pin ray version to 2.9.3
gshtras Mar 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ _build/
# hip files generated by PyTorch
*.hip
*_hip*
hip_compat.h

# Benchmark dataset
*.json
36 changes: 18 additions & 18 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"

FROM $BASE_IMAGE

ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"

RUN echo "Base image is $BASE_IMAGE"

# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
Expand All @@ -26,22 +24,12 @@ ARG BUILD_FA="1"
# whether to build cupy on rocm
ARG BUILD_CUPY="1"

# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y
# whether to build triton on rocm
ARG BUILD_TRITON="1"

# Install some basic utilities
RUN apt-get update && apt-get install -y \
curl \
ca-certificates \
sudo \
git \
bzip2 \
libx11-6 \
build-essential \
wget \
unzip \
nvidia-cuda-toolkit \
tmux \
sqlite3 libsqlite3-dev libfmt-dev \
&& rm -rf /var/lib/apt/lists/*

### Mount Point ###
Expand Down Expand Up @@ -95,6 +83,17 @@ RUN if [ "$BUILD_CUPY" = "1" ]; then \
&& cd ..; \
fi

# build triton
RUN if [ "$BUILD_TRITON" = "1" ]; then \
mkdir -p libs \
&& cd libs \
&& pip uninstall -y triton \
&& git clone https://github.com/ROCm/triton.git \
&& cd triton/python \
&& pip3 install . \
&& cd ../..; \
fi

COPY ./ /app/vllm

RUN python3 -m pip install --upgrade pip
Expand All @@ -104,12 +103,13 @@ RUN cd /app \
&& cd vllm \
&& pip install -U -r requirements-rocm.txt \
&& if [ "$BUILD_FA" = "1" ]; then \
bash patch_xformers.rocm.sh; fi \
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
bash patch_xformers.rocm.sh; fi \
&& if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch; fi \
&& python3 setup.py install \
&& cd ..

RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir ray[all]
RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3

CMD ["/bin/bash"]
28 changes: 16 additions & 12 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,17 @@ def main(args: argparse.Namespace):

# NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches.
llm = LLM(
model=args.model,
tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
)
llm = LLM(model=args.model,
tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
worker_use_ray=args.worker_use_ray)

sampling_params = SamplingParams(
n=args.n,
Expand Down Expand Up @@ -151,5 +150,10 @@ def run_to_completion(profile_dir: Optional[str] = None):
action='store_true',
help="If specified, use nsight to profile ray workers",
)
parser.add_argument('--worker-use-ray',
action='store_true',
help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU '
'unless on ROCm where the default is torchrun')
args = parser.parse_args()
main(args)
38 changes: 24 additions & 14 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,25 @@ def run_vllm(
device: str,
enable_prefix_caching: bool,
gpu_memory_utilization: float = 0.9,
worker_use_ray: bool = False,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
device=device,
enable_prefix_caching=enable_prefix_caching)
llm = LLM(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
device=device,
enable_prefix_caching=enable_prefix_caching,
worker_use_ray=worker_use_ray,
)

# Add the requests to the engine.
for prompt, _, output_len in requests:
Expand Down Expand Up @@ -213,7 +217,8 @@ def main(args: argparse.Namespace):
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype, args.device,
args.enable_prefix_caching, args.gpu_memory_utilization)
args.enable_prefix_caching, args.gpu_memory_utilization,
args.worker_use_ray)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -314,6 +319,11 @@ def main(args: argparse.Namespace):
"--enable-prefix-caching",
action='store_true',
help="enable automatic prefix caching for vLLM backend.")
parser.add_argument('--worker-use-ray',
action='store_true',
help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU '
'unless on ROCm where the default is torchrun')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
9 changes: 9 additions & 0 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,11 @@ template<
typename CACHE_T,
int BLOCK_SIZE,
bool IS_FP8_E5M2_KV_CACHE,
#ifdef USE_ROCM
int NUM_THREADS = 1024>
#else
int NUM_THREADS = 128>
#endif
void paged_attention_v1_launcher(
torch::Tensor& out,
torch::Tensor& query,
Expand Down Expand Up @@ -779,8 +783,13 @@ template<
typename CACHE_T,
int BLOCK_SIZE,
bool IS_FP8_E5M2_KV_CACHE,
#ifdef USE_ROCM
int NUM_THREADS = 1024,
int PARTITION_SIZE = 1024>
#else
int NUM_THREADS = 128,
int PARTITION_SIZE = 512>
#endif
void paged_attention_v2_launcher(
torch::Tensor& out,
torch::Tensor& exp_sums,
Expand Down
74 changes: 74 additions & 0 deletions csrc/custom/custom.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <pybind11/pybind11.h>
#include <cuda_runtime.h>

namespace py = pybind11;

// declare templates for front (cpp) and back (cuda) sides of function:
//template <typename T>

void LLGemm_Silu(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int rows_per_block);
void LLMM_Silu(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int rows_per_block) {
int M = in_a.size(0);
int K = in_a.size(1);
LLGemm_Silu(in_a.data_ptr(), in_b.data_ptr(),
out_c.data_ptr(), M, K, at::cuda::getCurrentCUDAStream(),rows_per_block);
}

void LLGemm1(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream,const int rows_per_block);

//template <typename T>
void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int rows_per_block=4) {
int M = in_a.size(0);
int K = in_a.size(1);
//if (N != in_b.numel())
// throw std::invalid_argument("Size mismatch A.numel(): " + std::to_string(in_a.numel())
// + ", B.numel(): " + std::to_string(in_b.numel()));

//out_c.resize_({N});

// call the kernel function...
LLGemm1(in_a.data_ptr(), in_b.data_ptr(),
out_c.data_ptr(), M, K, at::cuda::getCurrentCUDAStream(),rows_per_block);
}

void LLGemmZZ(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int solidx);

void LLZZ(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int solidx=0) {
int M = in_a.size(0);
int K = in_a.size(1);

LLGemmZZ(in_a.data_ptr(), in_b.data_ptr(),
out_c.data_ptr(), M, K, at::cuda::getCurrentCUDAStream(),solidx);
}
// instantiate the CPP template for T=float:
//template void AddGPU<float>(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c);


void MMGPUKernel(float *in_a, float *in_b, float *out_c,
int numARows, int numAColumns,
int numBRows, int numBColumns,
int numCRows, int numCColumns,
cudaStream_t stream);


void MMCustomGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c) {
auto matA_sizes { in_a.sizes() };
auto matB_sizes { in_b.sizes() };
auto matO_sizes { out_c.sizes() };
MMGPUKernel(in_a.data_ptr<float>(), in_b.data_ptr<float>(), out_c.data_ptr<float>(),
matA_sizes[0], matA_sizes[1],
matB_sizes[0], matB_sizes[1],
matO_sizes[0], matO_sizes[1],
at::cuda::getCurrentCUDAStream());
}

// declare the extension module with the AddGPU function:
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
m.doc() = "pybind11 example plugin";
m.def("LLMM1", &LLMM1);
m.def("LLMM_Silu", &LLMM_Silu);
m.def("LLZZ", &LLZZ);
//m.def("MMCustomGPU", &MMCustomGPU);
}
Loading
Loading