Skip to content

Commit

Permalink
Add vLLM offline inference supports for embedding (#17675)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuri-0 authored Feb 3, 2025
1 parent 0301a9a commit 2e07dce
Show file tree
Hide file tree
Showing 10 changed files with 360 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
poetry_requirements(
name="poetry",
module_mapping={"vcrpy": ["vcr"]}
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
GIT_ROOT ?= $(shell git rev-parse --show-toplevel)

help: ## Show all Makefile targets.
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'

format: ## Run code autoformatters (black).
pre-commit install
git ls-files | xargs pre-commit run black --files

lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files

test: ## Run tests via pytest.
pytest tests

watch-docs: ## Build and watch documentation.
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# LlamaIndex Embeddings Integration: Vllm
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_sources()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from llama_index.embeddings.vllm.base import VllmEmbedding

__all__ = ["VllmEmbedding"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
from io import BytesIO
import logging
from typing import Any, Dict, List, Optional, Union

from llama_index.core.base.embeddings.base import DEFAULT_EMBED_BATCH_SIZE
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CallbackManager
from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding
from llama_index.core.schema import ImageType
from PIL import Image
from tenacity import retry, stop_after_attempt, wait_exponential
import atexit

SUPPORT_EMBED_TYPES = ["image", "text"]
logger = logging.getLogger(__name__)


class VllmEmbedding(MultiModalEmbedding):
"""Vllm LLM.
This class runs a vLLM embedding model locally.
"""

tensor_parallel_size: Optional[int] = Field(
default=1,
description="The number of GPUs to use for distributed execution with tensor parallelism.",
)

trust_remote_code: Optional[bool] = Field(
default=True,
description="Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer.",
)

dtype: str = Field(
default="auto",
description="The data type for the model weights and activations.",
)

download_dir: Optional[str] = Field(
default=None,
description="Directory to download and load the weights. (Default to the default cache dir of huggingface)",
)

vllm_kwargs: Dict[str, Any] = Field(
default_factory=dict,
description="Holds any model parameters valid for `vllm.LLM` call not explicitly specified.",
)

_client: Any = PrivateAttr()

_image_token_id: Union[int, None] = PrivateAttr()

def __init__(
self,
model_name: str = "facebook/opt-125m",
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
tensor_parallel_size: int = 1,
trust_remote_code: bool = False,
dtype: str = "auto",
download_dir: Optional[str] = None,
vllm_kwargs: Dict[str, Any] = {},
callback_manager: Optional[CallbackManager] = None,
) -> None:
callback_manager = callback_manager or CallbackManager([])
super().__init__(
model_name=model_name,
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
)
try:
from vllm import LLM as VLLModel
except ImportError:
raise ImportError(
"Could not import vllm python package. "
"Please install it with `pip install vllm`."
)
self._client = VLLModel(
model=model_name,
task="embed",
max_num_seqs=embed_batch_size,
tensor_parallel_size=tensor_parallel_size,
trust_remote_code=trust_remote_code,
dtype=dtype,
download_dir=download_dir,
**vllm_kwargs,
)
try:
self._image_token_id = (
self._client.llm_engine.model_config.hf_config.image_token_id
)
except AttributeError:
self._image_token_id = None

@classmethod
def class_name(cls) -> str:
return "VllmEmbedding"

@atexit.register
def close():
import torch
import gc

if torch.cuda.is_available():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()

@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
reraise=True,
)
def _embed_with_retry(
self, inputs: List[Union[str, BytesIO]], embed_type: str = "text"
) -> List[List[float]]:
"""
Generates embeddings with retry mechanism.
Args:
inputs: List of texts or images to embed
Returns:
List of embedding vectors
Raises:
Exception: If embedding fails after retries
"""
try:
if embed_type == "image":
inputs = [
{
"prompt_token_ids": [self._image_token_id],
"multi_modal_data": {"image": x},
}
for x in inputs
]
emb = self._client.embed(inputs)
return [x.outputs.embedding for x in emb]
except Exception as e:
logger.warning(f"Embedding attempt failed: {e!s}")
raise

def _embed(
self, inputs: List[Union[str, BytesIO]], embed_type: str = "text"
) -> List[List[float]]:
"""
Generates Embeddings with input validation and retry mechanism.
Args:
sentences: Texts or Sentences to embed
prompt_name: The name of the prompt to use for encoding
Returns:
List of embedding vectors
Raises:
ValueError: If any input text is invalid
Exception: If embedding fails after retries
"""
if embed_type not in SUPPORT_EMBED_TYPES:
raise (ValueError("Not Implemented"))
return self._embed_with_retry(inputs, embed_type)

def _get_query_embedding(self, query: str) -> List[float]:
"""
Generates Embeddings for Query.
Args:
query (str): Query text/sentence
Returns:
List[float]: numpy array of embeddings
"""
return self._embed([query])[0]

async def _aget_query_embedding(self, query: str) -> List[float]:
"""
Generates Embeddings for Query Asynchronously.
Args:
query (str): Query text/sentence
Returns:
List[float]: numpy array of embeddings
"""
return self._get_query_embedding(query)

async def _aget_text_embedding(self, text: str) -> List[float]:
"""
Generates Embeddings for text Asynchronously.
Args:
text (str): Text/Sentence
Returns:
List[float]: numpy array of embeddings
"""
return self._get_text_embedding(text)

def _get_text_embedding(self, text: str) -> List[float]:
"""
Generates Embeddings for text.
Args:
text (str): Text/sentences
Returns:
List[float]: numpy array of embeddings
"""
return self._embed([text])[0]

def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""
Generates Embeddings for text.
Args:
texts (List[str]): Texts / Sentences
Returns:
List[List[float]]: numpy array of embeddings
"""
return self._embed(texts)

def _get_image_embedding(self, img_file_path: ImageType) -> List[float]:
"""Generate embedding for an image."""
image = Image.open(img_file_path)
return self._embed([image], "image")[0]

async def _aget_image_embedding(self, img_file_path: ImageType) -> List[float]:
"""Generate embedding for an image asynchronously."""
return self._get_image_embedding(img_file_path)

def _get_image_embeddings(
self, img_file_paths: List[ImageType]
) -> List[List[float]]:
images = [Image.open(x) for x in img_file_paths]
"""Generate embeddings for multiple images."""
return self._embed(images, "image")

async def _aget_image_embeddings(
self, img_file_paths: List[ImageType]
) -> List[List[float]]:
"""Generate embeddings for multiple images asynchronously."""
return self._get_image_embeddings(img_file_paths)
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
[build-system]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core"]

[tool.codespell]
check-filenames = true
check-hidden = true
skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb"

[tool.llamahub]
contains_example = false
import_path = "llama_index.embeddings.vllm"

[tool.llamahub.class_authors]
VllmEmbedding = "llama-index"

[tool.mypy]
disallow_untyped_defs = true
exclude = ["_static", "build", "examples", "notebooks", "venv"]
ignore_missing_imports = true
python_version = "3.8"

[tool.poetry]
authors = ["Yuri <yuri@yurinet.blog>"]
description = "llama-index embeddings vllm integration"
exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-embeddings-vllm"
readme = "README.md"
version = "0.0.1"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
llama-index-core = "^0.12.0"

[tool.poetry.group.dev.dependencies]
ipython = "8.10.0"
jupyter = "^1.0.0"
mypy = "0.991"
pre-commit = "3.2.0"
pylint = "2.15.10"
pytest = "7.2.1"
pytest-mock = "3.11.1"
ruff = "0.0.292"
tree-sitter-languages = "^1.8.0"
types-Deprecated = ">=0.1.0"
types-PyYAML = "^6.0.12.12"
types-protobuf = "^4.24.0.4"
types-redis = "4.5.5.0"
types-requests = "2.28.11.8"
types-setuptools = "67.1.0.0"
vcrpy = "7.0.0"
vllm = "*"

[tool.poetry.group.dev.dependencies.black]
extras = ["jupyter"]
version = "<=23.9.1,>=23.7.0"

[tool.poetry.group.dev.dependencies.codespell]
extras = ["toml"]
version = ">=v2.2.6"

[[tool.poetry.packages]]
include = "llama_index/"
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_tests()
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.embeddings.vllm import VllmEmbedding


def test_vllmembedding_class():
names_of_base_classes = [b.__name__ for b in VllmEmbedding.__mro__]
assert BaseEmbedding.__name__ in names_of_base_classes


def test_embedding_retry():
try:
embed_model = VllmEmbedding()
except RuntimeError:
# will fail in certain environments
# skip test if it fails
pytest.skip("Skipping test due to environment issue")
return

# Test successful embedding
result = embed_model._embed(["This is a test sentence"])
assert isinstance(result, list)
assert len(result) == 1
assert isinstance(result[0], list)
assert all(isinstance(x, float) for x in result[0])

0 comments on commit 2e07dce

Please sign in to comment.