-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add vLLM offline inference supports for embedding (#17675)
- Loading branch information
Showing
10 changed files
with
360 additions
and
0 deletions.
There are no files selected for viewing
4 changes: 4 additions & 0 deletions
4
llama-index-integrations/embeddings/llama-index-embeddings-vllm/BUILD
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
poetry_requirements( | ||
name="poetry", | ||
module_mapping={"vcrpy": ["vcr"]} | ||
) |
17 changes: 17 additions & 0 deletions
17
llama-index-integrations/embeddings/llama-index-embeddings-vllm/Makefile
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
GIT_ROOT ?= $(shell git rev-parse --show-toplevel) | ||
|
||
help: ## Show all Makefile targets. | ||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' | ||
|
||
format: ## Run code autoformatters (black). | ||
pre-commit install | ||
git ls-files | xargs pre-commit run black --files | ||
|
||
lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy | ||
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files | ||
|
||
test: ## Run tests via pytest. | ||
pytest tests | ||
|
||
watch-docs: ## Build and watch documentation. | ||
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ |
1 change: 1 addition & 0 deletions
1
llama-index-integrations/embeddings/llama-index-embeddings-vllm/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# LlamaIndex Embeddings Integration: Vllm |
1 change: 1 addition & 0 deletions
1
...dex-integrations/embeddings/llama-index-embeddings-vllm/llama_index/embeddings/vllm/BUILD
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python_sources() |
3 changes: 3 additions & 0 deletions
3
...tegrations/embeddings/llama-index-embeddings-vllm/llama_index/embeddings/vllm/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from llama_index.embeddings.vllm.base import VllmEmbedding | ||
|
||
__all__ = ["VllmEmbedding"] |
244 changes: 244 additions & 0 deletions
244
...x-integrations/embeddings/llama-index-embeddings-vllm/llama_index/embeddings/vllm/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
from io import BytesIO | ||
import logging | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from llama_index.core.base.embeddings.base import DEFAULT_EMBED_BATCH_SIZE | ||
from llama_index.core.bridge.pydantic import Field, PrivateAttr | ||
from llama_index.core.callbacks import CallbackManager | ||
from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding | ||
from llama_index.core.schema import ImageType | ||
from PIL import Image | ||
from tenacity import retry, stop_after_attempt, wait_exponential | ||
import atexit | ||
|
||
SUPPORT_EMBED_TYPES = ["image", "text"] | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class VllmEmbedding(MultiModalEmbedding): | ||
"""Vllm LLM. | ||
This class runs a vLLM embedding model locally. | ||
""" | ||
|
||
tensor_parallel_size: Optional[int] = Field( | ||
default=1, | ||
description="The number of GPUs to use for distributed execution with tensor parallelism.", | ||
) | ||
|
||
trust_remote_code: Optional[bool] = Field( | ||
default=True, | ||
description="Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer.", | ||
) | ||
|
||
dtype: str = Field( | ||
default="auto", | ||
description="The data type for the model weights and activations.", | ||
) | ||
|
||
download_dir: Optional[str] = Field( | ||
default=None, | ||
description="Directory to download and load the weights. (Default to the default cache dir of huggingface)", | ||
) | ||
|
||
vllm_kwargs: Dict[str, Any] = Field( | ||
default_factory=dict, | ||
description="Holds any model parameters valid for `vllm.LLM` call not explicitly specified.", | ||
) | ||
|
||
_client: Any = PrivateAttr() | ||
|
||
_image_token_id: Union[int, None] = PrivateAttr() | ||
|
||
def __init__( | ||
self, | ||
model_name: str = "facebook/opt-125m", | ||
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, | ||
tensor_parallel_size: int = 1, | ||
trust_remote_code: bool = False, | ||
dtype: str = "auto", | ||
download_dir: Optional[str] = None, | ||
vllm_kwargs: Dict[str, Any] = {}, | ||
callback_manager: Optional[CallbackManager] = None, | ||
) -> None: | ||
callback_manager = callback_manager or CallbackManager([]) | ||
super().__init__( | ||
model_name=model_name, | ||
embed_batch_size=embed_batch_size, | ||
callback_manager=callback_manager, | ||
) | ||
try: | ||
from vllm import LLM as VLLModel | ||
except ImportError: | ||
raise ImportError( | ||
"Could not import vllm python package. " | ||
"Please install it with `pip install vllm`." | ||
) | ||
self._client = VLLModel( | ||
model=model_name, | ||
task="embed", | ||
max_num_seqs=embed_batch_size, | ||
tensor_parallel_size=tensor_parallel_size, | ||
trust_remote_code=trust_remote_code, | ||
dtype=dtype, | ||
download_dir=download_dir, | ||
**vllm_kwargs, | ||
) | ||
try: | ||
self._image_token_id = ( | ||
self._client.llm_engine.model_config.hf_config.image_token_id | ||
) | ||
except AttributeError: | ||
self._image_token_id = None | ||
|
||
@classmethod | ||
def class_name(cls) -> str: | ||
return "VllmEmbedding" | ||
|
||
@atexit.register | ||
def close(): | ||
import torch | ||
import gc | ||
|
||
if torch.cuda.is_available(): | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
torch.cuda.synchronize() | ||
|
||
@retry( | ||
stop=stop_after_attempt(3), | ||
wait=wait_exponential(multiplier=1, min=4, max=10), | ||
reraise=True, | ||
) | ||
def _embed_with_retry( | ||
self, inputs: List[Union[str, BytesIO]], embed_type: str = "text" | ||
) -> List[List[float]]: | ||
""" | ||
Generates embeddings with retry mechanism. | ||
Args: | ||
inputs: List of texts or images to embed | ||
Returns: | ||
List of embedding vectors | ||
Raises: | ||
Exception: If embedding fails after retries | ||
""" | ||
try: | ||
if embed_type == "image": | ||
inputs = [ | ||
{ | ||
"prompt_token_ids": [self._image_token_id], | ||
"multi_modal_data": {"image": x}, | ||
} | ||
for x in inputs | ||
] | ||
emb = self._client.embed(inputs) | ||
return [x.outputs.embedding for x in emb] | ||
except Exception as e: | ||
logger.warning(f"Embedding attempt failed: {e!s}") | ||
raise | ||
|
||
def _embed( | ||
self, inputs: List[Union[str, BytesIO]], embed_type: str = "text" | ||
) -> List[List[float]]: | ||
""" | ||
Generates Embeddings with input validation and retry mechanism. | ||
Args: | ||
sentences: Texts or Sentences to embed | ||
prompt_name: The name of the prompt to use for encoding | ||
Returns: | ||
List of embedding vectors | ||
Raises: | ||
ValueError: If any input text is invalid | ||
Exception: If embedding fails after retries | ||
""" | ||
if embed_type not in SUPPORT_EMBED_TYPES: | ||
raise (ValueError("Not Implemented")) | ||
return self._embed_with_retry(inputs, embed_type) | ||
|
||
def _get_query_embedding(self, query: str) -> List[float]: | ||
""" | ||
Generates Embeddings for Query. | ||
Args: | ||
query (str): Query text/sentence | ||
Returns: | ||
List[float]: numpy array of embeddings | ||
""" | ||
return self._embed([query])[0] | ||
|
||
async def _aget_query_embedding(self, query: str) -> List[float]: | ||
""" | ||
Generates Embeddings for Query Asynchronously. | ||
Args: | ||
query (str): Query text/sentence | ||
Returns: | ||
List[float]: numpy array of embeddings | ||
""" | ||
return self._get_query_embedding(query) | ||
|
||
async def _aget_text_embedding(self, text: str) -> List[float]: | ||
""" | ||
Generates Embeddings for text Asynchronously. | ||
Args: | ||
text (str): Text/Sentence | ||
Returns: | ||
List[float]: numpy array of embeddings | ||
""" | ||
return self._get_text_embedding(text) | ||
|
||
def _get_text_embedding(self, text: str) -> List[float]: | ||
""" | ||
Generates Embeddings for text. | ||
Args: | ||
text (str): Text/sentences | ||
Returns: | ||
List[float]: numpy array of embeddings | ||
""" | ||
return self._embed([text])[0] | ||
|
||
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: | ||
""" | ||
Generates Embeddings for text. | ||
Args: | ||
texts (List[str]): Texts / Sentences | ||
Returns: | ||
List[List[float]]: numpy array of embeddings | ||
""" | ||
return self._embed(texts) | ||
|
||
def _get_image_embedding(self, img_file_path: ImageType) -> List[float]: | ||
"""Generate embedding for an image.""" | ||
image = Image.open(img_file_path) | ||
return self._embed([image], "image")[0] | ||
|
||
async def _aget_image_embedding(self, img_file_path: ImageType) -> List[float]: | ||
"""Generate embedding for an image asynchronously.""" | ||
return self._get_image_embedding(img_file_path) | ||
|
||
def _get_image_embeddings( | ||
self, img_file_paths: List[ImageType] | ||
) -> List[List[float]]: | ||
images = [Image.open(x) for x in img_file_paths] | ||
"""Generate embeddings for multiple images.""" | ||
return self._embed(images, "image") | ||
|
||
async def _aget_image_embeddings( | ||
self, img_file_paths: List[ImageType] | ||
) -> List[List[float]]: | ||
"""Generate embeddings for multiple images asynchronously.""" | ||
return self._get_image_embeddings(img_file_paths) |
64 changes: 64 additions & 0 deletions
64
llama-index-integrations/embeddings/llama-index-embeddings-vllm/pyproject.toml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
[build-system] | ||
build-backend = "poetry.core.masonry.api" | ||
requires = ["poetry-core"] | ||
|
||
[tool.codespell] | ||
check-filenames = true | ||
check-hidden = true | ||
skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" | ||
|
||
[tool.llamahub] | ||
contains_example = false | ||
import_path = "llama_index.embeddings.vllm" | ||
|
||
[tool.llamahub.class_authors] | ||
VllmEmbedding = "llama-index" | ||
|
||
[tool.mypy] | ||
disallow_untyped_defs = true | ||
exclude = ["_static", "build", "examples", "notebooks", "venv"] | ||
ignore_missing_imports = true | ||
python_version = "3.8" | ||
|
||
[tool.poetry] | ||
authors = ["Yuri <yuri@yurinet.blog>"] | ||
description = "llama-index embeddings vllm integration" | ||
exclude = ["**/BUILD"] | ||
license = "MIT" | ||
name = "llama-index-embeddings-vllm" | ||
readme = "README.md" | ||
version = "0.0.1" | ||
|
||
[tool.poetry.dependencies] | ||
python = ">=3.9,<4.0" | ||
llama-index-core = "^0.12.0" | ||
|
||
[tool.poetry.group.dev.dependencies] | ||
ipython = "8.10.0" | ||
jupyter = "^1.0.0" | ||
mypy = "0.991" | ||
pre-commit = "3.2.0" | ||
pylint = "2.15.10" | ||
pytest = "7.2.1" | ||
pytest-mock = "3.11.1" | ||
ruff = "0.0.292" | ||
tree-sitter-languages = "^1.8.0" | ||
types-Deprecated = ">=0.1.0" | ||
types-PyYAML = "^6.0.12.12" | ||
types-protobuf = "^4.24.0.4" | ||
types-redis = "4.5.5.0" | ||
types-requests = "2.28.11.8" | ||
types-setuptools = "67.1.0.0" | ||
vcrpy = "7.0.0" | ||
vllm = "*" | ||
|
||
[tool.poetry.group.dev.dependencies.black] | ||
extras = ["jupyter"] | ||
version = "<=23.9.1,>=23.7.0" | ||
|
||
[tool.poetry.group.dev.dependencies.codespell] | ||
extras = ["toml"] | ||
version = ">=v2.2.6" | ||
|
||
[[tool.poetry.packages]] | ||
include = "llama_index/" |
1 change: 1 addition & 0 deletions
1
llama-index-integrations/embeddings/llama-index-embeddings-vllm/tests/BUILD
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python_tests() |
Empty file.
25 changes: 25 additions & 0 deletions
25
...a-index-integrations/embeddings/llama-index-embeddings-vllm/tests/test_embeddings_vllm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import pytest | ||
from llama_index.core.base.embeddings.base import BaseEmbedding | ||
from llama_index.embeddings.vllm import VllmEmbedding | ||
|
||
|
||
def test_vllmembedding_class(): | ||
names_of_base_classes = [b.__name__ for b in VllmEmbedding.__mro__] | ||
assert BaseEmbedding.__name__ in names_of_base_classes | ||
|
||
|
||
def test_embedding_retry(): | ||
try: | ||
embed_model = VllmEmbedding() | ||
except RuntimeError: | ||
# will fail in certain environments | ||
# skip test if it fails | ||
pytest.skip("Skipping test due to environment issue") | ||
return | ||
|
||
# Test successful embedding | ||
result = embed_model._embed(["This is a test sentence"]) | ||
assert isinstance(result, list) | ||
assert len(result) == 1 | ||
assert isinstance(result[0], list) | ||
assert all(isinstance(x, float) for x in result[0]) |