Skip to content

Commit

Permalink
Feature add openai api for vllm integration (pytorch#3287)
Browse files Browse the repository at this point in the history
* Forward additional url segments as url_paths in request header to model

* Fix vllm test and clean preproc

* First attept to enable OpenAI api for models served via vllm

* fix streaming in openai api

* Add OpenAIServingCompletion usage example

* Add lora modules to vllm engine

* Finish openai completion integration; removed req openai client; updated lora example to llama 3.1

* fix lint

* Update mistral + llama3 vllm example

* Remove openai client from url path test

* Add openai chat api to vllm example

* Added v1/models endpoint for vllm example

* Remove accidential breakpoint()

* Add comment to new url_path
  • Loading branch information
mreso authored Aug 23, 2024
1 parent 391ee4c commit db1a003
Show file tree
Hide file tree
Showing 16 changed files with 633 additions and 149 deletions.
80 changes: 71 additions & 9 deletions examples/large_models/utils/test_llm_streaming_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,67 @@ def _predict(self):
combined_text = ""
for chunk in response.iter_content(chunk_size=None):
if chunk:
data = json.loads(chunk)
text = self._extract_text(chunk)
if self.args.demo_streaming:
print(data["text"], end="", flush=True)
print(text, end="", flush=True)
else:
combined_text += data.get("text", "")
combined_text += text
if not self.args.demo_streaming:
self.queue.put_nowait(f"payload={payload}\n, output={combined_text}\n")

def _extract_completion(self, chunk):
chunk = chunk.decode("utf-8")
if chunk.startswith("data:"):
chunk = chunk[len("data:") :].split("\n")[0].strip()
if chunk.startswith("[DONE]"):
return ""
return json.loads(chunk)["choices"][0]["text"]

def _extract_chat(self, chunk):
chunk = chunk.decode("utf-8")
if chunk.startswith("data:"):
chunk = chunk[len("data:") :].split("\n")[0].strip()
if chunk.startswith("[DONE]"):
return ""
try:
return json.loads(chunk)["choices"][0].get("message", {})["content"]
except KeyError:
return json.loads(chunk)["choices"][0].get("delta", {}).get("content", "")

def _extract_text(self, chunk):
if self.args.openai_api:
if "chat" in self.args.api_endpoint:
return self._extract_chat(chunk)
else:
return self._extract_completion(chunk)
else:
return json.loads(chunk).get("text", "")

def _get_url(self):
return f"http://localhost:8080/predictions/{self.args.model}"
if self.args.openai_api:
return f"http://localhost:8080/predictions/{self.args.model}/{self.args.model_version}/{self.args.api_endpoint}"
else:
return f"http://localhost:8080/predictions/{self.args.model}"

def _format_payload(self):
prompt_input = _load_curl_like_data(self.args.prompt_text)
if "chat" in self.args.api_endpoint:
assert self.args.prompt_json, "Use prompt json file for chat interface"
assert self.args.openai_api, "Chat only work with openai api"
prompt_input = json.loads(prompt_input)
messages = prompt_input.get("messages", None)
assert messages is not None
rt = int(prompt_input.get("max_tokens", self.args.max_tokens))
prompt_input["max_tokens"] = rt
if self.args.demo_streaming:
prompt_input["stream"] = True
return prompt_input
if self.args.prompt_json:
prompt_input = json.loads(prompt_input)
prompt = prompt_input.get("prompt", None)
assert prompt is not None
prompt_list = prompt.split(" ")
rt = int(prompt_input.get("max_new_tokens", self.args.max_tokens))
rt = int(prompt_input.get("max_tokens", self.args.max_tokens))
else:
prompt_list = prompt_input.split(" ")
rt = self.args.max_tokens
Expand All @@ -58,13 +100,15 @@ def _format_payload(self):
cur_prompt = " ".join(prompt_list)
if self.args.prompt_json:
prompt_input["prompt"] = cur_prompt
prompt_input["max_new_tokens"] = rt
return prompt_input
prompt_input["max_tokens"] = rt
else:
return {
prompt_input = {
"prompt": cur_prompt,
"max_new_tokens": rt,
"max_tokens": rt,
}
if self.args.demo_streaming and self.args.openai_api:
prompt_input["stream"] = True
return prompt_input


def _load_curl_like_data(text):
Expand Down Expand Up @@ -136,6 +180,24 @@ def parse_args():
default=False,
help="Demo streaming response, force num-requests-per-thread=1 and num-threads=1",
)
parser.add_argument(
"--openai-api",
action=argparse.BooleanOptionalAction,
default=False,
help="Use OpenAI compatible API",
)
parser.add_argument(
"--api-endpoint",
type=str,
default="v1/completions",
help="OpenAI endpoint suffix",
)
parser.add_argument(
"--model-version",
type=str,
default="1.0",
help="Model vesion. Default: 1.0",
)

return parser.parse_args()

Expand Down
11 changes: 8 additions & 3 deletions examples/large_models/vllm/llama3/Readme.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Example showing inference with vLLM on LoRA model

This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on model `meta-llama/Meta-Llama-3-8B-Instruct` with continuous batching.
This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on model `meta-llama/Meta-Llama-3.1-8B-Instruct` with continuous batching.
This examples supports distributed inference by following [this instruction](../Readme.md#distributed-inference)

### Step 0: Install vLLM
Expand All @@ -21,7 +21,7 @@ huggingface-cli login --token $HUGGINGFACE_TOKEN
```

```bash
python ../../utils/Download_model.py --model_path model --model_name meta-llama/Meta-Llama-3-8B-Instruct --use_auth_token True
python ../../utils/Download_model.py --model_path model --model_name meta-llama/Meta-Llama-3.1-8B-Instruct --use_auth_token True
```

### Step 2: Generate model artifacts
Expand All @@ -47,7 +47,12 @@ torchserve --start --ncs --ts-config ../config.properties --model-store model_st
```

### Step 5: Run inference
Run a text completion:
```bash
python ../../utils/test_llm_streaming_response.py -m llama3-8b -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json --openai-api
```

Or use the chat interface:
```bash
python ../../utils/test_llm_streaming_response.py -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json
python ../../utils/test_llm_streaming_response.py -m llama3-8b -o 50 -t 2 -n 4 --prompt-text "@chat.json" --prompt-json --openai-api --demo-streaming --api-endpoint "v1/chat/completions"
```
11 changes: 11 additions & 0 deletions examples/large_models/vllm/llama3/chat.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"model": "llama3-8b",
"messages":[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020?"},
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
{"role": "user", "content": "Where was it played?"}
],
"temperature":0.0,
"max_tokens": 50
}
5 changes: 4 additions & 1 deletion examples/large_models/vllm/llama3/model-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ deviceType: "gpu"
asyncCommunication: true

handler:
model_path: "model/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/"
model_path: "model/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/8c22764a7e3675c50d4c7c9a4edb474456022b16"
vllm_engine_config:
max_num_seqs: 16
max_model_len: 250
served_model_name:
- "meta-llama/Meta-Llama-3.1-8B"
- "llama3-8b"
4 changes: 1 addition & 3 deletions examples/large_models/vllm/llama3/prompt.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
{
"prompt": "A robot may not injure a human being",
"max_new_tokens": 50,
"temperature": 0.8,
"logprobs": 1,
"prompt_logprobs": 1,
"max_tokens": 128,
"adapter": "adapter_1"
"model": "llama3-8b"
}
45 changes: 36 additions & 9 deletions examples/large_models/vllm/lora/Readme.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Example showing inference with vLLM on LoRA model

This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on model `Llama-2-7b-hf` + LoRA model `llama-2-7b-sql-lora-test` with continuous batching.
This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on model `meta-llama/Meta-Llama-3.1-8B` + LoRA model `llama-duo/llama3.1-8b-summarize-gpt4o-128k` with continuous batching.
This examples supports distributed inference by following [this instruction](../Readme.md#distributed-inference)

### Step 0: Install vLLM
Expand All @@ -21,9 +21,9 @@ huggingface-cli login --token $HUGGINGFACE_TOKEN
```

```bash
python ../../utils/Download_model.py --model_path model --model_name meta-llama/Llama-2-7b-chat-hf --use_auth_token True
python ../../utils/Download_model.py --model_path model --model_name meta-llama/Meta-Llama-3.1-8B --use_auth_token True
mkdir adapters && cd adapters
python ../../../utils/Download_model.py --model_path model --model_name yard1/llama-2-7b-sql-lora-test --use_auth_token True
python ../../../utils/Download_model.py --model_path model --model_name llama-duo/llama3.1-8b-summarize-gpt4o-128k --use_auth_token True
cd ..
```

Expand All @@ -32,26 +32,53 @@ cd ..
Add the downloaded path to "model_path:" and "adapter_1:" in `model-config.yaml` and run the following.

```bash
torch-model-archiver --model-name llama-7b-lora --version 1.0 --handler vllm_handler --config-file model-config.yaml --archive-format no-archive
mv model llama-7b-lora
mv adapters llama-7b-lora
torch-model-archiver --model-name llama-8b-lora --version 1.0 --handler vllm_handler --config-file model-config.yaml --archive-format no-archive
mv model llama-8b-lora
mv adapters llama-8b-lora
```

### Step 3: Add the model artifacts to model store

```bash
mkdir model_store
mv llama-7b-lora model_store
mv llama-8b-lora model_store
```

### Step 4: Start torchserve

```bash
torchserve --start --ncs --ts-config ../config.properties --model-store model_store --models llama-7b-lora --disable-token-auth --enable-model-api
torchserve --start --ncs --ts-config ../config.properties --model-store model_store --models llama-8b-lora --disable-token-auth --enable-model-api
```

### Step 5: Run inference
The vllm integration uses an OpenAI compatible interface which lets you perform inference with curl or the openai library client and supports streaming.

Curl:
```bash
python ../../utils/test_llm_streaming_response.py -m lora -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json
curl --header "Content-Type: application/json" --request POST --data @prompt.json http://localhost:8080/predictions/llama-8b-lora/1.0/v1
```

Python + Request:
```bash
python ../../utils/test_llm_streaming_response.py -m llama-8b-lora -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json --openai-api --demo-streaming
```

OpenAI client:
```python
from openai import OpenAI
model_name = "llama-8b-lora"
stream=True
openai_api_key = "EMPTY"
openai_api_base = f"http://localhost:8080/predictions/{model_name}/1.0/v1"

client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)

response = client.completions.create(
model=model_name, prompt="Hello world", temperature=0.0, stream=stream
)
for chunk in reponse:
print(f"{chunk=}")
```
8 changes: 6 additions & 2 deletions examples/large_models/vllm/lora/model-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@ deviceType: "gpu"
asyncCommunication: true

handler:
model_path: "model/models--meta-llama--Llama-2-7b-chat-hf/snapshots/f5db02db724555f92da89c216ac04704f23d4590/"
model_path: "model/models--meta-llama--Meta-Llama-3.1-8B/snapshots/48d6d0fc4e02fb1269b36940650a1b7233035cbb"
vllm_engine_config:
enable_lora: true
max_loras: 4
max_cpu_loras: 4
max_lora_rank: 32
max_num_seqs: 16
max_model_len: 250
served_model_name:
- "meta-llama/Meta-Llama-3.1-8B"
- "llama-8b-lora"

adapters:
adapter_1: "adapters/model/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/"
adapter_1: "adapters/model/models--llama-duo--llama3.1-8b-summarize-gpt4o-128k/snapshots/4ba83353f24fa38946625c8cc49bf21c80a22825"
8 changes: 3 additions & 5 deletions examples/large_models/vllm/lora/prompt.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
{
"model": "adapter_1",
"prompt": "A robot may not injure a human being",
"max_new_tokens": 50,
"temperature": 0.8,
"temperature": 0.0,
"logprobs": 1,
"prompt_logprobs": 1,
"max_tokens": 128,
"adapter": "adapter_1"
"max_tokens": 128
}
2 changes: 1 addition & 1 deletion examples/large_models/vllm/mistral/Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ torchserve --start --ncs --ts-config ../config.properties --model-store model_st
### Step 5: Run inference

```bash
python ../../utils/test_llm_streaming_response.py -m mistral -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json
python ../../utils/test_llm_streaming_response.py -m mistral -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json --openai-api
```
2 changes: 2 additions & 0 deletions examples/large_models/vllm/mistral/model-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ handler:
max_model_len: 250
max_num_seqs: 16
tensor_parallel_size: 4
served_model_name:
- "mistral"
3 changes: 1 addition & 2 deletions examples/large_models/vllm/mistral/prompt.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
{
"model": "mistral",
"prompt": "A robot may not injure a human being",
"max_new_tokens": 50,
"temperature": 0.8,
"logprobs": 1,
"prompt_logprobs": 1,
"max_tokens": 128
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,21 @@ private void handlePredictions(

String modelVersion = null;

if (segments.length == 4) {
if (segments.length >= 4) {
modelVersion = segments[3];
}
req.headers().add("url_path", "");
/**
* If url provides more segments as model_name/version we provide these as url_path in the
* request header This way users can leverage them in the custom handler to e.g. influence
* handler behavior
*/
if (segments.length > 4) {
String joinedSegments =
String.join("/", Arrays.copyOfRange(segments, 4, segments.length));
req.headers().add("url_path", joinedSegments);
}

req.headers().add("explain", "False");
if (explain) {
req.headers().add("explain", "True");
Expand Down
Loading

0 comments on commit db1a003

Please sign in to comment.