From 2e07dce15439c10915721d38db282b2784087723 Mon Sep 17 00:00:00 2001 From: Yuri-0 Date: Tue, 4 Feb 2025 07:31:33 +0900 Subject: [PATCH] Add vLLM offline inference supports for embedding (#17675) --- .../llama-index-embeddings-vllm/BUILD | 4 + .../llama-index-embeddings-vllm/Makefile | 17 ++ .../llama-index-embeddings-vllm/README.md | 1 + .../llama_index/embeddings/vllm/BUILD | 1 + .../llama_index/embeddings/vllm/__init__.py | 3 + .../llama_index/embeddings/vllm/base.py | 244 ++++++++++++++++++ .../pyproject.toml | 64 +++++ .../llama-index-embeddings-vllm/tests/BUILD | 1 + .../tests/__init__.py | 0 .../tests/test_embeddings_vllm.py | 25 ++ 10 files changed, 360 insertions(+) create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-vllm/BUILD create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-vllm/Makefile create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-vllm/README.md create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-vllm/llama_index/embeddings/vllm/BUILD create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-vllm/llama_index/embeddings/vllm/__init__.py create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-vllm/llama_index/embeddings/vllm/base.py create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-vllm/pyproject.toml create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-vllm/tests/BUILD create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-vllm/tests/__init__.py create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-vllm/tests/test_embeddings_vllm.py diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vllm/BUILD b/llama-index-integrations/embeddings/llama-index-embeddings-vllm/BUILD new file mode 100644 index 0000000000000..91ba5edacdf5f --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vllm/BUILD @@ -0,0 +1,4 @@ +poetry_requirements( + name="poetry", + module_mapping={"vcrpy": ["vcr"]} +) diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vllm/Makefile b/llama-index-integrations/embeddings/llama-index-embeddings-vllm/Makefile new file mode 100644 index 0000000000000..b9eab05aa3706 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vllm/Makefile @@ -0,0 +1,17 @@ +GIT_ROOT ?= $(shell git rev-parse --show-toplevel) + +help: ## Show all Makefile targets. + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' + +format: ## Run code autoformatters (black). + pre-commit install + git ls-files | xargs pre-commit run black --files + +lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy + pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files + +test: ## Run tests via pytest. + pytest tests + +watch-docs: ## Build and watch documentation. + sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vllm/README.md b/llama-index-integrations/embeddings/llama-index-embeddings-vllm/README.md new file mode 100644 index 0000000000000..ebcd1b066ebc4 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vllm/README.md @@ -0,0 +1 @@ +# LlamaIndex Embeddings Integration: Vllm diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vllm/llama_index/embeddings/vllm/BUILD b/llama-index-integrations/embeddings/llama-index-embeddings-vllm/llama_index/embeddings/vllm/BUILD new file mode 100644 index 0000000000000..db46e8d6c978c --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vllm/llama_index/embeddings/vllm/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vllm/llama_index/embeddings/vllm/__init__.py b/llama-index-integrations/embeddings/llama-index-embeddings-vllm/llama_index/embeddings/vllm/__init__.py new file mode 100644 index 0000000000000..aac655d573296 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vllm/llama_index/embeddings/vllm/__init__.py @@ -0,0 +1,3 @@ +from llama_index.embeddings.vllm.base import VllmEmbedding + +__all__ = ["VllmEmbedding"] diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vllm/llama_index/embeddings/vllm/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-vllm/llama_index/embeddings/vllm/base.py new file mode 100644 index 0000000000000..00e290c9a9d92 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vllm/llama_index/embeddings/vllm/base.py @@ -0,0 +1,244 @@ +from io import BytesIO +import logging +from typing import Any, Dict, List, Optional, Union + +from llama_index.core.base.embeddings.base import DEFAULT_EMBED_BATCH_SIZE +from llama_index.core.bridge.pydantic import Field, PrivateAttr +from llama_index.core.callbacks import CallbackManager +from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding +from llama_index.core.schema import ImageType +from PIL import Image +from tenacity import retry, stop_after_attempt, wait_exponential +import atexit + +SUPPORT_EMBED_TYPES = ["image", "text"] +logger = logging.getLogger(__name__) + + +class VllmEmbedding(MultiModalEmbedding): + """Vllm LLM. + + This class runs a vLLM embedding model locally. + """ + + tensor_parallel_size: Optional[int] = Field( + default=1, + description="The number of GPUs to use for distributed execution with tensor parallelism.", + ) + + trust_remote_code: Optional[bool] = Field( + default=True, + description="Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer.", + ) + + dtype: str = Field( + default="auto", + description="The data type for the model weights and activations.", + ) + + download_dir: Optional[str] = Field( + default=None, + description="Directory to download and load the weights. (Default to the default cache dir of huggingface)", + ) + + vllm_kwargs: Dict[str, Any] = Field( + default_factory=dict, + description="Holds any model parameters valid for `vllm.LLM` call not explicitly specified.", + ) + + _client: Any = PrivateAttr() + + _image_token_id: Union[int, None] = PrivateAttr() + + def __init__( + self, + model_name: str = "facebook/opt-125m", + embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, + tensor_parallel_size: int = 1, + trust_remote_code: bool = False, + dtype: str = "auto", + download_dir: Optional[str] = None, + vllm_kwargs: Dict[str, Any] = {}, + callback_manager: Optional[CallbackManager] = None, + ) -> None: + callback_manager = callback_manager or CallbackManager([]) + super().__init__( + model_name=model_name, + embed_batch_size=embed_batch_size, + callback_manager=callback_manager, + ) + try: + from vllm import LLM as VLLModel + except ImportError: + raise ImportError( + "Could not import vllm python package. " + "Please install it with `pip install vllm`." + ) + self._client = VLLModel( + model=model_name, + task="embed", + max_num_seqs=embed_batch_size, + tensor_parallel_size=tensor_parallel_size, + trust_remote_code=trust_remote_code, + dtype=dtype, + download_dir=download_dir, + **vllm_kwargs, + ) + try: + self._image_token_id = ( + self._client.llm_engine.model_config.hf_config.image_token_id + ) + except AttributeError: + self._image_token_id = None + + @classmethod + def class_name(cls) -> str: + return "VllmEmbedding" + + @atexit.register + def close(): + import torch + import gc + + if torch.cuda.is_available(): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + reraise=True, + ) + def _embed_with_retry( + self, inputs: List[Union[str, BytesIO]], embed_type: str = "text" + ) -> List[List[float]]: + """ + Generates embeddings with retry mechanism. + + Args: + inputs: List of texts or images to embed + + Returns: + List of embedding vectors + + Raises: + Exception: If embedding fails after retries + """ + try: + if embed_type == "image": + inputs = [ + { + "prompt_token_ids": [self._image_token_id], + "multi_modal_data": {"image": x}, + } + for x in inputs + ] + emb = self._client.embed(inputs) + return [x.outputs.embedding for x in emb] + except Exception as e: + logger.warning(f"Embedding attempt failed: {e!s}") + raise + + def _embed( + self, inputs: List[Union[str, BytesIO]], embed_type: str = "text" + ) -> List[List[float]]: + """ + Generates Embeddings with input validation and retry mechanism. + + Args: + sentences: Texts or Sentences to embed + prompt_name: The name of the prompt to use for encoding + + Returns: + List of embedding vectors + + Raises: + ValueError: If any input text is invalid + Exception: If embedding fails after retries + """ + if embed_type not in SUPPORT_EMBED_TYPES: + raise (ValueError("Not Implemented")) + return self._embed_with_retry(inputs, embed_type) + + def _get_query_embedding(self, query: str) -> List[float]: + """ + Generates Embeddings for Query. + + Args: + query (str): Query text/sentence + + Returns: + List[float]: numpy array of embeddings + """ + return self._embed([query])[0] + + async def _aget_query_embedding(self, query: str) -> List[float]: + """ + Generates Embeddings for Query Asynchronously. + + Args: + query (str): Query text/sentence + + Returns: + List[float]: numpy array of embeddings + """ + return self._get_query_embedding(query) + + async def _aget_text_embedding(self, text: str) -> List[float]: + """ + Generates Embeddings for text Asynchronously. + + Args: + text (str): Text/Sentence + + Returns: + List[float]: numpy array of embeddings + """ + return self._get_text_embedding(text) + + def _get_text_embedding(self, text: str) -> List[float]: + """ + Generates Embeddings for text. + + Args: + text (str): Text/sentences + + Returns: + List[float]: numpy array of embeddings + """ + return self._embed([text])[0] + + def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """ + Generates Embeddings for text. + + Args: + texts (List[str]): Texts / Sentences + + Returns: + List[List[float]]: numpy array of embeddings + """ + return self._embed(texts) + + def _get_image_embedding(self, img_file_path: ImageType) -> List[float]: + """Generate embedding for an image.""" + image = Image.open(img_file_path) + return self._embed([image], "image")[0] + + async def _aget_image_embedding(self, img_file_path: ImageType) -> List[float]: + """Generate embedding for an image asynchronously.""" + return self._get_image_embedding(img_file_path) + + def _get_image_embeddings( + self, img_file_paths: List[ImageType] + ) -> List[List[float]]: + images = [Image.open(x) for x in img_file_paths] + """Generate embeddings for multiple images.""" + return self._embed(images, "image") + + async def _aget_image_embeddings( + self, img_file_paths: List[ImageType] + ) -> List[List[float]]: + """Generate embeddings for multiple images asynchronously.""" + return self._get_image_embeddings(img_file_paths) diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vllm/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-vllm/pyproject.toml new file mode 100644 index 0000000000000..0b24bc34b7074 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vllm/pyproject.toml @@ -0,0 +1,64 @@ +[build-system] +build-backend = "poetry.core.masonry.api" +requires = ["poetry-core"] + +[tool.codespell] +check-filenames = true +check-hidden = true +skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" + +[tool.llamahub] +contains_example = false +import_path = "llama_index.embeddings.vllm" + +[tool.llamahub.class_authors] +VllmEmbedding = "llama-index" + +[tool.mypy] +disallow_untyped_defs = true +exclude = ["_static", "build", "examples", "notebooks", "venv"] +ignore_missing_imports = true +python_version = "3.8" + +[tool.poetry] +authors = ["Yuri "] +description = "llama-index embeddings vllm integration" +exclude = ["**/BUILD"] +license = "MIT" +name = "llama-index-embeddings-vllm" +readme = "README.md" +version = "0.0.1" + +[tool.poetry.dependencies] +python = ">=3.9,<4.0" +llama-index-core = "^0.12.0" + +[tool.poetry.group.dev.dependencies] +ipython = "8.10.0" +jupyter = "^1.0.0" +mypy = "0.991" +pre-commit = "3.2.0" +pylint = "2.15.10" +pytest = "7.2.1" +pytest-mock = "3.11.1" +ruff = "0.0.292" +tree-sitter-languages = "^1.8.0" +types-Deprecated = ">=0.1.0" +types-PyYAML = "^6.0.12.12" +types-protobuf = "^4.24.0.4" +types-redis = "4.5.5.0" +types-requests = "2.28.11.8" +types-setuptools = "67.1.0.0" +vcrpy = "7.0.0" +vllm = "*" + +[tool.poetry.group.dev.dependencies.black] +extras = ["jupyter"] +version = "<=23.9.1,>=23.7.0" + +[tool.poetry.group.dev.dependencies.codespell] +extras = ["toml"] +version = ">=v2.2.6" + +[[tool.poetry.packages]] +include = "llama_index/" diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vllm/tests/BUILD b/llama-index-integrations/embeddings/llama-index-embeddings-vllm/tests/BUILD new file mode 100644 index 0000000000000..dabf212d7e716 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vllm/tests/BUILD @@ -0,0 +1 @@ +python_tests() diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vllm/tests/__init__.py b/llama-index-integrations/embeddings/llama-index-embeddings-vllm/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vllm/tests/test_embeddings_vllm.py b/llama-index-integrations/embeddings/llama-index-embeddings-vllm/tests/test_embeddings_vllm.py new file mode 100644 index 0000000000000..a2cfc6f0b69ba --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vllm/tests/test_embeddings_vllm.py @@ -0,0 +1,25 @@ +import pytest +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.embeddings.vllm import VllmEmbedding + + +def test_vllmembedding_class(): + names_of_base_classes = [b.__name__ for b in VllmEmbedding.__mro__] + assert BaseEmbedding.__name__ in names_of_base_classes + + +def test_embedding_retry(): + try: + embed_model = VllmEmbedding() + except RuntimeError: + # will fail in certain environments + # skip test if it fails + pytest.skip("Skipping test due to environment issue") + return + + # Test successful embedding + result = embed_model._embed(["This is a test sentence"]) + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], list) + assert all(isinstance(x, float) for x in result[0])