Skip to content

Commit

Permalink
Merge pull request #44 from BatsResearch/integrate-vllm
Browse files Browse the repository at this point in the history
Integrate vLLM
  • Loading branch information
dotpyu authored Jul 19, 2023
2 parents 09f1ac2 + 53431db commit 653649f
Show file tree
Hide file tree
Showing 45 changed files with 116 additions and 130 deletions.
5 changes: 4 additions & 1 deletion alfred/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
assert self.model_type in [
"huggingface", "huggingfacevlm",
"onnx", "tensorrt",
"flexgen",
"flexgen", "vllm",
"openai", "anthropic",
"cohere", "ai21",
"torch", "dummy"
Expand Down Expand Up @@ -169,6 +169,9 @@ def __init__(
elif self.model_type == "flexgen":
from alfred.fm.flexgen import FlexGenModel
self.model = FlexGenModel(self.model, **kwargs)
elif self.model_type == "vllm":
from alfred.fm.vllm import vLLMModel
self.model = vLLMModel(self.model, **kwargs)
elif self.model_type == "tensorrt":
# self.model = TensorRTModel(self.model, **kwargs)
raise NotImplementedError
Expand Down
57 changes: 57 additions & 0 deletions alfred/fm/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging
from typing import List, Any

from alfred.fm.model import LocalAccessFoundationModel
from .response import CompletionResponse

logger = logging.getLogger(__name__)


import torch
try:
from vllm import LLM, SamplingParams
except ImportError:
raise ImportError("Please install VLLM with `pip install vllm`")



class vLLMModel(LocalAccessFoundationModel):
"""
vLLMModel wraps a vLLM model. vLLM is a fast and easy-to-use library for LLM inference.
source: https://github.com/vllm-project/vllm
"""

def __init__(self, model: str, model_string: str, local_dir: str = None,
**kwargs: Any):
"""
Initialize a VLLM with MultiGPU.
:param model: (optional) The path to the model.
:type model: str
"""
self.model_string = model
super().__init__(model_string)
self.gpu_count = torch.cuda.device_count()
self.model = LLM(local_dir if local_dir is not None else model, tensor_parallel_size=self.gpu_count)
def _generate_batch(
self,
batch_instance: List[str],
**kwargs: Any,
) -> List[CompletionResponse]:
"""
Generate completions for a batch of queries.
:param batch_instance: A list of queries.
:type batch_instance: List[str]
:param kwargs: Additional keyword arguments.
:return: A list of `CompletionResponse` objects with the same prediction content as the input.
:rtype: List[CompletionResponse]
"""

temperature = kwargs.get("temperature", 0)
max_new_tokens = kwargs.get("max_new_tokens", 16)

sampling_params = SamplingParams(temperature=temperature, max_tokens=max_new_tokens, top_k=1)

return [CompletionResponse(prediction=output.outputs[0].text) for output in self.model.generate(batch_instance, sampling_params)]
5 changes: 4 additions & 1 deletion alfred/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
self.model_type = model_type.lower()
assert self.model_type in [
"huggingface", "huggingfacevlm", "onnx", "tensorrt", "openai", "anthropic",
"flexgen",
"flexgen", "vllm",
"cohere", "ai21", "torch", "dummy"
], f"Invalid model type: {self.model_type}"
if self.model_type == "huggingface":
Expand Down Expand Up @@ -77,6 +77,9 @@ def __init__(
elif self.model_type == "flexgen":
from alfred.fm.flexgen import FlexGenModel
self.model = FlexGenModel(self.model, **kwargs)
elif self.model_type == "vllm":
from alfred.fm.vllm import vLLMModel
self.model = vLLMModel(self.model, **kwargs)
elif self.model_type == "tensorrt":
# self.model = TensorRTModel(self.model, **kwargs)
raise NotImplementedError
Expand Down
1 change: 1 addition & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ A full list of `Alfred` project modules.
- [RankedResponse](alfred/fm/response/ranked_response.md#rankedresponse)
- [Response](alfred/fm/response/response.md#response)
- [Utils](alfred/fm/utils.md#utils)
- [Vllm](alfred/fm/vllm.md#vllm)
- [Labeling](alfred/labeling/index.md#labeling)
- [FlyingSquid](alfred/labeling/flyingsquid.md#flyingsquid)
- [LabelModel](alfred/labeling/labelmodel.md#labelmodel)
Expand Down
4 changes: 1 addition & 3 deletions docs/alfred/client/cache/cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,4 @@ Type: *str*
```python
def to_metadata_string(**kwargs: Any) -> str:
...
```


```
4 changes: 1 addition & 3 deletions docs/alfred/client/cache/dummy.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,4 @@ Write a prompt-response pair to the cache
```python
def write(self, prompt: str, response: str, metadata: Optional[str] = None):
...
```


```
4 changes: 1 addition & 3 deletions docs/alfred/client/cache/sqlite.md
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,4 @@ def write_batch(
self, prompts: List[str], responses: List[str], metadata: Optional[str] = None
):
...
```


```
20 changes: 9 additions & 11 deletions docs/alfred/client/client.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Client:

### Client().__call__

[Show source in client.py:267](../../../alfred/client/client.py#L267)
[Show source in client.py:270](../../../alfred/client/client.py#L270)

__call__() function to run the model on the queries.
Equivalent to run() function.
Expand Down Expand Up @@ -75,7 +75,7 @@ def __call__(

### Client().calibrate

[Show source in client.py:282](../../../alfred/client/client.py#L282)
[Show source in client.py:285](../../../alfred/client/client.py#L285)

calibrate are used to calibrate foundation models contextually given the template.
A voter class may be passed to calibrate the model with a specific voter.
Expand Down Expand Up @@ -120,7 +120,7 @@ def calibrate(

### Client().chat

[Show source in client.py:384](../../../alfred/client/client.py#L384)
[Show source in client.py:387](../../../alfred/client/client.py#L387)

Chat with the model APIs.
Currently, Alfred supports Chat APIs from Anthropic and OpenAI
Expand All @@ -139,7 +139,7 @@ def chat(self, log_save_path: Optional[str] = None, **kwargs: Any):

### Client().encode

[Show source in client.py:358](../../../alfred/client/client.py#L358)
[Show source in client.py:361](../../../alfred/client/client.py#L361)

embed() function to embed the queries.

Expand All @@ -162,7 +162,7 @@ def encode(

### Client().generate

[Show source in client.py:226](../../../alfred/client/client.py#L226)
[Show source in client.py:229](../../../alfred/client/client.py#L229)

Wrapper function to generate the response(s) from the model. (For completion)

Expand Down Expand Up @@ -191,7 +191,7 @@ def generate(

### Client().remote_run

[Show source in client.py:204](../../../alfred/client/client.py#L204)
[Show source in client.py:207](../../../alfred/client/client.py#L207)

Wrapper function for running the model on the queries thru a gRPC Server.

Expand All @@ -218,7 +218,7 @@ def remote_run(

### Client().run

[Show source in client.py:184](../../../alfred/client/client.py#L184)
[Show source in client.py:187](../../../alfred/client/client.py#L187)

Run the model on the queries.

Expand All @@ -245,7 +245,7 @@ def run(

### Client().score

[Show source in client.py:243](../../../alfred/client/client.py#L243)
[Show source in client.py:246](../../../alfred/client/client.py#L246)

Wrapper function to score the response(s) from the model. (For ranking)

Expand Down Expand Up @@ -275,6 +275,4 @@ def score(
self, query: Union[RankedQuery, Dict, List[RankedQuery], List[str]], **kwargs: Any
) -> Union[Response, List[Response]]:
...
```


```
4 changes: 1 addition & 3 deletions docs/alfred/client/ssh/sshtunnel.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,4 @@ Stop the tunnel
```python
def stop(self):
...
```


```
4 changes: 1 addition & 3 deletions docs/alfred/client/ssh/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,4 @@ Finds the next available port if given port is not available
```python
def port_finder(port: Union[str, int], host: str = "") -> int:
...
```


```
4 changes: 1 addition & 3 deletions docs/alfred/data/arrow.md
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,4 @@ returns the version of the dataset
```python
def version(self) -> str:
...
```


```
4 changes: 1 addition & 3 deletions docs/alfred/data/dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,4 @@ returns the version of the dataset
@property
def version(self) -> str:
...
```


```
4 changes: 1 addition & 3 deletions docs/alfred/data/wrench.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,4 @@ returns the string representation of the dataset
```python
def __repr__(self):
...
```


```
4 changes: 1 addition & 3 deletions docs/alfred/fm/ai21.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,4 @@ This class provides a wrapper for the OpenAI API for generating completions.
class AI21Model(APIAccessFoundationModel):
def __init__(self, model_string: str = "j1-large", api_key: Optional[str] = None):
...
```


```
4 changes: 1 addition & 3 deletions docs/alfred/fm/anthropic.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,4 @@ Launch an interactive chat session with the Anthropic API.
```python
def chat(self, **kwargs: Any):
...
```


```
4 changes: 1 addition & 3 deletions docs/alfred/fm/cohere.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,4 @@ This class provides a wrapper for the OpenAI API for generating completions.
class CohereModel(APIAccessFoundationModel):
def __init__(self, model_string: str = "xlarge", api_key: Optional[str] = None):
...
```


```
4 changes: 1 addition & 3 deletions docs/alfred/fm/dummy.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,4 @@ class DummyModel(LocalAccessFoundationModel):

#### See also

- [LocalAccessFoundationModel](./model.md#localaccessfoundationmodel)


- [LocalAccessFoundationModel](./model.md#localaccessfoundationmodel)
4 changes: 1 addition & 3 deletions docs/alfred/fm/huggingface.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,4 @@ class HuggingFaceModel(LocalAccessFoundationModel):

#### See also

- [LocalAccessFoundationModel](./model.md#localaccessfoundationmodel)


- [LocalAccessFoundationModel](./model.md#localaccessfoundationmodel)
4 changes: 1 addition & 3 deletions docs/alfred/fm/huggingfacevlm.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,4 @@ class HuggingFaceCLIPModel(LocalAccessFoundationModel):

#### See also

- [LocalAccessFoundationModel](./model.md#localaccessfoundationmodel)


- [LocalAccessFoundationModel](./model.md#localaccessfoundationmodel)
3 changes: 2 additions & 1 deletion docs/alfred/fm/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ Fm
- [Query](query/index.md)
- [Remote](remote/index.md)
- [Response](response/index.md)
- [Utils](./utils.md)
- [Utils](./utils.md)
- [Vllm](./vllm.md)
4 changes: 1 addition & 3 deletions docs/alfred/fm/model.md
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,4 @@ class LocalAccessFoundationModel(FoundationModel):

#### See also

- [FoundationModel](#foundationmodel)


- [FoundationModel](#foundationmodel)
4 changes: 1 addition & 3 deletions docs/alfred/fm/onnx.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,4 @@ class ONNXModel(LocalAccessFoundationModel):

#### See also

- [LocalAccessFoundationModel](./model.md#localaccessfoundationmodel)


- [LocalAccessFoundationModel](./model.md#localaccessfoundationmodel)
4 changes: 1 addition & 3 deletions docs/alfred/fm/openai.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,4 @@ Launch an interactive chat session with the OpenAI API.
```python
def chat(self, **kwargs: Any):
...
```


```
4 changes: 1 addition & 3 deletions docs/alfred/fm/query/completion_query.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,4 @@ returns the raw prompt content
@property
def prompt(self):
...
```


```
4 changes: 1 addition & 3 deletions docs/alfred/fm/query/query.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,4 @@ Type: *str*
```python
def serialize(self) -> str:
...
```


```
4 changes: 1 addition & 3 deletions docs/alfred/fm/query/ranked_query.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,4 @@ returns the raw prompt content
@property
def prompt(self):
...
```


```
4 changes: 1 addition & 3 deletions docs/alfred/fm/remote/grpc.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,4 @@ def restart(self):
```python
def serve(self, credentials: Optional[grpc.ServerCredentials] = None):
...
```


```
1 change: 0 additions & 1 deletion docs/alfred/fm/remote/protos/query_pb2.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@
Query Pb2

> Auto-generated documentation for [alfred.fm.remote.protos.query_pb2](../../../../../alfred/fm/remote/protos/query_pb2.py) module.
- [Query Pb2](#query-pb2)
4 changes: 1 addition & 3 deletions docs/alfred/fm/remote/protos/query_pb2_grpc.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,4 @@ class QueryServiceStub(object):
```python
def add_QueryServiceServicer_to_server(servicer, server):
...
```


```
Loading

0 comments on commit 653649f

Please sign in to comment.